diff options
author | Patrick Simianer <p@simianer.de> | 2011-11-13 12:26:23 +0100 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2011-11-13 12:26:23 +0100 |
commit | effc9bfc40a0559ce36a155daa15e0dc53e93b75 (patch) | |
tree | 768a29ebad48089e3445c515d47f49c942f09124 /mira | |
parent | ed8ca37550910a540e755ada119e814f13eeef03 (diff) | |
parent | a5592c9ab0266dbf4993e42e82e5a113316990ad (diff) |
merge upstream/master
Diffstat (limited to 'mira')
-rw-r--r-- | mira/kbest_mira.cc | 79 |
1 files changed, 60 insertions, 19 deletions
diff --git a/mira/kbest_mira.cc b/mira/kbest_mira.cc index 811bbd4b..9fda9b32 100644 --- a/mira/kbest_mira.cc +++ b/mira/kbest_mira.cc @@ -10,6 +10,7 @@ #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> +#include "hg_sampler.h" #include "sentence_metadata.h" #include "scorer.h" #include "verbose.h" @@ -54,6 +55,8 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("max_step_size,C", po::value<double>()->default_value(0.01), "regularization strength (C)") //("mt_metric_scale,s", po::value<double>()->default_value(1.0), "Amount to scale MT loss function by") ("k_best_size,k", po::value<int>()->default_value(250), "Size of hypothesis list to search for oracles") + ("sample_forest,f", "Instead of a k-best list, sample k hypotheses from the decoder's forest") + ("sample_forest_unit_weight_vector,x", "Before sampling (must use -f option), rescale the weight vector used so it has unit length; this may improve the quality of the samples") ("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)") ("decoder_config,c",po::value<string>(),"Decoder configuration file"); po::options_description clo("Command line options"); @@ -91,11 +94,12 @@ struct GoodBadOracle { }; struct TrainingObserver : public DecoderObserver { - TrainingObserver(const int k, const DocScorer& d, vector<GoodBadOracle>* o) : ds(d), oracles(*o), kbest_size(k) {} + TrainingObserver(const int k, const DocScorer& d, bool sf, vector<GoodBadOracle>* o) : ds(d), oracles(*o), kbest_size(k), sample_forest(sf) {} const DocScorer& ds; vector<GoodBadOracle>& oracles; shared_ptr<HypothesisInfo> cur_best; const int kbest_size; + const bool sample_forest; const HypothesisInfo& GetCurrentBestHypothesis() const { return *cur_best; @@ -116,24 +120,43 @@ struct TrainingObserver : public DecoderObserver { shared_ptr<HypothesisInfo>& cur_good = oracles[sent_id].good; shared_ptr<HypothesisInfo>& cur_bad = oracles[sent_id].bad; cur_bad.reset(); // TODO get rid of?? - KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, kbest_size); - for (int i = 0; i < kbest_size; ++i) { - const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = - kbest.LazyKthBest(forest.nodes_.size() - 1, i); - if (!d) break; - float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore(); - if (invert_score) sentscore *= -1.0; - // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl; - if (i == 0) - cur_best = MakeHypothesisInfo(d->feature_values, sentscore); - if (!cur_good || sentscore > cur_good->mt_metric) - cur_good = MakeHypothesisInfo(d->feature_values, sentscore); - if (!cur_bad || sentscore < cur_bad->mt_metric) - cur_bad = MakeHypothesisInfo(d->feature_values, sentscore); + + if (sample_forest) { + vector<WordID> cur_prediction; + ViterbiESentence(forest, &cur_prediction); + float sentscore = ds[sent_id]->ScoreCandidate(cur_prediction)->ComputeScore(); + cur_best = MakeHypothesisInfo(ViterbiFeatures(forest), sentscore); + + vector<HypergraphSampler::Hypothesis> samples; + HypergraphSampler::sample_hypotheses(forest, kbest_size, &*rng, &samples); + for (unsigned i = 0; i < samples.size(); ++i) { + sentscore = ds[sent_id]->ScoreCandidate(samples[i].words)->ComputeScore(); + if (invert_score) sentscore *= -1.0; + if (!cur_good || sentscore > cur_good->mt_metric) + cur_good = MakeHypothesisInfo(samples[i].fmap, sentscore); + if (!cur_bad || sentscore < cur_bad->mt_metric) + cur_bad = MakeHypothesisInfo(samples[i].fmap, sentscore); + } + } else { + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, kbest_size); + for (int i = 0; i < kbest_size; ++i) { + const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore(); + if (invert_score) sentscore *= -1.0; + // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl; + if (i == 0) + cur_best = MakeHypothesisInfo(d->feature_values, sentscore); + if (!cur_good || sentscore > cur_good->mt_metric) + cur_good = MakeHypothesisInfo(d->feature_values, sentscore); + if (!cur_bad || sentscore < cur_bad->mt_metric) + cur_bad = MakeHypothesisInfo(d->feature_values, sentscore); + } + //cerr << "GOOD: " << cur_good->mt_metric << endl; + //cerr << " CUR: " << cur_best->mt_metric << endl; + //cerr << " BAD: " << cur_bad->mt_metric << endl; } - //cerr << "GOOD: " << cur_good->mt_metric << endl; - //cerr << " CUR: " << cur_best->mt_metric << endl; - //cerr << " BAD: " << cur_bad->mt_metric << endl; } }; @@ -164,6 +187,12 @@ int main(int argc, char** argv) { rng.reset(new MT19937(conf["random_seed"].as<uint32_t>())); else rng.reset(new MT19937); + const bool sample_forest = conf.count("sample_forest") > 0; + const bool sample_forest_unit_weight_vector = conf.count("sample_forest_unit_weight_vector") > 0; + if (sample_forest_unit_weight_vector && !sample_forest) { + cerr << "Cannot --sample_forest_unit_weight_vector without --sample_forest" << endl; + return 1; + } vector<string> corpus; ReadTrainingCorpus(conf["source"].as<string>(), &corpus); const string metric_name = conf["mt_metric"].as<string>(); @@ -195,7 +224,7 @@ int main(int argc, char** argv) { assert(corpus.size() > 0); vector<GoodBadOracle> oracles(corpus.size()); - TrainingObserver observer(conf["k_best_size"].as<int>(), ds, &oracles); + TrainingObserver observer(conf["k_best_size"].as<int>(), ds, sample_forest, &oracles); int cur_sent = 0; int lcount = 0; int normalizer = 0; @@ -234,7 +263,19 @@ int main(int argc, char** argv) { cerr << "PASS " << (lcount / corpus.size() + 1) << endl; } decoder.SetId(order[cur_sent]); + double sc = 1.0; + if (sample_forest_unit_weight_vector) { + sc = lambdas.l2norm(); + if (sc > 0) { + for (unsigned i = 0; i < dense_weights.size(); ++i) + dense_weights[i] /= sc; + } + } decoder.Decode(corpus[order[cur_sent]], &observer); // update oracles + if (sc && sc != 1.0) { + for (unsigned i = 0; i < dense_weights.size(); ++i) + dense_weights[i] *= sc; + } const HypothesisInfo& cur_hyp = observer.GetCurrentBestHypothesis(); const HypothesisInfo& cur_good = *oracles[order[cur_sent]].good; const HypothesisInfo& cur_bad = *oracles[order[cur_sent]].bad; |