diff options
-rw-r--r-- | training/dtrain/dtrain.cc | 5 | ||||
-rw-r--r-- | training/dtrain/dtrain.h | 4 | ||||
-rw-r--r-- | training/dtrain/update.h | 27 |
3 files changed, 25 insertions, 11 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index bbe14547..30c64037 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -48,6 +48,7 @@ main(int argc, char** argv) const string adadelta_output = conf["adadelta_output"].as<string>(); const size_t max_input = conf["stop_after"].as<size_t>(); const bool batch = conf["batch"].as<bool>(); + const bool all = conf["all"].as<bool>(); // setup decoder register_feature_functions(); @@ -267,14 +268,14 @@ main(int argc, char** argv) num_up += update_structured(sample, updates, margin, out_up, i); else if (all_pairs) - num_up += updates_all(sample, updates, max_up, threshold, + num_up += updates_all(sample, updates, max_up, margin, threshold, all, out_up, i); else if (pro) num_up += updates_pro(sample, updates, cut, max_up, threshold, out_up, i); else num_up += updates_multipartite(sample, updates, cut, margin, - max_up, threshold, adjust_cut, + max_up, threshold, adjust_cut, all, out_up, i); SparseVector<weight_t> lambdas_copy; diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h index 883e6028..83dcd945 100644 --- a/training/dtrain/dtrain.h +++ b/training/dtrain/dtrain.h @@ -71,6 +71,8 @@ dtrain_init(int argc, "use top/bottom 10% (default) of k-best as 'good' and 'bad' for pair sampling, 0 to use all pairs") ("adjust,A", po::bool_switch()->default_value(false), "adjust cut for optimal pos. in k-best to cut") + ("all,A", po::bool_switch()->default_value(false), + "update using all pairs, ignoring margin and threshold") ("score,s", po::value<string>()->default_value("nakov"), "per-sentence BLEU (approx.)") ("nakov_fix", po::value<weight_t>()->default_value(1.0), @@ -106,7 +108,7 @@ dtrain_init(int argc, "output raw data (e.g. k-best lists) [to filename]") ("stop_after", po::value<size_t>()->default_value(numeric_limits<size_t>::max()), "only look at this number of segments") - ("print_weights,P", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"), + ("print_weights", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"), "list of weights to print after each iteration"); po::options_description clopts("Command Line Options"); clopts.add_options() diff --git a/training/dtrain/update.h b/training/dtrain/update.h index 405a3f76..edcfe391 100644 --- a/training/dtrain/update.h +++ b/training/dtrain/update.h @@ -20,6 +20,7 @@ updates_multipartite(vector<Hyp>* sample, size_t max_up, weight_t threshold, bool adjust, + bool all, WriteFile& output, size_t id) { @@ -44,9 +45,11 @@ updates_multipartite(vector<Hyp>* sample, for (size_t i = 0; i < sep_hi; i++) { for (size_t j = sep_hi; j < sz; j++) { Hyp& first=(*sample)[i], second=(*sample)[j]; - if ((first.model-second.model)>margin - || (first.gold==second.gold) - || (threshold && (first.gold-second.gold < threshold))) + if (first.gold==second.gold) + continue; + if (!all + && (((first.model-second.model)>margin) + || (threshold && (first.gold-second.gold < threshold)))) continue; if (output) *output << id << "\t" << first.f-second.f << endl; @@ -64,9 +67,11 @@ updates_multipartite(vector<Hyp>* sample, for (size_t i = sep_hi; i < sep_lo; i++) { for (size_t j = sep_lo; j < sz; j++) { Hyp& first=(*sample)[i], second=(*sample)[j]; - if ((first.model-second.model)>margin - || (first.gold==second.gold) - || (threshold && (first.gold-second.gold < threshold))) + if (first.gold==second.gold) + continue; + if (!all + && (((first.model-second.model)>margin) + || (threshold && (first.gold-second.gold < threshold)))) continue; if (output) *output << id << "\t" << first.f-second.f << endl; @@ -83,13 +88,16 @@ updates_multipartite(vector<Hyp>* sample, * all pairs * only ignore a pair if gold scores are * identical + * FIXME: that's really _all_ * */ inline size_t updates_all(vector<Hyp>* sample, SparseVector<weight_t>& updates, size_t max_up, + weight_t margin, weight_t threshold, + bool all, WriteFile output, size_t id) { @@ -102,8 +110,11 @@ updates_all(vector<Hyp>* sample, for (size_t i = 0; i < sz-1; i++) { for (size_t j = i+1; j < sz; j++) { Hyp& first=(*sample)[i], second=(*sample)[j]; - if ((first.gold == second.gold) - || (threshold && (first.gold-second.gold < threshold))) + if (first.gold == second.gold) + continue; + if (!all + && (((first.model-second.model)>margin) + || (threshold && (first.gold-second.gold < threshold)))) continue; if (output) *output << id << "\t" << first.f-second.f << endl; |