#include "sample_hg.h" #include #include "viterbi.h" #include "inside_outside.h" using namespace std; struct SampledDerivationWeightFunction { typedef double Weight; explicit SampledDerivationWeightFunction(const vector& sampled) : sampled_edges(sampled) {} double operator()(const Hypergraph::Edge& e) const { return static_cast(sampled_edges[e.id_]); } const vector& sampled_edges; }; void HypergraphSampler::sample_hypotheses(const Hypergraph& hg, unsigned n, MT19937* rng, vector* hypos) { hypos->clear(); hypos->resize(n); // compute inside probabilities vector node_probs; Inside(hg, &node_probs, EdgeProb()); vector sampled_edges(hg.edges_.size()); queue q; SampleSet ss; for (unsigned i = 0; i < n; ++i) { fill(sampled_edges.begin(), sampled_edges.end(), false); // sample derivation top down assert(q.empty()); Hypothesis& hyp = (*hypos)[i]; SparseVector& deriv_features = hyp.fmap; q.push(hg.nodes_.size() - 1); prob_t& model_score = hyp.model_score; model_score = prob_t::One(); while(!q.empty()) { unsigned cur_node_id = q.front(); q.pop(); const Hypergraph::Node& node = hg.nodes_[cur_node_id]; const unsigned num_in_edges = node.in_edges_.size(); unsigned sampled_edge_idx = 0; if (num_in_edges == 1) { sampled_edge_idx = node.in_edges_[0]; } else { assert(num_in_edges > 1); ss.clear(); for (unsigned j = 0; j < num_in_edges; ++j) { const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]]; prob_t p = edge.edge_prob_; // edge weight for (unsigned k = 0; k < edge.tail_nodes_.size(); ++k) p *= node_probs[edge.tail_nodes_[k]]; // tail node inside weight ss.add(p); } sampled_edge_idx = node.in_edges_[rng->SelectSample(ss)]; } sampled_edges[sampled_edge_idx] = true; const Hypergraph::Edge& sampled_edge = hg.edges_[sampled_edge_idx]; deriv_features += sampled_edge.feature_values_; model_score *= sampled_edge.edge_prob_; //sampled_deriv->push_back(sampled_edge_idx); for (unsigned j = 0; j < sampled_edge.tail_nodes_.size(); ++j) { q.push(sampled_edge.tail_nodes_[j]); } } Viterbi(hg, &hyp.words, ESentenceTraversal(), SampledDerivationWeightFunction(sampled_edges)); } }