diff options
-rw-r--r-- | dtrain/dtrain.cc | 9 | ||||
-rw-r--r-- | dtrain/pairsampling.h | 20 |
2 files changed, 20 insertions, 9 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 3d3aa2d3..1769c690 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -69,10 +69,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'uniq' or 'no'." << endl; return false; } - if ((*cfg)["pair_sampling"].as<string>() != "all" - && (*cfg)["pair_sampling"].as<string>() != "5050" && (*cfg)["pair_sampling"].as<string>() != "108010" - && (*cfg)["pair_sampling"].as<string>() != "PRO") { - cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "', use 'all' or 'rand'." << endl; + string s = (*cfg)["pair_sampling"].as<string>(); + if (s != "all" && s != "5050" && s != "108010" && s != "PRO" && s != "alld") { + cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl; return false; } if ((*cfg)["select_weights"].as<string>() != "last" @@ -390,6 +389,8 @@ main(int argc, char** argv) multpart108010(samples, pairs); if (pair_sampling == "PRO") PROsampling(samples, pairs); + if (pair_sampling == "alld") + all_pairs_discard(samples, pairs); npairs += pairs.size(); for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); diff --git a/dtrain/pairsampling.h b/dtrain/pairsampling.h index 4399dfee..9b88a4be 100644 --- a/dtrain/pairsampling.h +++ b/dtrain/pairsampling.h @@ -82,18 +82,14 @@ PROsampling(vector<ScoredHyp>* s, vector<pair<ScoredHyp,ScoredHyp> >& training) { unsigned max_count = 5000, count = 0; bool b = false; - //unsigned max_pairs = (s->size()*(s->size()-1))/2; - vector<pair<unsigned,unsigned> > taken; for (unsigned i = 0; i < s->size()-1; i++) { for (unsigned j = i+1; j < s->size(); j++) { pair<ScoredHyp,ScoredHyp> p; p.first = (*s)[i]; p.second = (*s)[j]; - vector<pair<unsigned,unsigned> >::iterator it = find(taken.begin(), taken.end(), make_pair(i, j)); - if (_PRO_accept_pair(p) && it == taken.end()) { + if (_PRO_accept_pair(p)) { training.push_back(p); count++; - taken.push_back(make_pair(i, j)); if (count == max_count) { b = true; break; @@ -108,6 +104,20 @@ PROsampling(vector<ScoredHyp>* s, vector<pair<ScoredHyp,ScoredHyp> >& training) return; } +inline void +all_pairs_discard(vector<ScoredHyp>* s, vector<pair<ScoredHyp,ScoredHyp> >& training) +{ + for (unsigned i = 0; i < s->size()-1; i++) { + for (unsigned j = i+1; j < s->size(); j++) { + pair<ScoredHyp,ScoredHyp> p; + p.first = (*s)[i]; + p.second = (*s)[j]; + if(_PRO_accept_pair(p)) + training.push_back(p); + } + } +} + } // namespace |