diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/dtrain/dtrain.cc | 6 | ||||
-rw-r--r-- | training/dtrain/dtrain.h | 1 | ||||
-rw-r--r-- | training/dtrain/example/standard/dtrain.ini | 5 | ||||
-rw-r--r-- | training/dtrain/update.h | 38 |
4 files changed, 47 insertions, 3 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 97df530b..378c988f 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -18,6 +18,7 @@ main(int argc, char** argv) const weight_t eta = conf["learning_rate"].as<weight_t>(); const weight_t margin = conf["margin"].as<weight_t>(); const bool average = conf["average"].as<bool>(); + const bool structured = conf["struct"].as<bool>(); const weight_t l1_reg = conf["l1_reg"].as<weight_t>(); const bool keep = conf["keep"].as<bool>(); const string output_fn = conf["output"].as<string>(); @@ -160,7 +161,10 @@ main(int argc, char** argv) // get pairs and update SparseVector<weight_t> updates; - num_up += CollectUpdates(samples, updates, margin); + if (structured) + num_up += CollectUpdatesStruct(samples, updates); + else + num_up += CollectUpdates(samples, updates, margin); SparseVector<weight_t> lambdas_copy; if (l1_reg) lambdas_copy = lambdas; diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h index 2636fa89..dc3ac30f 100644 --- a/training/dtrain/dtrain.h +++ b/training/dtrain/dtrain.h @@ -59,6 +59,7 @@ dtrain_init(int argc, char** argv, po::variables_map* conf) ("N", po::value<size_t>()->default_value(4), "N for BLEU approximation") ("input_weights,w", po::value<string>(), "input weights file") ("average,a", po::value<bool>()->default_value(false), "output average weights") + ("struct,S", po::value<bool>()->default_value(false), "structured SGD") ("keep,K", po::value<bool>()->default_value(false), "output a weight file per iteration") ("output,o", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") ("print_weights,P", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"), diff --git a/training/dtrain/example/standard/dtrain.ini b/training/dtrain/example/standard/dtrain.ini index c52bef4a..dfb9b844 100644 --- a/training/dtrain/example/standard/dtrain.ini +++ b/training/dtrain/example/standard/dtrain.ini @@ -4,7 +4,8 @@ decoder_conf=./cdec.ini # config for cdec iterations=3 # run over input 3 times k=100 # use 100best lists N=4 # optimize (approx.) BLEU4 -learning_rate=0.1 # learning rate -margin=1.0 # margin for margin perceptron +learning_rate=0.0001 # learning rate +margin=0 # margin for margin perceptron print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PhraseModel_5 PhraseModel_6 PassThrough score=nakov +struct=true diff --git a/training/dtrain/update.h b/training/dtrain/update.h index d7224cca..6f42e5bd 100644 --- a/training/dtrain/update.h +++ b/training/dtrain/update.h @@ -10,6 +10,18 @@ _cmp(ScoredHyp a, ScoredHyp b) return a.gold > b.gold; } +bool +_cmpHope(ScoredHyp a, ScoredHyp b) +{ + return (a.model+a.gold) > (b.model+b.gold); +} + +bool +_cmpFear(ScoredHyp a, ScoredHyp b) +{ + return (a.model-a.gold) > (b.model-b.gold); +} + inline bool _good(ScoredHyp& a, ScoredHyp& b, weight_t margin) { @@ -20,6 +32,15 @@ _good(ScoredHyp& a, ScoredHyp& b, weight_t margin) return false; } +inline bool +_goodS(ScoredHyp& a, ScoredHyp& b) +{ + if (a.gold==b.gold) + return true; + + return false; +} + /* * multipartite ranking * sort (descending) by bleu @@ -56,6 +77,23 @@ CollectUpdates(vector<ScoredHyp>* s, return num_up; } +inline size_t +CollectUpdatesStruct(vector<ScoredHyp>* s, + SparseVector<weight_t>& updates, + weight_t unused=-1) +{ + // hope + sort(s->begin(), s->end(), _cmpHope); + ScoredHyp hope = (*s)[0]; + // fear + sort(s->begin(), s->end(), _cmpFear); + ScoredHyp fear = (*s)[0]; + if (!_goodS(hope, fear)) + updates += hope.f - fear.f; + + return updates.size(); +} + } // namespace #endif |