From 105a52a8d37497fe69a01a7de771ef9b9300cd71 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 11 Nov 2011 17:12:39 -0500 Subject: optionally sample from forest to get training instances, rather than k-best it --- mira/kbest_mira.cc | 79 +++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 60 insertions(+), 19 deletions(-) (limited to 'mira/kbest_mira.cc') diff --git a/mira/kbest_mira.cc b/mira/kbest_mira.cc index 459a5e6f..904eba74 100644 --- a/mira/kbest_mira.cc +++ b/mira/kbest_mira.cc @@ -10,6 +10,7 @@ #include #include +#include "hg_sampler.h" #include "sentence_metadata.h" #include "scorer.h" #include "verbose.h" @@ -54,6 +55,8 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("max_step_size,C", po::value()->default_value(0.01), "regularization strength (C)") ("mt_metric_scale,s", po::value()->default_value(1.0), "Amount to scale MT loss function by") ("k_best_size,k", po::value()->default_value(250), "Size of hypothesis list to search for oracles") + ("sample_forest,f", "Instead of a k-best list, sample k hypotheses from the decoder's forest") + ("sample_forest_unit_weight_vector,x", "Before sampling (must use -f option), rescale the weight vector used so it has unit length; this may improve the quality of the samples") ("random_seed,S", po::value(), "Random seed (if not specified, /dev/random will be used)") ("decoder_config,c",po::value(),"Decoder configuration file"); po::options_description clo("Command line options"); @@ -91,11 +94,12 @@ struct GoodBadOracle { }; struct TrainingObserver : public DecoderObserver { - TrainingObserver(const int k, const DocScorer& d, vector* o) : ds(d), oracles(*o), kbest_size(k) {} + TrainingObserver(const int k, const DocScorer& d, bool sf, vector* o) : ds(d), oracles(*o), kbest_size(k), sample_forest(sf) {} const DocScorer& ds; vector& oracles; shared_ptr cur_best; const int kbest_size; + const bool sample_forest; const HypothesisInfo& GetCurrentBestHypothesis() const { return *cur_best; @@ -116,24 +120,43 @@ struct TrainingObserver : public DecoderObserver { shared_ptr& cur_good = oracles[sent_id].good; shared_ptr& cur_bad = oracles[sent_id].bad; cur_bad.reset(); // TODO get rid of?? - KBest::KBestDerivations, ESentenceTraversal> kbest(forest, kbest_size); - for (int i = 0; i < kbest_size; ++i) { - const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = - kbest.LazyKthBest(forest.nodes_.size() - 1, i); - if (!d) break; - float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore(); - if (invert_score) sentscore *= -1.0; - // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl; - if (i == 0) - cur_best = MakeHypothesisInfo(d->feature_values, sentscore); - if (!cur_good || sentscore > cur_good->mt_metric) - cur_good = MakeHypothesisInfo(d->feature_values, sentscore); - if (!cur_bad || sentscore < cur_bad->mt_metric) - cur_bad = MakeHypothesisInfo(d->feature_values, sentscore); + + if (sample_forest) { + vector cur_prediction; + ViterbiESentence(forest, &cur_prediction); + float sentscore = ds[sent_id]->ScoreCandidate(cur_prediction)->ComputeScore(); + cur_best = MakeHypothesisInfo(ViterbiFeatures(forest), sentscore); + + vector samples; + HypergraphSampler::sample_hypotheses(forest, kbest_size, &*rng, &samples); + for (unsigned i = 0; i < samples.size(); ++i) { + sentscore = ds[sent_id]->ScoreCandidate(samples[i].words)->ComputeScore(); + if (invert_score) sentscore *= -1.0; + if (!cur_good || sentscore > cur_good->mt_metric) + cur_good = MakeHypothesisInfo(samples[i].fmap, sentscore); + if (!cur_bad || sentscore < cur_bad->mt_metric) + cur_bad = MakeHypothesisInfo(samples[i].fmap, sentscore); + } + } else { + KBest::KBestDerivations, ESentenceTraversal> kbest(forest, kbest_size); + for (int i = 0; i < kbest_size; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore(); + if (invert_score) sentscore *= -1.0; + // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl; + if (i == 0) + cur_best = MakeHypothesisInfo(d->feature_values, sentscore); + if (!cur_good || sentscore > cur_good->mt_metric) + cur_good = MakeHypothesisInfo(d->feature_values, sentscore); + if (!cur_bad || sentscore < cur_bad->mt_metric) + cur_bad = MakeHypothesisInfo(d->feature_values, sentscore); + } + //cerr << "GOOD: " << cur_good->mt_metric << endl; + //cerr << " CUR: " << cur_best->mt_metric << endl; + //cerr << " BAD: " << cur_bad->mt_metric << endl; } - //cerr << "GOOD: " << cur_good->mt_metric << endl; - //cerr << " CUR: " << cur_best->mt_metric << endl; - //cerr << " BAD: " << cur_bad->mt_metric << endl; } }; @@ -164,6 +187,12 @@ int main(int argc, char** argv) { rng.reset(new MT19937(conf["random_seed"].as())); else rng.reset(new MT19937); + const bool sample_forest = conf.count("sample_forest") > 0; + const bool sample_forest_unit_weight_vector = conf.count("sample_forest_unit_weight_vector") > 0; + if (sample_forest_unit_weight_vector && !sample_forest) { + cerr << "Cannot --sample_forest_unit_weight_vector without --sample_forest" << endl; + return 1; + } vector corpus; ReadTrainingCorpus(conf["source"].as(), &corpus); const string metric_name = conf["mt_metric"].as(); @@ -195,7 +224,7 @@ int main(int argc, char** argv) { assert(corpus.size() > 0); vector oracles(corpus.size()); - TrainingObserver observer(conf["k_best_size"].as(), ds, &oracles); + TrainingObserver observer(conf["k_best_size"].as(), ds, sample_forest, &oracles); int cur_sent = 0; int lcount = 0; int normalizer = 0; @@ -234,7 +263,19 @@ int main(int argc, char** argv) { cerr << "PASS " << (lcount / corpus.size() + 1) << endl; } decoder.SetId(order[cur_sent]); + double sc = 1.0; + if (sample_forest_unit_weight_vector) { + sc = lambdas.l2norm(); + if (sc > 0) { + for (unsigned i = 0; i < dense_weights.size(); ++i) + dense_weights[i] /= sc; + } + } decoder.Decode(corpus[order[cur_sent]], &observer); // update oracles + if (sc && sc != 1.0) { + for (unsigned i = 0; i < dense_weights.size(); ++i) + dense_weights[i] *= sc; + } const HypothesisInfo& cur_hyp = observer.GetCurrentBestHypothesis(); const HypothesisInfo& cur_good = *oracles[order[cur_sent]].good; const HypothesisInfo& cur_bad = *oracles[order[cur_sent]].bad; -- cgit v1.2.3