summaryrefslogtreecommitdiff
path: root/mira/kbest_mira.cc
diff options
context:
space:
mode:
Diffstat (limited to 'mira/kbest_mira.cc')
-rw-r--r--mira/kbest_mira.cc79
1 files changed, 60 insertions, 19 deletions
diff --git a/mira/kbest_mira.cc b/mira/kbest_mira.cc
index 459a5e6f..904eba74 100644
--- a/mira/kbest_mira.cc
+++ b/mira/kbest_mira.cc
@@ -10,6 +10,7 @@
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
+#include "hg_sampler.h"
#include "sentence_metadata.h"
#include "scorer.h"
#include "verbose.h"
@@ -54,6 +55,8 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("max_step_size,C", po::value<double>()->default_value(0.01), "regularization strength (C)")
("mt_metric_scale,s", po::value<double>()->default_value(1.0), "Amount to scale MT loss function by")
("k_best_size,k", po::value<int>()->default_value(250), "Size of hypothesis list to search for oracles")
+ ("sample_forest,f", "Instead of a k-best list, sample k hypotheses from the decoder's forest")
+ ("sample_forest_unit_weight_vector,x", "Before sampling (must use -f option), rescale the weight vector used so it has unit length; this may improve the quality of the samples")
("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)")
("decoder_config,c",po::value<string>(),"Decoder configuration file");
po::options_description clo("Command line options");
@@ -91,11 +94,12 @@ struct GoodBadOracle {
};
struct TrainingObserver : public DecoderObserver {
- TrainingObserver(const int k, const DocScorer& d, vector<GoodBadOracle>* o) : ds(d), oracles(*o), kbest_size(k) {}
+ TrainingObserver(const int k, const DocScorer& d, bool sf, vector<GoodBadOracle>* o) : ds(d), oracles(*o), kbest_size(k), sample_forest(sf) {}
const DocScorer& ds;
vector<GoodBadOracle>& oracles;
shared_ptr<HypothesisInfo> cur_best;
const int kbest_size;
+ const bool sample_forest;
const HypothesisInfo& GetCurrentBestHypothesis() const {
return *cur_best;
@@ -116,24 +120,43 @@ struct TrainingObserver : public DecoderObserver {
shared_ptr<HypothesisInfo>& cur_good = oracles[sent_id].good;
shared_ptr<HypothesisInfo>& cur_bad = oracles[sent_id].bad;
cur_bad.reset(); // TODO get rid of??
- KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, kbest_size);
- for (int i = 0; i < kbest_size; ++i) {
- const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
- kbest.LazyKthBest(forest.nodes_.size() - 1, i);
- if (!d) break;
- float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore();
- if (invert_score) sentscore *= -1.0;
- // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl;
- if (i == 0)
- cur_best = MakeHypothesisInfo(d->feature_values, sentscore);
- if (!cur_good || sentscore > cur_good->mt_metric)
- cur_good = MakeHypothesisInfo(d->feature_values, sentscore);
- if (!cur_bad || sentscore < cur_bad->mt_metric)
- cur_bad = MakeHypothesisInfo(d->feature_values, sentscore);
+
+ if (sample_forest) {
+ vector<WordID> cur_prediction;
+ ViterbiESentence(forest, &cur_prediction);
+ float sentscore = ds[sent_id]->ScoreCandidate(cur_prediction)->ComputeScore();
+ cur_best = MakeHypothesisInfo(ViterbiFeatures(forest), sentscore);
+
+ vector<HypergraphSampler::Hypothesis> samples;
+ HypergraphSampler::sample_hypotheses(forest, kbest_size, &*rng, &samples);
+ for (unsigned i = 0; i < samples.size(); ++i) {
+ sentscore = ds[sent_id]->ScoreCandidate(samples[i].words)->ComputeScore();
+ if (invert_score) sentscore *= -1.0;
+ if (!cur_good || sentscore > cur_good->mt_metric)
+ cur_good = MakeHypothesisInfo(samples[i].fmap, sentscore);
+ if (!cur_bad || sentscore < cur_bad->mt_metric)
+ cur_bad = MakeHypothesisInfo(samples[i].fmap, sentscore);
+ }
+ } else {
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, kbest_size);
+ for (int i = 0; i < kbest_size; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(forest.nodes_.size() - 1, i);
+ if (!d) break;
+ float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore();
+ if (invert_score) sentscore *= -1.0;
+ // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl;
+ if (i == 0)
+ cur_best = MakeHypothesisInfo(d->feature_values, sentscore);
+ if (!cur_good || sentscore > cur_good->mt_metric)
+ cur_good = MakeHypothesisInfo(d->feature_values, sentscore);
+ if (!cur_bad || sentscore < cur_bad->mt_metric)
+ cur_bad = MakeHypothesisInfo(d->feature_values, sentscore);
+ }
+ //cerr << "GOOD: " << cur_good->mt_metric << endl;
+ //cerr << " CUR: " << cur_best->mt_metric << endl;
+ //cerr << " BAD: " << cur_bad->mt_metric << endl;
}
- //cerr << "GOOD: " << cur_good->mt_metric << endl;
- //cerr << " CUR: " << cur_best->mt_metric << endl;
- //cerr << " BAD: " << cur_bad->mt_metric << endl;
}
};
@@ -164,6 +187,12 @@ int main(int argc, char** argv) {
rng.reset(new MT19937(conf["random_seed"].as<uint32_t>()));
else
rng.reset(new MT19937);
+ const bool sample_forest = conf.count("sample_forest") > 0;
+ const bool sample_forest_unit_weight_vector = conf.count("sample_forest_unit_weight_vector") > 0;
+ if (sample_forest_unit_weight_vector && !sample_forest) {
+ cerr << "Cannot --sample_forest_unit_weight_vector without --sample_forest" << endl;
+ return 1;
+ }
vector<string> corpus;
ReadTrainingCorpus(conf["source"].as<string>(), &corpus);
const string metric_name = conf["mt_metric"].as<string>();
@@ -195,7 +224,7 @@ int main(int argc, char** argv) {
assert(corpus.size() > 0);
vector<GoodBadOracle> oracles(corpus.size());
- TrainingObserver observer(conf["k_best_size"].as<int>(), ds, &oracles);
+ TrainingObserver observer(conf["k_best_size"].as<int>(), ds, sample_forest, &oracles);
int cur_sent = 0;
int lcount = 0;
int normalizer = 0;
@@ -234,7 +263,19 @@ int main(int argc, char** argv) {
cerr << "PASS " << (lcount / corpus.size() + 1) << endl;
}
decoder.SetId(order[cur_sent]);
+ double sc = 1.0;
+ if (sample_forest_unit_weight_vector) {
+ sc = lambdas.l2norm();
+ if (sc > 0) {
+ for (unsigned i = 0; i < dense_weights.size(); ++i)
+ dense_weights[i] /= sc;
+ }
+ }
decoder.Decode(corpus[order[cur_sent]], &observer); // update oracles
+ if (sc && sc != 1.0) {
+ for (unsigned i = 0; i < dense_weights.size(); ++i)
+ dense_weights[i] *= sc;
+ }
const HypothesisInfo& cur_hyp = observer.GetCurrentBestHypothesis();
const HypothesisInfo& cur_good = *oracles[order[cur_sent]].good;
const HypothesisInfo& cur_bad = *oracles[order[cur_sent]].bad;