summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-11-11 17:12:39 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2011-11-11 17:12:39 -0500
commit105a52a8d37497fe69a01a7de771ef9b9300cd71 (patch)
treef20c4cc8bf31ccf1ce7204301bfa169c6fa080a7
parentb4fd470d2cb80b0c88d4210f7e5bb10d2aa4531d (diff)
optionally sample from forest to get training instances, rather than k-best it
-rw-r--r--decoder/Makefile.am1
-rw-r--r--decoder/hg_sampler.cc73
-rw-r--r--decoder/hg_sampler.h27
-rw-r--r--mira/kbest_mira.cc79
4 files changed, 161 insertions, 19 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index 6b9360d8..30eaf04d 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -51,6 +51,7 @@ libcdec_a_SOURCES = \
hg_io.cc \
decoder.cc \
hg_intersect.cc \
+ hg_sampler.cc \
factored_lexicon_helper.cc \
viterbi.cc \
lattice.cc \
diff --git a/decoder/hg_sampler.cc b/decoder/hg_sampler.cc
new file mode 100644
index 00000000..cdf0ec3c
--- /dev/null
+++ b/decoder/hg_sampler.cc
@@ -0,0 +1,73 @@
+#include "hg_sampler.h"
+
+#include <queue>
+
+#include "viterbi.h"
+#include "inside_outside.h"
+
+using namespace std;
+
+struct SampledDerivationWeightFunction {
+ typedef double Weight;
+ explicit SampledDerivationWeightFunction(const vector<bool>& sampled) : sampled_edges(sampled) {}
+ double operator()(const Hypergraph::Edge& e) const {
+ return static_cast<double>(sampled_edges[e.id_]);
+ }
+ const vector<bool>& sampled_edges;
+};
+
+void HypergraphSampler::sample_hypotheses(const Hypergraph& hg,
+ unsigned n,
+ MT19937* rng,
+ vector<Hypothesis>* hypos) {
+ hypos->clear();
+ hypos->resize(n);
+
+ // compute inside probabilities
+ vector<prob_t> node_probs;
+ Inside<prob_t, EdgeProb>(hg, &node_probs, EdgeProb());
+
+ vector<bool> sampled_edges(hg.edges_.size());
+ queue<unsigned> q;
+ SampleSet<prob_t> ss;
+ for (unsigned i = 0; i < n; ++i) {
+ fill(sampled_edges.begin(), sampled_edges.end(), false);
+ // sample derivation top down
+ assert(q.empty());
+ Hypothesis& hyp = (*hypos)[i];
+ SparseVector<double>& deriv_features = hyp.fmap;
+ q.push(hg.nodes_.size() - 1);
+ prob_t& model_score = hyp.model_score;
+ model_score = prob_t::One();
+ while(!q.empty()) {
+ unsigned cur_node_id = q.front();
+ q.pop();
+ const Hypergraph::Node& node = hg.nodes_[cur_node_id];
+ const unsigned num_in_edges = node.in_edges_.size();
+ unsigned sampled_edge_idx = 0;
+ if (num_in_edges == 1) {
+ sampled_edge_idx = node.in_edges_[0];
+ } else {
+ assert(num_in_edges > 1);
+ ss.clear();
+ for (unsigned j = 0; j < num_in_edges; ++j) {
+ const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]];
+ prob_t p = edge.edge_prob_; // edge weight
+ for (unsigned k = 0; k < edge.tail_nodes_.size(); ++k)
+ p *= node_probs[edge.tail_nodes_[k]]; // tail node inside weight
+ ss.add(p);
+ }
+ sampled_edge_idx = node.in_edges_[rng->SelectSample(ss)];
+ }
+ sampled_edges[sampled_edge_idx] = true;
+ const Hypergraph::Edge& sampled_edge = hg.edges_[sampled_edge_idx];
+ deriv_features += sampled_edge.feature_values_;
+ model_score *= sampled_edge.edge_prob_;
+ //sampled_deriv->push_back(sampled_edge_idx);
+ for (unsigned j = 0; j < sampled_edge.tail_nodes_.size(); ++j) {
+ q.push(sampled_edge.tail_nodes_[j]);
+ }
+ }
+ Viterbi(hg, &hyp.words, ESentenceTraversal(), SampledDerivationWeightFunction(sampled_edges));
+ }
+}
diff --git a/decoder/hg_sampler.h b/decoder/hg_sampler.h
new file mode 100644
index 00000000..bf4e1eb0
--- /dev/null
+++ b/decoder/hg_sampler.h
@@ -0,0 +1,27 @@
+#ifndef _HG_SAMPLER_H_
+#define _HG_SAMPLER_H_
+
+
+#include <vector>
+#include "sparse_vector.h"
+#include "sampler.h"
+#include "wordid.h"
+
+class Hypergraph;
+
+struct HypergraphSampler {
+
+ struct Hypothesis {
+ std::vector<WordID> words;
+ SparseVector<double> fmap;
+ prob_t model_score; // log unnormalized probability
+ };
+
+ static void
+ sample_hypotheses(const Hypergraph& hg,
+ unsigned n, // how many samples to draw
+ MT19937* rng,
+ std::vector<Hypothesis>* hypos);
+};
+
+#endif
diff --git a/mira/kbest_mira.cc b/mira/kbest_mira.cc
index 459a5e6f..904eba74 100644
--- a/mira/kbest_mira.cc
+++ b/mira/kbest_mira.cc
@@ -10,6 +10,7 @@
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
+#include "hg_sampler.h"
#include "sentence_metadata.h"
#include "scorer.h"
#include "verbose.h"
@@ -54,6 +55,8 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("max_step_size,C", po::value<double>()->default_value(0.01), "regularization strength (C)")
("mt_metric_scale,s", po::value<double>()->default_value(1.0), "Amount to scale MT loss function by")
("k_best_size,k", po::value<int>()->default_value(250), "Size of hypothesis list to search for oracles")
+ ("sample_forest,f", "Instead of a k-best list, sample k hypotheses from the decoder's forest")
+ ("sample_forest_unit_weight_vector,x", "Before sampling (must use -f option), rescale the weight vector used so it has unit length; this may improve the quality of the samples")
("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)")
("decoder_config,c",po::value<string>(),"Decoder configuration file");
po::options_description clo("Command line options");
@@ -91,11 +94,12 @@ struct GoodBadOracle {
};
struct TrainingObserver : public DecoderObserver {
- TrainingObserver(const int k, const DocScorer& d, vector<GoodBadOracle>* o) : ds(d), oracles(*o), kbest_size(k) {}
+ TrainingObserver(const int k, const DocScorer& d, bool sf, vector<GoodBadOracle>* o) : ds(d), oracles(*o), kbest_size(k), sample_forest(sf) {}
const DocScorer& ds;
vector<GoodBadOracle>& oracles;
shared_ptr<HypothesisInfo> cur_best;
const int kbest_size;
+ const bool sample_forest;
const HypothesisInfo& GetCurrentBestHypothesis() const {
return *cur_best;
@@ -116,24 +120,43 @@ struct TrainingObserver : public DecoderObserver {
shared_ptr<HypothesisInfo>& cur_good = oracles[sent_id].good;
shared_ptr<HypothesisInfo>& cur_bad = oracles[sent_id].bad;
cur_bad.reset(); // TODO get rid of??
- KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, kbest_size);
- for (int i = 0; i < kbest_size; ++i) {
- const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
- kbest.LazyKthBest(forest.nodes_.size() - 1, i);
- if (!d) break;
- float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore();
- if (invert_score) sentscore *= -1.0;
- // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl;
- if (i == 0)
- cur_best = MakeHypothesisInfo(d->feature_values, sentscore);
- if (!cur_good || sentscore > cur_good->mt_metric)
- cur_good = MakeHypothesisInfo(d->feature_values, sentscore);
- if (!cur_bad || sentscore < cur_bad->mt_metric)
- cur_bad = MakeHypothesisInfo(d->feature_values, sentscore);
+
+ if (sample_forest) {
+ vector<WordID> cur_prediction;
+ ViterbiESentence(forest, &cur_prediction);
+ float sentscore = ds[sent_id]->ScoreCandidate(cur_prediction)->ComputeScore();
+ cur_best = MakeHypothesisInfo(ViterbiFeatures(forest), sentscore);
+
+ vector<HypergraphSampler::Hypothesis> samples;
+ HypergraphSampler::sample_hypotheses(forest, kbest_size, &*rng, &samples);
+ for (unsigned i = 0; i < samples.size(); ++i) {
+ sentscore = ds[sent_id]->ScoreCandidate(samples[i].words)->ComputeScore();
+ if (invert_score) sentscore *= -1.0;
+ if (!cur_good || sentscore > cur_good->mt_metric)
+ cur_good = MakeHypothesisInfo(samples[i].fmap, sentscore);
+ if (!cur_bad || sentscore < cur_bad->mt_metric)
+ cur_bad = MakeHypothesisInfo(samples[i].fmap, sentscore);
+ }
+ } else {
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, kbest_size);
+ for (int i = 0; i < kbest_size; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(forest.nodes_.size() - 1, i);
+ if (!d) break;
+ float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore();
+ if (invert_score) sentscore *= -1.0;
+ // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl;
+ if (i == 0)
+ cur_best = MakeHypothesisInfo(d->feature_values, sentscore);
+ if (!cur_good || sentscore > cur_good->mt_metric)
+ cur_good = MakeHypothesisInfo(d->feature_values, sentscore);
+ if (!cur_bad || sentscore < cur_bad->mt_metric)
+ cur_bad = MakeHypothesisInfo(d->feature_values, sentscore);
+ }
+ //cerr << "GOOD: " << cur_good->mt_metric << endl;
+ //cerr << " CUR: " << cur_best->mt_metric << endl;
+ //cerr << " BAD: " << cur_bad->mt_metric << endl;
}
- //cerr << "GOOD: " << cur_good->mt_metric << endl;
- //cerr << " CUR: " << cur_best->mt_metric << endl;
- //cerr << " BAD: " << cur_bad->mt_metric << endl;
}
};
@@ -164,6 +187,12 @@ int main(int argc, char** argv) {
rng.reset(new MT19937(conf["random_seed"].as<uint32_t>()));
else
rng.reset(new MT19937);
+ const bool sample_forest = conf.count("sample_forest") > 0;
+ const bool sample_forest_unit_weight_vector = conf.count("sample_forest_unit_weight_vector") > 0;
+ if (sample_forest_unit_weight_vector && !sample_forest) {
+ cerr << "Cannot --sample_forest_unit_weight_vector without --sample_forest" << endl;
+ return 1;
+ }
vector<string> corpus;
ReadTrainingCorpus(conf["source"].as<string>(), &corpus);
const string metric_name = conf["mt_metric"].as<string>();
@@ -195,7 +224,7 @@ int main(int argc, char** argv) {
assert(corpus.size() > 0);
vector<GoodBadOracle> oracles(corpus.size());
- TrainingObserver observer(conf["k_best_size"].as<int>(), ds, &oracles);
+ TrainingObserver observer(conf["k_best_size"].as<int>(), ds, sample_forest, &oracles);
int cur_sent = 0;
int lcount = 0;
int normalizer = 0;
@@ -234,7 +263,19 @@ int main(int argc, char** argv) {
cerr << "PASS " << (lcount / corpus.size() + 1) << endl;
}
decoder.SetId(order[cur_sent]);
+ double sc = 1.0;
+ if (sample_forest_unit_weight_vector) {
+ sc = lambdas.l2norm();
+ if (sc > 0) {
+ for (unsigned i = 0; i < dense_weights.size(); ++i)
+ dense_weights[i] /= sc;
+ }
+ }
decoder.Decode(corpus[order[cur_sent]], &observer); // update oracles
+ if (sc && sc != 1.0) {
+ for (unsigned i = 0; i < dense_weights.size(); ++i)
+ dense_weights[i] *= sc;
+ }
const HypothesisInfo& cur_hyp = observer.GetCurrentBestHypothesis();
const HypothesisInfo& cur_good = *oracles[order[cur_sent]].good;
const HypothesisInfo& cur_bad = *oracles[order[cur_sent]].bad;