summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-08-07 23:22:44 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-08-07 23:22:44 -0400
commitbc2992ba96cd7af83da8522bdeb6e5dd94a5a11b (patch)
tree26fdfcde0138c26447514d2f97c26de8e1decca4
parent2a4fd2dac126cb5753ae32b6ea3ba1255551a810 (diff)
sample trees from hypergraphs
-rw-r--r--decoder/hg_sampler.cc55
-rw-r--r--decoder/hg_sampler.h7
-rw-r--r--python/src/hypergraph.pxd4
-rw-r--r--python/src/hypergraph.pxi12
4 files changed, 78 insertions, 0 deletions
diff --git a/decoder/hg_sampler.cc b/decoder/hg_sampler.cc
index cdf0ec3c..c4d3dede 100644
--- a/decoder/hg_sampler.cc
+++ b/decoder/hg_sampler.cc
@@ -71,3 +71,58 @@ void HypergraphSampler::sample_hypotheses(const Hypergraph& hg,
Viterbi(hg, &hyp.words, ESentenceTraversal(), SampledDerivationWeightFunction(sampled_edges));
}
}
+
+void HypergraphSampler::sample_trees(const Hypergraph& hg,
+ unsigned n,
+ MT19937* rng,
+ vector<string>* trees) {
+ trees->clear();
+ trees->resize(n);
+
+ // compute inside probabilities
+ vector<prob_t> node_probs;
+ Inside<prob_t, EdgeProb>(hg, &node_probs, EdgeProb());
+
+ vector<bool> sampled_edges(hg.edges_.size());
+ queue<unsigned> q;
+ SampleSet<prob_t> ss;
+ for (unsigned i = 0; i < n; ++i) {
+ fill(sampled_edges.begin(), sampled_edges.end(), false);
+ // sample derivation top down
+ assert(q.empty());
+ q.push(hg.nodes_.size() - 1);
+ prob_t model_score = prob_t::One();
+ while(!q.empty()) {
+ unsigned cur_node_id = q.front();
+ q.pop();
+ const Hypergraph::Node& node = hg.nodes_[cur_node_id];
+ const unsigned num_in_edges = node.in_edges_.size();
+ unsigned sampled_edge_idx = 0;
+ if (num_in_edges == 1) {
+ sampled_edge_idx = node.in_edges_[0];
+ } else {
+ assert(num_in_edges > 1);
+ ss.clear();
+ for (unsigned j = 0; j < num_in_edges; ++j) {
+ const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]];
+ prob_t p = edge.edge_prob_; // edge weight
+ for (unsigned k = 0; k < edge.tail_nodes_.size(); ++k)
+ p *= node_probs[edge.tail_nodes_[k]]; // tail node inside weight
+ ss.add(p);
+ }
+ sampled_edge_idx = node.in_edges_[rng->SelectSample(ss)];
+ }
+ sampled_edges[sampled_edge_idx] = true;
+ const Hypergraph::Edge& sampled_edge = hg.edges_[sampled_edge_idx];
+ model_score *= sampled_edge.edge_prob_;
+ //sampled_deriv->push_back(sampled_edge_idx);
+ for (unsigned j = 0; j < sampled_edge.tail_nodes_.size(); ++j) {
+ q.push(sampled_edge.tail_nodes_[j]);
+ }
+ }
+ vector<WordID> tmp;
+ Viterbi(hg, &tmp, ETreeTraversal(), SampledDerivationWeightFunction(sampled_edges));
+ (*trees)[n] = TD::GetString(tmp);
+ }
+}
+
diff --git a/decoder/hg_sampler.h b/decoder/hg_sampler.h
index bf4e1eb0..6ac39a20 100644
--- a/decoder/hg_sampler.h
+++ b/decoder/hg_sampler.h
@@ -3,6 +3,7 @@
#include <vector>
+#include <string>
#include "sparse_vector.h"
#include "sampler.h"
#include "wordid.h"
@@ -22,6 +23,12 @@ struct HypergraphSampler {
unsigned n, // how many samples to draw
MT19937* rng,
std::vector<Hypothesis>* hypos);
+
+ static void
+ sample_trees(const Hypergraph& hg,
+ unsigned n,
+ MT19937* rng,
+ std::vector<std::string>* trees);
};
#endif
diff --git a/python/src/hypergraph.pxd b/python/src/hypergraph.pxd
index 886660bf..abd6759c 100644
--- a/python/src/hypergraph.pxd
+++ b/python/src/hypergraph.pxd
@@ -76,6 +76,10 @@ cdef extern from "decoder/hg_sampler.h" namespace "HypergraphSampler":
unsigned n,
MT19937* rng,
vector[Hypothesis]* hypos)
+ void sample_trees(Hypergraph& hg,
+ unsigned n,
+ MT19937* rng,
+ vector[string]* trees)
cdef extern from "decoder/csplit.h" namespace "CompoundSplit":
int GetFullWordEdgeIndex(Hypergraph& forest)
diff --git a/python/src/hypergraph.pxi b/python/src/hypergraph.pxi
index b210f440..62dd5bb1 100644
--- a/python/src/hypergraph.pxi
+++ b/python/src/hypergraph.pxi
@@ -85,6 +85,18 @@ cdef class Hypergraph:
finally:
del hypos
+ def sample_trees(self, unsigned n):
+ cdef vector[string]* trees = new vector[string]()
+ if self.rng == NULL:
+ self.rng = new MT19937()
+ hypergraph.sample_trees(self.hg[0], n, self.rng, trees)
+ cdef unsigned k
+ try:
+ for k in range(trees.size()):
+ yield unicode(trees[0][k].c_str(), 'utf8')
+ finally:
+ del trees
+
def intersect(self, Lattice lat):
return hypergraph.Intersect(lat.lattice[0], self.hg)