diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-11-11 17:12:39 -0500 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-11-11 17:12:39 -0500 |
commit | a5592c9ab0266dbf4993e42e82e5a113316990ad (patch) | |
tree | e29f4bb543b962f49319f788d8262663f9b0a5b6 | |
parent | cb762a9c0e50e4e49b688dcc3f52498191efb20a (diff) |
optionally sample from forest to get training instances, rather than k-best it
-rw-r--r-- | decoder/Makefile.am | 1 | ||||
-rw-r--r-- | decoder/hg_sampler.cc | 73 | ||||
-rw-r--r-- | decoder/hg_sampler.h | 27 | ||||
-rw-r--r-- | mira/kbest_mira.cc | 79 |
4 files changed, 161 insertions, 19 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 6b9360d8..30eaf04d 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -51,6 +51,7 @@ libcdec_a_SOURCES = \ hg_io.cc \ decoder.cc \ hg_intersect.cc \ + hg_sampler.cc \ factored_lexicon_helper.cc \ viterbi.cc \ lattice.cc \ diff --git a/decoder/hg_sampler.cc b/decoder/hg_sampler.cc new file mode 100644 index 00000000..cdf0ec3c --- /dev/null +++ b/decoder/hg_sampler.cc @@ -0,0 +1,73 @@ +#include "hg_sampler.h" + +#include <queue> + +#include "viterbi.h" +#include "inside_outside.h" + +using namespace std; + +struct SampledDerivationWeightFunction { + typedef double Weight; + explicit SampledDerivationWeightFunction(const vector<bool>& sampled) : sampled_edges(sampled) {} + double operator()(const Hypergraph::Edge& e) const { + return static_cast<double>(sampled_edges[e.id_]); + } + const vector<bool>& sampled_edges; +}; + +void HypergraphSampler::sample_hypotheses(const Hypergraph& hg, + unsigned n, + MT19937* rng, + vector<Hypothesis>* hypos) { + hypos->clear(); + hypos->resize(n); + + // compute inside probabilities + vector<prob_t> node_probs; + Inside<prob_t, EdgeProb>(hg, &node_probs, EdgeProb()); + + vector<bool> sampled_edges(hg.edges_.size()); + queue<unsigned> q; + SampleSet<prob_t> ss; + for (unsigned i = 0; i < n; ++i) { + fill(sampled_edges.begin(), sampled_edges.end(), false); + // sample derivation top down + assert(q.empty()); + Hypothesis& hyp = (*hypos)[i]; + SparseVector<double>& deriv_features = hyp.fmap; + q.push(hg.nodes_.size() - 1); + prob_t& model_score = hyp.model_score; + model_score = prob_t::One(); + while(!q.empty()) { + unsigned cur_node_id = q.front(); + q.pop(); + const Hypergraph::Node& node = hg.nodes_[cur_node_id]; + const unsigned num_in_edges = node.in_edges_.size(); + unsigned sampled_edge_idx = 0; + if (num_in_edges == 1) { + sampled_edge_idx = node.in_edges_[0]; + } else { + assert(num_in_edges > 1); + ss.clear(); + for (unsigned j = 0; j < num_in_edges; ++j) { + const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]]; + prob_t p = edge.edge_prob_; // edge weight + for (unsigned k = 0; k < edge.tail_nodes_.size(); ++k) + p *= node_probs[edge.tail_nodes_[k]]; // tail node inside weight + ss.add(p); + } + sampled_edge_idx = node.in_edges_[rng->SelectSample(ss)]; + } + sampled_edges[sampled_edge_idx] = true; + const Hypergraph::Edge& sampled_edge = hg.edges_[sampled_edge_idx]; + deriv_features += sampled_edge.feature_values_; + model_score *= sampled_edge.edge_prob_; + //sampled_deriv->push_back(sampled_edge_idx); + for (unsigned j = 0; j < sampled_edge.tail_nodes_.size(); ++j) { + q.push(sampled_edge.tail_nodes_[j]); + } + } + Viterbi(hg, &hyp.words, ESentenceTraversal(), SampledDerivationWeightFunction(sampled_edges)); + } +} diff --git a/decoder/hg_sampler.h b/decoder/hg_sampler.h new file mode 100644 index 00000000..bf4e1eb0 --- /dev/null +++ b/decoder/hg_sampler.h @@ -0,0 +1,27 @@ +#ifndef _HG_SAMPLER_H_ +#define _HG_SAMPLER_H_ + + +#include <vector> +#include "sparse_vector.h" +#include "sampler.h" +#include "wordid.h" + +class Hypergraph; + +struct HypergraphSampler { + + struct Hypothesis { + std::vector<WordID> words; + SparseVector<double> fmap; + prob_t model_score; // log unnormalized probability + }; + + static void + sample_hypotheses(const Hypergraph& hg, + unsigned n, // how many samples to draw + MT19937* rng, + std::vector<Hypothesis>* hypos); +}; + +#endif diff --git a/mira/kbest_mira.cc b/mira/kbest_mira.cc index 459a5e6f..904eba74 100644 --- a/mira/kbest_mira.cc +++ b/mira/kbest_mira.cc @@ -10,6 +10,7 @@ #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> +#include "hg_sampler.h" #include "sentence_metadata.h" #include "scorer.h" #include "verbose.h" @@ -54,6 +55,8 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("max_step_size,C", po::value<double>()->default_value(0.01), "regularization strength (C)") ("mt_metric_scale,s", po::value<double>()->default_value(1.0), "Amount to scale MT loss function by") ("k_best_size,k", po::value<int>()->default_value(250), "Size of hypothesis list to search for oracles") + ("sample_forest,f", "Instead of a k-best list, sample k hypotheses from the decoder's forest") + ("sample_forest_unit_weight_vector,x", "Before sampling (must use -f option), rescale the weight vector used so it has unit length; this may improve the quality of the samples") ("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)") ("decoder_config,c",po::value<string>(),"Decoder configuration file"); po::options_description clo("Command line options"); @@ -91,11 +94,12 @@ struct GoodBadOracle { }; struct TrainingObserver : public DecoderObserver { - TrainingObserver(const int k, const DocScorer& d, vector<GoodBadOracle>* o) : ds(d), oracles(*o), kbest_size(k) {} + TrainingObserver(const int k, const DocScorer& d, bool sf, vector<GoodBadOracle>* o) : ds(d), oracles(*o), kbest_size(k), sample_forest(sf) {} const DocScorer& ds; vector<GoodBadOracle>& oracles; shared_ptr<HypothesisInfo> cur_best; const int kbest_size; + const bool sample_forest; const HypothesisInfo& GetCurrentBestHypothesis() const { return *cur_best; @@ -116,24 +120,43 @@ struct TrainingObserver : public DecoderObserver { shared_ptr<HypothesisInfo>& cur_good = oracles[sent_id].good; shared_ptr<HypothesisInfo>& cur_bad = oracles[sent_id].bad; cur_bad.reset(); // TODO get rid of?? - KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, kbest_size); - for (int i = 0; i < kbest_size; ++i) { - const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = - kbest.LazyKthBest(forest.nodes_.size() - 1, i); - if (!d) break; - float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore(); - if (invert_score) sentscore *= -1.0; - // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl; - if (i == 0) - cur_best = MakeHypothesisInfo(d->feature_values, sentscore); - if (!cur_good || sentscore > cur_good->mt_metric) - cur_good = MakeHypothesisInfo(d->feature_values, sentscore); - if (!cur_bad || sentscore < cur_bad->mt_metric) - cur_bad = MakeHypothesisInfo(d->feature_values, sentscore); + + if (sample_forest) { + vector<WordID> cur_prediction; + ViterbiESentence(forest, &cur_prediction); + float sentscore = ds[sent_id]->ScoreCandidate(cur_prediction)->ComputeScore(); + cur_best = MakeHypothesisInfo(ViterbiFeatures(forest), sentscore); + + vector<HypergraphSampler::Hypothesis> samples; + HypergraphSampler::sample_hypotheses(forest, kbest_size, &*rng, &samples); + for (unsigned i = 0; i < samples.size(); ++i) { + sentscore = ds[sent_id]->ScoreCandidate(samples[i].words)->ComputeScore(); + if (invert_score) sentscore *= -1.0; + if (!cur_good || sentscore > cur_good->mt_metric) + cur_good = MakeHypothesisInfo(samples[i].fmap, sentscore); + if (!cur_bad || sentscore < cur_bad->mt_metric) + cur_bad = MakeHypothesisInfo(samples[i].fmap, sentscore); + } + } else { + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, kbest_size); + for (int i = 0; i < kbest_size; ++i) { + const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore(); + if (invert_score) sentscore *= -1.0; + // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl; + if (i == 0) + cur_best = MakeHypothesisInfo(d->feature_values, sentscore); + if (!cur_good || sentscore > cur_good->mt_metric) + cur_good = MakeHypothesisInfo(d->feature_values, sentscore); + if (!cur_bad || sentscore < cur_bad->mt_metric) + cur_bad = MakeHypothesisInfo(d->feature_values, sentscore); + } + //cerr << "GOOD: " << cur_good->mt_metric << endl; + //cerr << " CUR: " << cur_best->mt_metric << endl; + //cerr << " BAD: " << cur_bad->mt_metric << endl; } - //cerr << "GOOD: " << cur_good->mt_metric << endl; - //cerr << " CUR: " << cur_best->mt_metric << endl; - //cerr << " BAD: " << cur_bad->mt_metric << endl; } }; @@ -164,6 +187,12 @@ int main(int argc, char** argv) { rng.reset(new MT19937(conf["random_seed"].as<uint32_t>())); else rng.reset(new MT19937); + const bool sample_forest = conf.count("sample_forest") > 0; + const bool sample_forest_unit_weight_vector = conf.count("sample_forest_unit_weight_vector") > 0; + if (sample_forest_unit_weight_vector && !sample_forest) { + cerr << "Cannot --sample_forest_unit_weight_vector without --sample_forest" << endl; + return 1; + } vector<string> corpus; ReadTrainingCorpus(conf["source"].as<string>(), &corpus); const string metric_name = conf["mt_metric"].as<string>(); @@ -195,7 +224,7 @@ int main(int argc, char** argv) { assert(corpus.size() > 0); vector<GoodBadOracle> oracles(corpus.size()); - TrainingObserver observer(conf["k_best_size"].as<int>(), ds, &oracles); + TrainingObserver observer(conf["k_best_size"].as<int>(), ds, sample_forest, &oracles); int cur_sent = 0; int lcount = 0; int normalizer = 0; @@ -234,7 +263,19 @@ int main(int argc, char** argv) { cerr << "PASS " << (lcount / corpus.size() + 1) << endl; } decoder.SetId(order[cur_sent]); + double sc = 1.0; + if (sample_forest_unit_weight_vector) { + sc = lambdas.l2norm(); + if (sc > 0) { + for (unsigned i = 0; i < dense_weights.size(); ++i) + dense_weights[i] /= sc; + } + } decoder.Decode(corpus[order[cur_sent]], &observer); // update oracles + if (sc && sc != 1.0) { + for (unsigned i = 0; i < dense_weights.size(); ++i) + dense_weights[i] *= sc; + } const HypothesisInfo& cur_hyp = observer.GetCurrentBestHypothesis(); const HypothesisInfo& cur_good = *oracles[order[cur_sent]].good; const HypothesisInfo& cur_bad = *oracles[order[cur_sent]].bad; |