diff options
-rw-r--r-- | dtrain/dtrain.cc | 4 | ||||
-rw-r--r-- | dtrain/dtrain.h | 2 | ||||
-rw-r--r-- | dtrain/pairsampling.h | 70 | ||||
-rw-r--r-- | dtrain/test/example/cdec.ini | 16 | ||||
-rw-r--r-- | dtrain/test/example/dtrain.ini | 10 |
5 files changed, 91 insertions, 11 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 1769c690..434ae2d6 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -70,7 +70,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) return false; } string s = (*cfg)["pair_sampling"].as<string>(); - if (s != "all" && s != "5050" && s != "108010" && s != "PRO" && s != "alld") { + if (s != "all" && s != "5050" && s != "108010" && s != "PRO" && s != "alld" && s != "108010d") { cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl; return false; } @@ -391,6 +391,8 @@ main(int argc, char** argv) PROsampling(samples, pairs); if (pair_sampling == "alld") all_pairs_discard(samples, pairs); + if (pair_sampling == "108010d") + multpart108010_discard(samples, pairs); npairs += pairs.size(); for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); diff --git a/dtrain/dtrain.h b/dtrain/dtrain.h index 3d76bd7f..14ef410e 100644 --- a/dtrain/dtrain.h +++ b/dtrain/dtrain.h @@ -13,7 +13,7 @@ #include "filelib.h" -//#define DTRAIN_LOCAL +#define DTRAIN_LOCAL #define DTRAIN_DOTS 100 // when to display a '.' #define DTRAIN_GRAMMAR_DELIM "########EOS########" diff --git a/dtrain/pairsampling.h b/dtrain/pairsampling.h index 9b88a4be..0951f8e9 100644 --- a/dtrain/pairsampling.h +++ b/dtrain/pairsampling.h @@ -49,17 +49,17 @@ multpart108010(vector<ScoredHyp>* s, vector<pair<ScoredHyp,ScoredHyp> >& trainin unsigned sep = sz%slice; if (sep == 0) sep = sz/slice; for (unsigned i = 0; i < sep; i++) { - for(unsigned j = sep; j < sz; j++) { + for (unsigned j = sep; j < sz; j++) { p.first = (*s)[i]; p.second = (*s)[j]; - if(p.first.rank < p.second.rank) training.push_back(p); + if (p.first.rank < p.second.rank) training.push_back(p); } } for (unsigned i = sep; i < sz-sep; i++) { for (unsigned j = sz-sep; j < sz; j++) { p.first = (*s)[i]; p.second = (*s)[j]; - if(p.first.rank < p.second.rank) training.push_back(p); + if (p.first.rank < p.second.rank) training.push_back(p); } } } @@ -118,6 +118,70 @@ all_pairs_discard(vector<ScoredHyp>* s, vector<pair<ScoredHyp,ScoredHyp> >& trai } } +inline void +multpart108010_discard(vector<ScoredHyp>* s, vector<pair<ScoredHyp,ScoredHyp> >& training) +{ + sort(s->begin(), s->end(), _multpart_cmp_hyp_by_score); + pair<ScoredHyp,ScoredHyp> p; + unsigned sz = s->size(); + unsigned slice = 10; + unsigned sep = sz%slice; + if (sep == 0) sep = sz/slice; + for (unsigned i = 0; i < sep; i++) { + for (unsigned j = sep; j < sz; j++) { + p.first = (*s)[i]; + p.second = (*s)[j]; + if (p.first.rank < p.second.rank) { + if (_PRO_accept_pair(p)) training.push_back(p); + } + } + } + for (unsigned i = sep; i < sz-sep; i++) { + for (unsigned j = sz-sep; j < sz; j++) { + p.first = (*s)[i]; + p.second = (*s)[j]; + if (p.first.rank < p.second.rank) { + if (_PRO_accept_pair(p)) training.push_back(p); + } + } + } + sort(training.begin(), training.end(), _PRO_cmp_pair_by_diff); + if (training.size() > 50) + training.erase(training.begin()+50, training.end()); +} + +inline void +multpart108010_discard1(vector<ScoredHyp>* s, vector<pair<ScoredHyp,ScoredHyp> >& training) +{ + sort(s->begin(), s->end(), _multpart_cmp_hyp_by_score); + pair<ScoredHyp,ScoredHyp> p; + unsigned sz = s->size(); + unsigned slice = 10; + unsigned sep = sz%slice; + if (sep == 0) sep = sz/slice; + for (unsigned i = 0; i < sep; i++) { + for (unsigned j = sep; j < sz; j++) { + p.first = (*s)[i]; + p.second = (*s)[j]; + if (p.first.rank < p.second.rank) { + if (_PRO_accept_pair(p)) training.push_back(p); + } + } + } + for (unsigned i = sep; i < sz-sep; i++) { + for (unsigned j = sz-sep; j < sz; j++) { + p.first = (*s)[i]; + p.second = (*s)[j]; + if (p.first.rank < p.second.rank) { + if (_PRO_accept_pair(p)) training.push_back(p); + } + } + } + sort(training.begin(), training.end(), _PRO_cmp_pair_by_diff); + if (training.size() > 50) + training.erase(training.begin()+50, training.end()); +} + } // namespace diff --git a/dtrain/test/example/cdec.ini b/dtrain/test/example/cdec.ini index ff99de7b..51edab09 100644 --- a/dtrain/test/example/cdec.ini +++ b/dtrain/test/example/cdec.ini @@ -5,4 +5,18 @@ intersection_strategy=cube_pruning cubepruning_pop_limit=30 feature_function=WordPenalty feature_function=KLanguageModel test/example/nc-wmt11.en.srilm.gz -#feature_function=RuleIdentityFeatures + +feature_function=RuleIdentityFeatures +#feature_function=SpanFeatures +#feature_function=SourceWordPenalty +#feature_function=SourceSpanSizeFeatures +#feature_function=RuleShape +#feature_function=RuleNgramFeatures +#feature_function=OutputIndicator +#feature_function=NonLatinCount +#feature_function=NgramFeatures +#feature_function=NewJump +#feature_function=LexNullJump +#feature_function=InputIndicator +#feature_function=CMR2008ReorderingFeatures +#feature_function=ArityPenalty diff --git a/dtrain/test/example/dtrain.ini b/dtrain/test/example/dtrain.ini index fab4d317..09b493ad 100644 --- a/dtrain/test/example/dtrain.ini +++ b/dtrain/test/example/dtrain.ini @@ -1,19 +1,19 @@ decoder_config=test/example/cdec.ini -k=100 +k=1500 N=3 learning_rate=0.0005 gamma=0 -epochs=3 +epochs=2 input=test/example/nc-wmt11.1k.gz output=- scorer=stupid_bleu -sample_from=forest +sample_from=kbest #filter=unique -pair_sampling=5050 +pair_sampling=PRO select_weights=last print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PassThrough tmp=/tmp -stop_after=10 +stop_after=100 #keep_w= #update_ok= #l1_reg=clip |