From ef74e67449515ff68f598f06ffc9d221eb13f919 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Fri, 9 Sep 2011 09:51:09 +0200 Subject: forgotten files --- dtrain/ksampler.h | 52 +++++++++++++++++++++++++++++++++++++ dtrain/sample_hg.cc | 74 +++++++++++++++++++++++++++++++++++++++++++++++++++++ dtrain/sample_hg.h | 24 +++++++++++++++++ 3 files changed, 150 insertions(+) create mode 100644 dtrain/ksampler.h create mode 100644 dtrain/sample_hg.cc create mode 100644 dtrain/sample_hg.h (limited to 'dtrain') diff --git a/dtrain/ksampler.h b/dtrain/ksampler.h new file mode 100644 index 00000000..a28b69c9 --- /dev/null +++ b/dtrain/ksampler.h @@ -0,0 +1,52 @@ +#ifndef _DTRAIN_KSAMPLER_H_ +#define _DTRAIN_KSAMPLER_H_ + +#include "kbest.h" +#include "sample_hg.h" +#include "sampler.h" + +namespace dtrain +{ + +/* + * KSampler + * + */ +struct KSampler : public DecoderObserver +{ + const size_t k_; + KBestList kb; + MT19937* rng; + + explicit KSampler( const size_t k, MT19937* prng ) : + k_(k), rng(prng) {} + + virtual void + NotifyTranslationForest( const SentenceMetadata& smeta, Hypergraph* hg ) + { + Sample( *hg ); + } + + KBestList* GetKBest() { return &kb; } + + void Sample( const Hypergraph& forest ) { + kb.sents.clear(); + kb.feats.clear(); + kb.model_scores.clear(); + kb.scores.clear(); + std::vector samples; + HypergraphSampler::sample_hypotheses(forest, k_, rng, &samples); + for ( size_t i = 0; i < k_; ++i ) { + kb.sents.push_back( samples[i].words ); + kb.feats.push_back( samples[i].fmap ); + kb.model_scores.push_back( log(samples[i].model_score) ); + } + } +}; + + +} // namespace + + +#endif + diff --git a/dtrain/sample_hg.cc b/dtrain/sample_hg.cc new file mode 100644 index 00000000..33872fb8 --- /dev/null +++ b/dtrain/sample_hg.cc @@ -0,0 +1,74 @@ +#include "sample_hg.h" + +#include + +#include "viterbi.h" +#include "inside_outside.h" + +using namespace std; + +struct SampledDerivationWeightFunction { + typedef double Weight; + explicit SampledDerivationWeightFunction(const vector& sampled) : sampled_edges(sampled) {} + double operator()(const Hypergraph::Edge& e) const { + return static_cast(sampled_edges[e.id_]); + } + const vector& sampled_edges; +}; + +void HypergraphSampler::sample_hypotheses(const Hypergraph& hg, + unsigned n, + MT19937* rng, + vector* hypos) { + hypos->clear(); + hypos->resize(n); + + // compute inside probabilities + vector node_probs; + Inside(hg, &node_probs, EdgeProb()); + + vector sampled_edges(hg.edges_.size()); + queue q; + SampleSet ss; + for (unsigned i = 0; i < n; ++i) { + fill(sampled_edges.begin(), sampled_edges.end(), false); + // sample derivation top down + assert(q.empty()); + Hypothesis& hyp = (*hypos)[i]; + SparseVector& deriv_features = hyp.fmap; + q.push(hg.nodes_.size() - 1); + prob_t& model_score = hyp.model_score; + model_score = prob_t::One(); + while(!q.empty()) { + unsigned cur_node_id = q.front(); + q.pop(); + const Hypergraph::Node& node = hg.nodes_[cur_node_id]; + const unsigned num_in_edges = node.in_edges_.size(); + unsigned sampled_edge_idx = 0; + if (num_in_edges == 1) { + sampled_edge_idx = node.in_edges_[0]; + } else { + assert(num_in_edges > 1); + ss.clear(); + for (unsigned j = 0; j < num_in_edges; ++j) { + const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]]; + prob_t p = edge.edge_prob_; // edge weight + for (unsigned k = 0; k < edge.tail_nodes_.size(); ++k) + p *= node_probs[edge.tail_nodes_[k]]; // tail node inside weight + ss.add(p); + } + sampled_edge_idx = node.in_edges_[rng->SelectSample(ss)]; + } + sampled_edges[sampled_edge_idx] = true; + const Hypergraph::Edge& sampled_edge = hg.edges_[sampled_edge_idx]; + deriv_features += sampled_edge.feature_values_; + model_score *= sampled_edge.edge_prob_; + //sampled_deriv->push_back(sampled_edge_idx); + for (unsigned j = 0; j < sampled_edge.tail_nodes_.size(); ++j) { + q.push(sampled_edge.tail_nodes_[j]); + } + } + Viterbi(hg, &hyp.words, ESentenceTraversal(), SampledDerivationWeightFunction(sampled_edges)); + } +} + diff --git a/dtrain/sample_hg.h b/dtrain/sample_hg.h new file mode 100644 index 00000000..932fd369 --- /dev/null +++ b/dtrain/sample_hg.h @@ -0,0 +1,24 @@ +#ifndef _SAMPLE_HG_H_ +#define _SAMPLE_HG_H_ + +#include +#include "sparse_vector.h" +#include "sampler.h" +#include "wordid.h" + +class Hypergraph; + +struct HypergraphSampler { + struct Hypothesis { + std::vector words; + SparseVector fmap; + prob_t model_score; + }; + + static void sample_hypotheses(const Hypergraph& hg, + unsigned n, + MT19937* rng, + std::vector* hypos); +}; + +#endif -- cgit v1.2.3