diff options
-rw-r--r-- | training/dtrain/dtrain.cc | 31 | ||||
-rw-r--r-- | training/dtrain/update.h | 37 |
2 files changed, 40 insertions, 28 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index ddd27211..3e9902ab 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -143,10 +143,15 @@ main(int argc, char** argv) time_t total_time = 0.; // output - WriteFile raw_out; - if (output_raw) raw_out.Init(output_raw_fn); - WriteFile updates_out; - if (output_updates) updates_out.Init(output_raw_fn); + WriteFile out_up, out_raw; + if (output_raw) { + out_raw.Init(output_raw_fn); + *out_raw << setprecision(numeric_limits<double>::digits10+1); + } + if (output_updates) { + out_up.Init(output_updates_fn); + *out_up << setprecision(numeric_limits<double>::digits10+1); + } for (size_t t = 0; t < T; t++) // T iterations @@ -220,25 +225,25 @@ main(int argc, char** argv) list_sz += observer->effective_size; if (output_raw) - output_sample(sample); + output_sample(sample, *out_raw, i); // update model if (!noup) { SparseVector<weight_t> updates; if (structured) - num_up += update_structured(sample, updates, margin/*, - output_updates, updates_out.get()*/); // FIXME + num_up += update_structured(sample, updates, margin, + output_updates, *out_up, i); else if (all_pairs) - num_up += updates_all(sample, updates, max_up, threshold/*, - output_updates, updates_out.get()*/); // FIXME + num_up += updates_all(sample, updates, max_up, threshold, + output_updates, *out_up, i); else if (pro) - num_up += updates_pro(sample, updates, cut, max_up, threshold/*, - output_updates, updates_out.get()*/); // FIXME + num_up += updates_pro(sample, updates, cut, max_up, threshold, + output_updates, *out_up, i); else num_up += updates_multipartite(sample, updates, cut, margin, - max_up, threshold, adjust_cut/*, - output_updates, updates_out.get()*/); // FIXME + max_up, threshold, adjust_cut, + output_updates, *out_up, i); SparseVector<weight_t> lambdas_copy; if (l1_reg) lambdas_copy = lambdas; diff --git a/training/dtrain/update.h b/training/dtrain/update.h index 30b14771..f6aa9842 100644 --- a/training/dtrain/update.h +++ b/training/dtrain/update.h @@ -21,7 +21,8 @@ updates_multipartite(vector<Hyp>* sample, weight_t threshold, bool adjust, bool output=false, - ostream& os=cout) + ostream& os=cout, + size_t id=0) { size_t up = 0; size_t sz = sample->size(); @@ -45,11 +46,11 @@ updates_multipartite(vector<Hyp>* sample, for (size_t j = sep_hi; j < sz; j++) { Hyp& first=(*sample)[i], second=(*sample)[j]; if ((first.model-second.model)>margin - || (!adjust && first.gold==second.gold) + || (first.gold==second.gold) || (threshold && (first.gold-second.gold < threshold))) continue; if (output) - os << first.f-second.f << endl; + os << id << "\t" << first.f-second.f << endl; updates += first.f-second.f; if (++up==max_up) return up; @@ -65,11 +66,11 @@ updates_multipartite(vector<Hyp>* sample, for (size_t j = sep_lo; j < sz; j++) { Hyp& first=(*sample)[i], second=(*sample)[j]; if ((first.model-second.model)>margin - || (!adjust && first.gold==second.gold) + || (first.gold==second.gold) || (threshold && (first.gold-second.gold < threshold))) continue; if (output) - os << first.f-second.f << endl; + os << id << "\t" << first.f-second.f << endl; updates += first.f-second.f; if (++up==max_up) break; @@ -91,7 +92,8 @@ updates_all(vector<Hyp>* sample, size_t max_up, weight_t threshold, bool output=false, - ostream& os=cout) + ostream& os=cout, + size_t id=0) { size_t up = 0; size_t sz = sample->size(); @@ -102,11 +104,11 @@ updates_all(vector<Hyp>* sample, for (size_t i = 0; i < sz-1; i++) { for (size_t j = i+1; j < sz; j++) { Hyp& first=(*sample)[i], second=(*sample)[j]; - if (first.gold == second.gold + if ((first.gold == second.gold) || (threshold && (first.gold-second.gold < threshold))) continue; if (output) - os << first.f-second.f << endl; + os << id << "\t" << first.f-second.f << endl; updates += first.f-second.f; if (++up==max_up) break; @@ -126,7 +128,8 @@ update_structured(vector<Hyp>* sample, SparseVector<weight_t>& updates, weight_t margin, bool output=false, - ostream& os=cout) + ostream& os=cout, + size_t id=0) { // hope sort(sample->begin(), sample->end(), [](Hyp first, Hyp second) @@ -144,7 +147,7 @@ update_structured(vector<Hyp>* sample, if (hope.gold != fear.gold) { updates += hope.f - fear.f; if (output) - os << hope.f << "\t" << fear.f << endl; + os << id << "\t" << hope.f << "\t" << fear.f << endl; return 1; } @@ -170,7 +173,8 @@ updates_pro(vector<Hyp>* sample, size_t max_up, weight_t threshold, bool output=false, - ostream& os=cout) + ostream& os=cout, + size_t id=0) { size_t sz = sample->size(), s; @@ -198,7 +202,7 @@ updates_pro(vector<Hyp>* sample, for (auto i: g) { if (output) - os << i.first->f-i.second->f << endl; + os << id << "\t" << i.first->f-i.second->f << endl; updates += i.first->f-i.second->f; } @@ -212,16 +216,19 @@ updates_pro(vector<Hyp>* sample, inline void output_sample(vector<Hyp>* sample, ostream& os=cout, + size_t id=0, bool sorted=true) { - if (sorted) + if (sorted) { sort(sample->begin(), sample->end(), [](Hyp first, Hyp second) { return first.gold > second.gold; }); + } size_t j = 0; - for (auto i: *sample) { - os << j << "\t" << i.gold << "\t" << i.model << "\t" << i.f << endl; + for (auto k: *sample) { + os << id << "\t" << j << "\t" << k.gold << "\t" << k.model + << "\t" << k.f << endl; j++; } } |