summaryrefslogtreecommitdiff
path: root/dtrain
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain')
-rw-r--r--dtrain/dtrain.cc4
-rw-r--r--dtrain/dtrain.h2
-rw-r--r--dtrain/pairsampling.h70
-rw-r--r--dtrain/test/example/cdec.ini16
-rw-r--r--dtrain/test/example/dtrain.ini10
5 files changed, 91 insertions, 11 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 1769c690..434ae2d6 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -70,7 +70,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
return false;
}
string s = (*cfg)["pair_sampling"].as<string>();
- if (s != "all" && s != "5050" && s != "108010" && s != "PRO" && s != "alld") {
+ if (s != "all" && s != "5050" && s != "108010" && s != "PRO" && s != "alld" && s != "108010d") {
cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl;
return false;
}
@@ -391,6 +391,8 @@ main(int argc, char** argv)
PROsampling(samples, pairs);
if (pair_sampling == "alld")
all_pairs_discard(samples, pairs);
+ if (pair_sampling == "108010d")
+ multpart108010_discard(samples, pairs);
npairs += pairs.size();
for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
diff --git a/dtrain/dtrain.h b/dtrain/dtrain.h
index 3d76bd7f..14ef410e 100644
--- a/dtrain/dtrain.h
+++ b/dtrain/dtrain.h
@@ -13,7 +13,7 @@
#include "filelib.h"
-//#define DTRAIN_LOCAL
+#define DTRAIN_LOCAL
#define DTRAIN_DOTS 100 // when to display a '.'
#define DTRAIN_GRAMMAR_DELIM "########EOS########"
diff --git a/dtrain/pairsampling.h b/dtrain/pairsampling.h
index 9b88a4be..0951f8e9 100644
--- a/dtrain/pairsampling.h
+++ b/dtrain/pairsampling.h
@@ -49,17 +49,17 @@ multpart108010(vector<ScoredHyp>* s, vector<pair<ScoredHyp,ScoredHyp> >& trainin
unsigned sep = sz%slice;
if (sep == 0) sep = sz/slice;
for (unsigned i = 0; i < sep; i++) {
- for(unsigned j = sep; j < sz; j++) {
+ for (unsigned j = sep; j < sz; j++) {
p.first = (*s)[i];
p.second = (*s)[j];
- if(p.first.rank < p.second.rank) training.push_back(p);
+ if (p.first.rank < p.second.rank) training.push_back(p);
}
}
for (unsigned i = sep; i < sz-sep; i++) {
for (unsigned j = sz-sep; j < sz; j++) {
p.first = (*s)[i];
p.second = (*s)[j];
- if(p.first.rank < p.second.rank) training.push_back(p);
+ if (p.first.rank < p.second.rank) training.push_back(p);
}
}
}
@@ -118,6 +118,70 @@ all_pairs_discard(vector<ScoredHyp>* s, vector<pair<ScoredHyp,ScoredHyp> >& trai
}
}
+inline void
+multpart108010_discard(vector<ScoredHyp>* s, vector<pair<ScoredHyp,ScoredHyp> >& training)
+{
+ sort(s->begin(), s->end(), _multpart_cmp_hyp_by_score);
+ pair<ScoredHyp,ScoredHyp> p;
+ unsigned sz = s->size();
+ unsigned slice = 10;
+ unsigned sep = sz%slice;
+ if (sep == 0) sep = sz/slice;
+ for (unsigned i = 0; i < sep; i++) {
+ for (unsigned j = sep; j < sz; j++) {
+ p.first = (*s)[i];
+ p.second = (*s)[j];
+ if (p.first.rank < p.second.rank) {
+ if (_PRO_accept_pair(p)) training.push_back(p);
+ }
+ }
+ }
+ for (unsigned i = sep; i < sz-sep; i++) {
+ for (unsigned j = sz-sep; j < sz; j++) {
+ p.first = (*s)[i];
+ p.second = (*s)[j];
+ if (p.first.rank < p.second.rank) {
+ if (_PRO_accept_pair(p)) training.push_back(p);
+ }
+ }
+ }
+ sort(training.begin(), training.end(), _PRO_cmp_pair_by_diff);
+ if (training.size() > 50)
+ training.erase(training.begin()+50, training.end());
+}
+
+inline void
+multpart108010_discard1(vector<ScoredHyp>* s, vector<pair<ScoredHyp,ScoredHyp> >& training)
+{
+ sort(s->begin(), s->end(), _multpart_cmp_hyp_by_score);
+ pair<ScoredHyp,ScoredHyp> p;
+ unsigned sz = s->size();
+ unsigned slice = 10;
+ unsigned sep = sz%slice;
+ if (sep == 0) sep = sz/slice;
+ for (unsigned i = 0; i < sep; i++) {
+ for (unsigned j = sep; j < sz; j++) {
+ p.first = (*s)[i];
+ p.second = (*s)[j];
+ if (p.first.rank < p.second.rank) {
+ if (_PRO_accept_pair(p)) training.push_back(p);
+ }
+ }
+ }
+ for (unsigned i = sep; i < sz-sep; i++) {
+ for (unsigned j = sz-sep; j < sz; j++) {
+ p.first = (*s)[i];
+ p.second = (*s)[j];
+ if (p.first.rank < p.second.rank) {
+ if (_PRO_accept_pair(p)) training.push_back(p);
+ }
+ }
+ }
+ sort(training.begin(), training.end(), _PRO_cmp_pair_by_diff);
+ if (training.size() > 50)
+ training.erase(training.begin()+50, training.end());
+}
+
} // namespace
diff --git a/dtrain/test/example/cdec.ini b/dtrain/test/example/cdec.ini
index ff99de7b..51edab09 100644
--- a/dtrain/test/example/cdec.ini
+++ b/dtrain/test/example/cdec.ini
@@ -5,4 +5,18 @@ intersection_strategy=cube_pruning
cubepruning_pop_limit=30
feature_function=WordPenalty
feature_function=KLanguageModel test/example/nc-wmt11.en.srilm.gz
-#feature_function=RuleIdentityFeatures
+
+feature_function=RuleIdentityFeatures
+#feature_function=SpanFeatures
+#feature_function=SourceWordPenalty
+#feature_function=SourceSpanSizeFeatures
+#feature_function=RuleShape
+#feature_function=RuleNgramFeatures
+#feature_function=OutputIndicator
+#feature_function=NonLatinCount
+#feature_function=NgramFeatures
+#feature_function=NewJump
+#feature_function=LexNullJump
+#feature_function=InputIndicator
+#feature_function=CMR2008ReorderingFeatures
+#feature_function=ArityPenalty
diff --git a/dtrain/test/example/dtrain.ini b/dtrain/test/example/dtrain.ini
index fab4d317..09b493ad 100644
--- a/dtrain/test/example/dtrain.ini
+++ b/dtrain/test/example/dtrain.ini
@@ -1,19 +1,19 @@
decoder_config=test/example/cdec.ini
-k=100
+k=1500
N=3
learning_rate=0.0005
gamma=0
-epochs=3
+epochs=2
input=test/example/nc-wmt11.1k.gz
output=-
scorer=stupid_bleu
-sample_from=forest
+sample_from=kbest
#filter=unique
-pair_sampling=5050
+pair_sampling=PRO
select_weights=last
print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PassThrough
tmp=/tmp
-stop_after=10
+stop_after=100
#keep_w=
#update_ok=
#l1_reg=clip