summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/dtrain/dtrain.cc35
-rw-r--r--training/dtrain/dtrain.h9
-rw-r--r--training/dtrain/example/standard/dtrain.ini1
-rw-r--r--training/dtrain/sample.h6
-rw-r--r--training/dtrain/score.h356
5 files changed, 327 insertions, 80 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index e5cfd50a..97df530b 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -10,9 +10,9 @@ main(int argc, char** argv)
{
// get configuration
po::variables_map conf;
- if (!dtrain_init(argc, argv, &conf))
- exit(1); // something is wrong
+ dtrain_init(argc, argv, &conf);
const size_t k = conf["k"].as<size_t>();
+ const string score_name = conf["score"].as<string>();
const size_t N = conf["N"].as<size_t>();
const size_t T = conf["iterations"].as<size_t>();
const weight_t eta = conf["learning_rate"].as<weight_t>();
@@ -25,12 +25,28 @@ main(int argc, char** argv)
boost::split(print_weights, conf["print_weights"].as<string>(),
boost::is_any_of(" "));
- // setup decoder
+ // setup decoder and scorer
register_feature_functions();
SetSilent(true);
ReadFile f(conf["decoder_conf"].as<string>());
Decoder decoder(f.stream());
- ScoredKbest* observer = new ScoredKbest(k, new PerSentenceBleuScorer(N));
+ Scorer* scorer;
+ if (score_name == "nakov") {
+ scorer = static_cast<PerSentenceBleuScorer*>(new PerSentenceBleuScorer(N));
+ } else if (score_name == "papineni") {
+ scorer = static_cast<BleuScorer*>(new BleuScorer(N));
+ } else if (score_name == "lin") {
+ scorer = static_cast<OriginalPerSentenceBleuScorer*>\
+ (new OriginalPerSentenceBleuScorer(N));
+ } else if (score_name == "liang") {
+ scorer = static_cast<SmoothPerSentenceBleuScorer*>\
+ (new SmoothPerSentenceBleuScorer(N));
+ } else if (score_name == "chiang") {
+ scorer = static_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N));
+ } else {
+ assert(false);
+ }
+ ScoredKbest* observer = new ScoredKbest(k, scorer);
// weights
vector<weight_t>& decoder_weights = decoder.CurrentWeightVector();
@@ -52,6 +68,7 @@ main(int argc, char** argv)
// output configuration
cerr << "dtrain" << endl << "Parameters:" << endl;
cerr << setw(25) << "k " << k << endl;
+ cerr << setw(25) << "score " << "'" << score_name << "'" << endl;
cerr << setw(25) << "N " << N << endl;
cerr << setw(25) << "T " << T << endl;
cerr << setw(25) << "learning rate " << eta << endl;
@@ -149,6 +166,16 @@ main(int argc, char** argv)
lambdas_copy = lambdas;
lambdas.plus_eq_v_times_s(updates, eta);
+ // update context for approx. BLEU
+ if (score_name == "chiang") {
+ for (auto it: *samples) {
+ if (it.rank == 0) {
+ scorer->UpdateContext(it.w, buf_ngs[i], buf_ls[i], 0.9);
+ break;
+ }
+ }
+ }
+
// l1 regularization
// NB: regularization is done after each sentence,
// not after every single pair!
diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h
index b0ee348c..2636fa89 100644
--- a/training/dtrain/dtrain.h
+++ b/training/dtrain/dtrain.h
@@ -43,7 +43,7 @@ inline ostream& _np(ostream& out) { return out << resetiosflags(ios::showpos); }
inline ostream& _p(ostream& out) { return out << setiosflags(ios::showpos); }
inline ostream& _p4(ostream& out) { return out << setprecision(4); }
-bool
+void
dtrain_init(int argc, char** argv, po::variables_map* conf)
{
po::options_description ini("Configuration File Options");
@@ -55,6 +55,7 @@ dtrain_init(int argc, char** argv, po::variables_map* conf)
("learning_rate,l", po::value<weight_t>()->default_value(1.0), "learning rate")
("l1_reg,r", po::value<weight_t>()->default_value(0.), "l1 regularization strength")
("margin,m", po::value<weight_t>()->default_value(0.), "margin for margin perceptron")
+ ("score,s", po::value<string>()->default_value("nakov"), "per-sentence BLEU approx.")
("N", po::value<size_t>()->default_value(4), "N for BLEU approximation")
("input_weights,w", po::value<string>(), "input weights file")
("average,a", po::value<bool>()->default_value(false), "output average weights")
@@ -74,14 +75,12 @@ dtrain_init(int argc, char** argv, po::variables_map* conf)
po::notify(*conf);
if (!conf->count("decoder_conf")) {
cerr << "Missing decoder configuration." << endl;
- return false;
+ assert(false);
}
if (!conf->count("bitext")) {
cerr << "No input given." << endl;
- return false;
+ assert(false);
}
-
- return true;
}
} // namespace
diff --git a/training/dtrain/example/standard/dtrain.ini b/training/dtrain/example/standard/dtrain.ini
index a0d64fa0..c52bef4a 100644
--- a/training/dtrain/example/standard/dtrain.ini
+++ b/training/dtrain/example/standard/dtrain.ini
@@ -7,3 +7,4 @@ N=4 # optimize (approx.) BLEU4
learning_rate=0.1 # learning rate
margin=1.0 # margin for margin perceptron
print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PhraseModel_5 PhraseModel_6 PassThrough
+score=nakov
diff --git a/training/dtrain/sample.h b/training/dtrain/sample.h
index 03cc82c3..1249e372 100644
--- a/training/dtrain/sample.h
+++ b/training/dtrain/sample.h
@@ -13,15 +13,15 @@ struct ScoredKbest : public DecoderObserver
const size_t k_;
size_t feature_count_, effective_sz_;
vector<ScoredHyp> samples_;
- PerSentenceBleuScorer* scorer_;
+ Scorer* scorer_;
vector<Ngrams>* ref_ngs_;
vector<size_t>* ref_ls_;
- ScoredKbest(const size_t k, PerSentenceBleuScorer* scorer) :
+ ScoredKbest(const size_t k, Scorer* scorer) :
k_(k), scorer_(scorer) {}
virtual void
- NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg)
+ NotifyTranslationForest(const SentenceMetadata& /*smeta*/, Hypergraph* hg)
{
samples_.clear(); effective_sz_ = feature_count_ = 0;
KBest::KBestDerivations<vector<WordID>, ESentenceTraversal,
diff --git a/training/dtrain/score.h b/training/dtrain/score.h
index 06dbc5a4..ca3da39b 100644
--- a/training/dtrain/score.h
+++ b/training/dtrain/score.h
@@ -12,6 +12,8 @@ struct NgramCounts
map<size_t, weight_t> clipped_;
map<size_t, weight_t> sum_;
+ NgramCounts() {}
+
NgramCounts(const size_t N) : N_(N) { Zero(); }
inline void
@@ -24,13 +26,13 @@ struct NgramCounts
}
}
- inline const NgramCounts
- operator+(const NgramCounts &other) const
+ inline void
+ operator*=(const weight_t rhs)
{
- NgramCounts result = *this;
- result += other;
-
- return result;
+ for (unsigned i = 0; i < N_; i++) {
+ this->clipped_[i] *= rhs;
+ this->sum_[i] *= rhs;
+ }
}
inline void
@@ -112,85 +114,303 @@ MakeNgramCounts(const vector<WordID>& hyp,
return counts;
}
+class Scorer
+{
+ protected:
+ const size_t N_;
+ vector<weight_t> w_;
+
+ public:
+ Scorer(size_t n): N_(n)
+ {
+ for (size_t i = 1; i <= N_; i++)
+ w_.push_back(1.0/N_);
+ }
+
+ inline bool
+ Init(const vector<WordID>& hyp,
+ const vector<Ngrams>& ref_ngs,
+ const vector<size_t>& ref_ls,
+ size_t& hl,
+ size_t& rl,
+ size_t& M,
+ vector<weight_t>& v,
+ NgramCounts& counts)
+ {
+ hl = hyp.size();
+ if (hl == 0) return false;
+ rl = BestMatchLength(hl, ref_ls);
+ if (rl == 0) return false;
+ counts = MakeNgramCounts(hyp, ref_ngs, N_);
+ if (rl < N_) {
+ M = rl;
+ for (size_t i = 0; i < M; i++) v.push_back(1/((weight_t)M));
+ } else {
+ M = N_;
+ v = w_;
+ }
+
+ return true;
+ }
+
+ inline weight_t
+ BrevityPenalty(const size_t hl, const size_t rl)
+ {
+ if (hl > rl)
+ return 1;
+
+ return exp(1 - (weight_t)rl/hl);
+ }
+
+ inline size_t
+ BestMatchLength(const size_t hl,
+ const vector<size_t>& ref_ls)
+ {
+ size_t m;
+ if (ref_ls.size() == 1) {
+ m = ref_ls.front();
+ } else {
+ size_t i = 0, best_idx = 0;
+ size_t best = numeric_limits<size_t>::max();
+ for (auto l: ref_ls) {
+ size_t d = abs(hl-l);
+ if (d < best) {
+ best_idx = i;
+ best = d;
+ }
+ i += 1;
+ }
+ m = ref_ls[best_idx];
+ }
+
+ return m;
+ }
+
+ virtual weight_t
+ Score(const vector<WordID>&,
+ const vector<Ngrams>&,
+ const vector<size_t>&) = 0;
+
+ void
+ UpdateContext(const vector<WordID>& /*hyp*/,
+ const vector<Ngrams>& /*ref_ngs*/,
+ const vector<size_t>& /*ref_ls*/,
+ weight_t /*decay*/) {}
+};
+
/*
- * per-sentence BLEU
+ * 'fixed' per-sentence BLEU
+ * simply add 1 to reference length for calculation of BP
+ *
* as in "Optimizing for Sentence-Level BLEU+1
* Yields Short Translations"
* (Nakov et al. '12)
*
- * [simply add 1 to reference length for calculation of BP]
+ */
+class PerSentenceBleuScorer : public Scorer
+{
+ public:
+ PerSentenceBleuScorer(size_t n) : Scorer(n) {}
+
+ weight_t
+ Score(const vector<WordID>& hyp,
+ const vector<Ngrams>& ref_ngs,
+ const vector<size_t>& ref_ls)
+ {
+ size_t hl, rl, M;
+ vector<weight_t> v;
+ NgramCounts counts;
+ if (!Init(hyp, ref_ngs, ref_ls, hl, rl, M, v, counts))
+ return 0.;
+ weight_t sum=0, add=0;
+ for (size_t i = 0; i < M; i++) {
+ if (i == 0 && (counts.sum_[i] == 0 || counts.clipped_[i] == 0)) return 0.;
+ if (i > 0) add = 1;
+ sum += v[i] * log(((weight_t)counts.clipped_[i] + add)
+ / ((counts.sum_[i] + add)));
+ }
+
+ return BrevityPenalty(hl, rl+1) * exp(sum);
+ }
+};
+
+
+/*
+ * BLEU
+ * 0 if for one n \in {1..N} count is 0
+ *
+ * as in "BLEU: a Method for Automatic Evaluation
+ * of Machine Translation"
+ * (Papineni et al. '02)
*
*/
-struct PerSentenceBleuScorer
+
+class BleuScorer : public Scorer
{
- const size_t N_;
- vector<weight_t> w_;
+ public:
+ BleuScorer(size_t n) : Scorer(n) {}
- PerSentenceBleuScorer(size_t n) : N_(n)
- {
- for (size_t i = 1; i <= N_; i++)
- w_.push_back(1.0/N_);
- }
+ weight_t
+ Score(const vector<WordID>& hyp,
+ const vector<Ngrams>& ref_ngs,
+ const vector<size_t>& ref_ls)
+ {
+ size_t hl, rl, M;
+ vector<weight_t> v;
+ NgramCounts counts;
+ if (!Init(hyp, ref_ngs, ref_ls, hl, rl, M, v, counts))
+ return 0.;
+ weight_t sum = 0;
+ for (size_t i = 0; i < M; i++) {
+ if (counts.sum_[i] == 0 || counts.clipped_[i] == 0) return 0.;
+ sum += v[i] * log((weight_t)counts.clipped_[i]/counts.sum_[i]);
+ }
- inline weight_t
- BrevityPenalty(const size_t hl, const size_t rl)
- {
- if (hl > rl)
- return 1;
+ return BrevityPenalty(hl, rl) * exp(sum);
+ }
+};
- return exp(1 - (weight_t)rl/hl);
- }
+/*
+ * original BLEU+1
+ * 0 iff no 1gram match ('grounded')
+ *
+ * as in "ORANGE: a Method for Evaluating
+ * Automatic Evaluation Metrics
+ * for Machine Translation"
+ * (Lin & Och '04)
+ *
+ */
+class OriginalPerSentenceBleuScorer : public Scorer
+{
+ public:
+ OriginalPerSentenceBleuScorer(size_t n) : Scorer(n) {}
- inline size_t
- BestMatchLength(const size_t hl,
- const vector<size_t>& ref_ls)
- {
- size_t m;
- if (ref_ls.size() == 1) {
- m = ref_ls.front();
- } else {
- size_t i = 0, best_idx = 0;
- size_t best = numeric_limits<size_t>::max();
- for (auto l: ref_ls) {
- size_t d = abs(hl-l);
- if (d < best) {
- best_idx = i;
- best = d;
+ weight_t
+ Score(const vector<WordID>& hyp,
+ const vector<Ngrams>& ref_ngs,
+ const vector<size_t>& ref_ls)
+ {
+ size_t hl, rl, M;
+ vector<weight_t> v;
+ NgramCounts counts;
+ if (!Init(hyp, ref_ngs, ref_ls, hl, rl, M, v, counts))
+ return 0.;
+ weight_t sum=0, add=0;
+ for (size_t i = 0; i < M; i++) {
+ if (i == 0 && (counts.sum_[i] == 0 || counts.clipped_[i] == 0)) return 0.;
+ if (i == 1) add = 1;
+ sum += v[i] * log(((weight_t)counts.clipped_[i] + add)/((counts.sum_[i] + add)));
+ }
+
+ return BrevityPenalty(hl, rl) * exp(sum);
+ }
+};
+
+/*
+ * smooth BLEU
+ * max is 0.9375 (with N=4)
+ *
+ * as in "An End-to-End Discriminative Approach
+ * to Machine Translation"
+ * (Liang et al. '06)
+ *
+ */
+class SmoothPerSentenceBleuScorer : public Scorer
+{
+ public:
+ SmoothPerSentenceBleuScorer(size_t n) : Scorer(n) {}
+
+ weight_t
+ Score(const vector<WordID>& hyp,
+ const vector<Ngrams>& ref_ngs,
+ const vector<size_t>& ref_ls)
+ {
+ size_t hl=hyp.size(), rl=BestMatchLength(hl, ref_ls);
+ if (hl == 0 || rl == 0) return 0.;
+ NgramCounts counts = MakeNgramCounts(hyp, ref_ngs, N_);
+ size_t M = N_;
+ if (rl < N_) M = rl;
+ weight_t sum = 0.;
+ vector<weight_t> i_bleu;
+ for (size_t i=0; i < M; i++)
+ i_bleu.push_back(0.);
+ for (size_t i=0; i < M; i++) {
+ if (counts.sum_[i] == 0 || counts.clipped_[i] == 0) {
+ break;
+ } else {
+ weight_t i_score = log((weight_t)counts.clipped_[i]/counts.sum_[i]);
+ for (size_t j=i; j < M; j++) {
+ i_bleu[j] += (1/((weight_t)j+1)) * i_score;
+ }
}
- i += 1;
+ sum += exp(i_bleu[i])/pow(2.0, (double)(N_-i+2));
}
- m = ref_ls[best_idx];
+
+ return BrevityPenalty(hl, hl) * sum;
+ }
+};
+
+/*
+ * approx. bleu
+ * Needs some more code in dtrain.cc .
+ * We do not scaling by source lengths, as hypotheses are compared only
+ * within an kbest list, not globally.
+ *
+ * as in "Online Large-Margin Training of Syntactic
+ * and Structural Translation Features"
+ * (Chiang et al. '08)
+ *
+
+ */
+class ApproxBleuScorer : public Scorer
+{
+ private:
+ NgramCounts context;
+ weight_t hyp_sz_sum;
+ weight_t ref_sz_sum;
+
+ public:
+ ApproxBleuScorer(size_t n) :
+ Scorer(n), context(n), hyp_sz_sum(0), ref_sz_sum(0) {}
+
+ weight_t
+ Score(const vector<WordID>& hyp,
+ const vector<Ngrams>& ref_ngs,
+ const vector<size_t>& ref_ls)
+ {
+ size_t hl, rl, M;
+ vector<weight_t> v;
+ NgramCounts counts;
+ if (!Init(hyp, ref_ngs, ref_ls, hl, rl, M, v, counts))
+ return 0.;
+ counts += context;
+ weight_t sum = 0;
+ for (size_t i = 0; i < M; i++) {
+ if (counts.sum_[i] == 0 || counts.clipped_[i] == 0) return 0.;
+ sum += v[i] * log((weight_t)counts.clipped_[i]/counts.sum_[i]);
+ }
+
+ return BrevityPenalty(hyp_sz_sum+hl, ref_sz_sum+rl) * exp(sum);
}
- return m;
- }
+ void
+ UpdateContext(const vector<WordID>& hyp,
+ const vector<Ngrams>& ref_ngs,
+ const vector<size_t>& ref_ls,
+ weight_t decay=0.9)
+ {
+ size_t hl, rl, M;
+ vector<weight_t> v;
+ NgramCounts counts;
+ Init(hyp, ref_ngs, ref_ls, hl, rl, M, v, counts);
- weight_t
- Score(const vector<WordID>& hyp,
- const vector<Ngrams>& ref_ngs,
- const vector<size_t>& ref_ls)
- {
- size_t hl = hyp.size(), rl = 0;
- if (hl == 0) return 0.;
- rl = BestMatchLength(hl, ref_ls);
- if (rl == 0) return 0.;
- NgramCounts counts = MakeNgramCounts(hyp, ref_ngs, N_);
- size_t M = N_;
- vector<weight_t> v = w_;
- if (rl < N_) {
- M = rl;
- for (size_t i = 0; i < M; i++) v[i] = 1/((weight_t)M);
- }
- weight_t sum = 0, add = 0;
- for (size_t i = 0; i < M; i++) {
- if (i == 0 && (counts.sum_[i] == 0 || counts.clipped_[i] == 0)) return 0.;
- if (i > 0) add = 1;
- sum += v[i] * log(((weight_t)counts.clipped_[i] + add)
- / ((counts.sum_[i] + add)));
- }
-
- return BrevityPenalty(hl, rl+1) * exp(sum);
- }
+ context += counts;
+ context *= decay;
+ hyp_sz_sum += hl;
+ hyp_sz_sum *= decay;
+ ref_sz_sum += rl;
+ ref_sz_sum *= decay;
+ }
};
} // namespace