diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2012-08-07 23:22:44 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2012-08-07 23:22:44 -0400 |
commit | 6ba464f6d78e38970d5467b10ce1114f4d7feaa4 (patch) | |
tree | 15614a9ee68a2f0b5d05025752b0dd2bc94ae67d | |
parent | c6f5711203782e2677ad95da3ffa7a79fc0fbf3a (diff) |
sample trees from hypergraphs
-rw-r--r-- | decoder/hg_sampler.cc | 55 | ||||
-rw-r--r-- | decoder/hg_sampler.h | 7 | ||||
-rw-r--r-- | python/src/hypergraph.pxd | 4 | ||||
-rw-r--r-- | python/src/hypergraph.pxi | 12 |
4 files changed, 78 insertions, 0 deletions
diff --git a/decoder/hg_sampler.cc b/decoder/hg_sampler.cc index cdf0ec3c..c4d3dede 100644 --- a/decoder/hg_sampler.cc +++ b/decoder/hg_sampler.cc @@ -71,3 +71,58 @@ void HypergraphSampler::sample_hypotheses(const Hypergraph& hg, Viterbi(hg, &hyp.words, ESentenceTraversal(), SampledDerivationWeightFunction(sampled_edges)); } } + +void HypergraphSampler::sample_trees(const Hypergraph& hg, + unsigned n, + MT19937* rng, + vector<string>* trees) { + trees->clear(); + trees->resize(n); + + // compute inside probabilities + vector<prob_t> node_probs; + Inside<prob_t, EdgeProb>(hg, &node_probs, EdgeProb()); + + vector<bool> sampled_edges(hg.edges_.size()); + queue<unsigned> q; + SampleSet<prob_t> ss; + for (unsigned i = 0; i < n; ++i) { + fill(sampled_edges.begin(), sampled_edges.end(), false); + // sample derivation top down + assert(q.empty()); + q.push(hg.nodes_.size() - 1); + prob_t model_score = prob_t::One(); + while(!q.empty()) { + unsigned cur_node_id = q.front(); + q.pop(); + const Hypergraph::Node& node = hg.nodes_[cur_node_id]; + const unsigned num_in_edges = node.in_edges_.size(); + unsigned sampled_edge_idx = 0; + if (num_in_edges == 1) { + sampled_edge_idx = node.in_edges_[0]; + } else { + assert(num_in_edges > 1); + ss.clear(); + for (unsigned j = 0; j < num_in_edges; ++j) { + const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]]; + prob_t p = edge.edge_prob_; // edge weight + for (unsigned k = 0; k < edge.tail_nodes_.size(); ++k) + p *= node_probs[edge.tail_nodes_[k]]; // tail node inside weight + ss.add(p); + } + sampled_edge_idx = node.in_edges_[rng->SelectSample(ss)]; + } + sampled_edges[sampled_edge_idx] = true; + const Hypergraph::Edge& sampled_edge = hg.edges_[sampled_edge_idx]; + model_score *= sampled_edge.edge_prob_; + //sampled_deriv->push_back(sampled_edge_idx); + for (unsigned j = 0; j < sampled_edge.tail_nodes_.size(); ++j) { + q.push(sampled_edge.tail_nodes_[j]); + } + } + vector<WordID> tmp; + Viterbi(hg, &tmp, ETreeTraversal(), SampledDerivationWeightFunction(sampled_edges)); + (*trees)[n] = TD::GetString(tmp); + } +} + diff --git a/decoder/hg_sampler.h b/decoder/hg_sampler.h index bf4e1eb0..6ac39a20 100644 --- a/decoder/hg_sampler.h +++ b/decoder/hg_sampler.h @@ -3,6 +3,7 @@ #include <vector> +#include <string> #include "sparse_vector.h" #include "sampler.h" #include "wordid.h" @@ -22,6 +23,12 @@ struct HypergraphSampler { unsigned n, // how many samples to draw MT19937* rng, std::vector<Hypothesis>* hypos); + + static void + sample_trees(const Hypergraph& hg, + unsigned n, + MT19937* rng, + std::vector<std::string>* trees); }; #endif diff --git a/python/src/hypergraph.pxd b/python/src/hypergraph.pxd index 886660bf..abd6759c 100644 --- a/python/src/hypergraph.pxd +++ b/python/src/hypergraph.pxd @@ -76,6 +76,10 @@ cdef extern from "decoder/hg_sampler.h" namespace "HypergraphSampler": unsigned n, MT19937* rng, vector[Hypothesis]* hypos) + void sample_trees(Hypergraph& hg, + unsigned n, + MT19937* rng, + vector[string]* trees) cdef extern from "decoder/csplit.h" namespace "CompoundSplit": int GetFullWordEdgeIndex(Hypergraph& forest) diff --git a/python/src/hypergraph.pxi b/python/src/hypergraph.pxi index b210f440..62dd5bb1 100644 --- a/python/src/hypergraph.pxi +++ b/python/src/hypergraph.pxi @@ -85,6 +85,18 @@ cdef class Hypergraph: finally: del hypos + def sample_trees(self, unsigned n): + cdef vector[string]* trees = new vector[string]() + if self.rng == NULL: + self.rng = new MT19937() + hypergraph.sample_trees(self.hg[0], n, self.rng, trees) + cdef unsigned k + try: + for k in range(trees.size()): + yield unicode(trees[0][k].c_str(), 'utf8') + finally: + del trees + def intersect(self, Lattice lat): return hypergraph.Intersect(lat.lattice[0], self.hg) |