summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--dtrain/dtrain.cc9
-rw-r--r--dtrain/pairsampling.h20
2 files changed, 20 insertions, 9 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 3d3aa2d3..1769c690 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -69,10 +69,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'uniq' or 'no'." << endl;
return false;
}
- if ((*cfg)["pair_sampling"].as<string>() != "all"
- && (*cfg)["pair_sampling"].as<string>() != "5050" && (*cfg)["pair_sampling"].as<string>() != "108010"
- && (*cfg)["pair_sampling"].as<string>() != "PRO") {
- cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "', use 'all' or 'rand'." << endl;
+ string s = (*cfg)["pair_sampling"].as<string>();
+ if (s != "all" && s != "5050" && s != "108010" && s != "PRO" && s != "alld") {
+ cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl;
return false;
}
if ((*cfg)["select_weights"].as<string>() != "last"
@@ -390,6 +389,8 @@ main(int argc, char** argv)
multpart108010(samples, pairs);
if (pair_sampling == "PRO")
PROsampling(samples, pairs);
+ if (pair_sampling == "alld")
+ all_pairs_discard(samples, pairs);
npairs += pairs.size();
for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
diff --git a/dtrain/pairsampling.h b/dtrain/pairsampling.h
index 4399dfee..9b88a4be 100644
--- a/dtrain/pairsampling.h
+++ b/dtrain/pairsampling.h
@@ -82,18 +82,14 @@ PROsampling(vector<ScoredHyp>* s, vector<pair<ScoredHyp,ScoredHyp> >& training)
{
unsigned max_count = 5000, count = 0;
bool b = false;
- //unsigned max_pairs = (s->size()*(s->size()-1))/2;
- vector<pair<unsigned,unsigned> > taken;
for (unsigned i = 0; i < s->size()-1; i++) {
for (unsigned j = i+1; j < s->size(); j++) {
pair<ScoredHyp,ScoredHyp> p;
p.first = (*s)[i];
p.second = (*s)[j];
- vector<pair<unsigned,unsigned> >::iterator it = find(taken.begin(), taken.end(), make_pair(i, j));
- if (_PRO_accept_pair(p) && it == taken.end()) {
+ if (_PRO_accept_pair(p)) {
training.push_back(p);
count++;
- taken.push_back(make_pair(i, j));
if (count == max_count) {
b = true;
break;
@@ -108,6 +104,20 @@ PROsampling(vector<ScoredHyp>* s, vector<pair<ScoredHyp,ScoredHyp> >& training)
return;
}
+inline void
+all_pairs_discard(vector<ScoredHyp>* s, vector<pair<ScoredHyp,ScoredHyp> >& training)
+{
+ for (unsigned i = 0; i < s->size()-1; i++) {
+ for (unsigned j = i+1; j < s->size(); j++) {
+ pair<ScoredHyp,ScoredHyp> p;
+ p.first = (*s)[i];
+ p.second = (*s)[j];
+ if(_PRO_accept_pair(p))
+ training.push_back(p);
+ }
+ }
+}
+
} // namespace