diff options
| -rw-r--r-- | training/dtrain/dtrain.cc | 6 | ||||
| -rw-r--r-- | training/dtrain/dtrain.h | 1 | ||||
| -rw-r--r-- | training/dtrain/example/standard/dtrain.ini | 5 | ||||
| -rw-r--r-- | training/dtrain/update.h | 38 | 
4 files changed, 47 insertions, 3 deletions
| diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 97df530b..378c988f 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -18,6 +18,7 @@ main(int argc, char** argv)    const weight_t eta          = conf["learning_rate"].as<weight_t>();    const weight_t margin       = conf["margin"].as<weight_t>();    const bool average          = conf["average"].as<bool>(); +  const bool structured       = conf["struct"].as<bool>();    const weight_t l1_reg       = conf["l1_reg"].as<weight_t>();    const bool keep             = conf["keep"].as<bool>();    const string output_fn      = conf["output"].as<string>(); @@ -160,7 +161,10 @@ main(int argc, char** argv)      // get pairs and update      SparseVector<weight_t> updates; -    num_up += CollectUpdates(samples, updates, margin); +    if (structured) +      num_up += CollectUpdatesStruct(samples, updates); +    else +      num_up += CollectUpdates(samples, updates, margin);      SparseVector<weight_t> lambdas_copy;      if (l1_reg)        lambdas_copy = lambdas; diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h index 2636fa89..dc3ac30f 100644 --- a/training/dtrain/dtrain.h +++ b/training/dtrain/dtrain.h @@ -59,6 +59,7 @@ dtrain_init(int argc, char** argv, po::variables_map* conf)      ("N",                 po::value<size_t>()->default_value(4),              "N for BLEU approximation")      ("input_weights,w",   po::value<string>(),                                      "input weights file")      ("average,a",         po::value<bool>()->default_value(false),              "output average weights") +    ("struct,S",          po::value<bool>()->default_value(false),                      "structured SGD")      ("keep,K",            po::value<bool>()->default_value(false),  "output a weight file per iteration")      ("output,o",          po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT")      ("print_weights,P",   po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"), diff --git a/training/dtrain/example/standard/dtrain.ini b/training/dtrain/example/standard/dtrain.ini index c52bef4a..dfb9b844 100644 --- a/training/dtrain/example/standard/dtrain.ini +++ b/training/dtrain/example/standard/dtrain.ini @@ -4,7 +4,8 @@ decoder_conf=./cdec.ini   # config for cdec  iterations=3              # run over input 3 times  k=100                     # use 100best lists  N=4                       # optimize (approx.) BLEU4 -learning_rate=0.1         # learning rate -margin=1.0                # margin for margin perceptron +learning_rate=0.0001         # learning rate +margin=0                # margin for margin perceptron  print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PhraseModel_5 PhraseModel_6 PassThrough  score=nakov +struct=true diff --git a/training/dtrain/update.h b/training/dtrain/update.h index d7224cca..6f42e5bd 100644 --- a/training/dtrain/update.h +++ b/training/dtrain/update.h @@ -10,6 +10,18 @@ _cmp(ScoredHyp a, ScoredHyp b)    return a.gold > b.gold;  } +bool +_cmpHope(ScoredHyp a, ScoredHyp b) +{ +  return (a.model+a.gold) > (b.model+b.gold); +} + +bool +_cmpFear(ScoredHyp a, ScoredHyp b) +{ +  return (a.model-a.gold) > (b.model-b.gold); +} +  inline bool  _good(ScoredHyp& a, ScoredHyp& b, weight_t margin)  { @@ -20,6 +32,15 @@ _good(ScoredHyp& a, ScoredHyp& b, weight_t margin)    return false;  } +inline bool +_goodS(ScoredHyp& a, ScoredHyp& b) +{ +  if (a.gold==b.gold) +    return true; + +  return false; +} +  /*   * multipartite ranking   *  sort (descending) by bleu @@ -56,6 +77,23 @@ CollectUpdates(vector<ScoredHyp>* s,    return num_up;  } +inline size_t +CollectUpdatesStruct(vector<ScoredHyp>* s, +                     SparseVector<weight_t>& updates, +                     weight_t unused=-1) +{ +  // hope +  sort(s->begin(), s->end(), _cmpHope); +  ScoredHyp hope = (*s)[0]; +  // fear +  sort(s->begin(), s->end(), _cmpFear); +  ScoredHyp fear = (*s)[0]; +  if (!_goodS(hope, fear)) +    updates += hope.f - fear.f; + +  return updates.size(); +} +  } // namespace  #endif | 
