summaryrefslogtreecommitdiff
path: root/dtrain
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2011-09-09 09:51:09 +0200
committerPatrick Simianer <p@simianer.de>2011-09-23 19:13:58 +0200
commit2e6ef7cbec77b22ce3d64416a5ada3a6c081f9e2 (patch)
tree13c2013947ed643889811b82c93d434835b05252 /dtrain
parent14637f89c899179f54a5bc327857db8ea1e1d427 (diff)
forgotten files
Diffstat (limited to 'dtrain')
-rw-r--r--dtrain/ksampler.h52
-rw-r--r--dtrain/sample_hg.cc74
-rw-r--r--dtrain/sample_hg.h24
3 files changed, 150 insertions, 0 deletions
diff --git a/dtrain/ksampler.h b/dtrain/ksampler.h
new file mode 100644
index 00000000..a28b69c9
--- /dev/null
+++ b/dtrain/ksampler.h
@@ -0,0 +1,52 @@
+#ifndef _DTRAIN_KSAMPLER_H_
+#define _DTRAIN_KSAMPLER_H_
+
+#include "kbest.h"
+#include "sample_hg.h"
+#include "sampler.h"
+
+namespace dtrain
+{
+
+/*
+ * KSampler
+ *
+ */
+struct KSampler : public DecoderObserver
+{
+ const size_t k_;
+ KBestList kb;
+ MT19937* rng;
+
+ explicit KSampler( const size_t k, MT19937* prng ) :
+ k_(k), rng(prng) {}
+
+ virtual void
+ NotifyTranslationForest( const SentenceMetadata& smeta, Hypergraph* hg )
+ {
+ Sample( *hg );
+ }
+
+ KBestList* GetKBest() { return &kb; }
+
+ void Sample( const Hypergraph& forest ) {
+ kb.sents.clear();
+ kb.feats.clear();
+ kb.model_scores.clear();
+ kb.scores.clear();
+ std::vector<HypergraphSampler::Hypothesis> samples;
+ HypergraphSampler::sample_hypotheses(forest, k_, rng, &samples);
+ for ( size_t i = 0; i < k_; ++i ) {
+ kb.sents.push_back( samples[i].words );
+ kb.feats.push_back( samples[i].fmap );
+ kb.model_scores.push_back( log(samples[i].model_score) );
+ }
+ }
+};
+
+
+} // namespace
+
+
+#endif
+
diff --git a/dtrain/sample_hg.cc b/dtrain/sample_hg.cc
new file mode 100644
index 00000000..33872fb8
--- /dev/null
+++ b/dtrain/sample_hg.cc
@@ -0,0 +1,74 @@
+#include "sample_hg.h"
+
+#include <queue>
+
+#include "viterbi.h"
+#include "inside_outside.h"
+
+using namespace std;
+
+struct SampledDerivationWeightFunction {
+ typedef double Weight;
+ explicit SampledDerivationWeightFunction(const vector<bool>& sampled) : sampled_edges(sampled) {}
+ double operator()(const Hypergraph::Edge& e) const {
+ return static_cast<double>(sampled_edges[e.id_]);
+ }
+ const vector<bool>& sampled_edges;
+};
+
+void HypergraphSampler::sample_hypotheses(const Hypergraph& hg,
+ unsigned n,
+ MT19937* rng,
+ vector<Hypothesis>* hypos) {
+ hypos->clear();
+ hypos->resize(n);
+
+ // compute inside probabilities
+ vector<prob_t> node_probs;
+ Inside<prob_t, EdgeProb>(hg, &node_probs, EdgeProb());
+
+ vector<bool> sampled_edges(hg.edges_.size());
+ queue<unsigned> q;
+ SampleSet<prob_t> ss;
+ for (unsigned i = 0; i < n; ++i) {
+ fill(sampled_edges.begin(), sampled_edges.end(), false);
+ // sample derivation top down
+ assert(q.empty());
+ Hypothesis& hyp = (*hypos)[i];
+ SparseVector<double>& deriv_features = hyp.fmap;
+ q.push(hg.nodes_.size() - 1);
+ prob_t& model_score = hyp.model_score;
+ model_score = prob_t::One();
+ while(!q.empty()) {
+ unsigned cur_node_id = q.front();
+ q.pop();
+ const Hypergraph::Node& node = hg.nodes_[cur_node_id];
+ const unsigned num_in_edges = node.in_edges_.size();
+ unsigned sampled_edge_idx = 0;
+ if (num_in_edges == 1) {
+ sampled_edge_idx = node.in_edges_[0];
+ } else {
+ assert(num_in_edges > 1);
+ ss.clear();
+ for (unsigned j = 0; j < num_in_edges; ++j) {
+ const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]];
+ prob_t p = edge.edge_prob_; // edge weight
+ for (unsigned k = 0; k < edge.tail_nodes_.size(); ++k)
+ p *= node_probs[edge.tail_nodes_[k]]; // tail node inside weight
+ ss.add(p);
+ }
+ sampled_edge_idx = node.in_edges_[rng->SelectSample(ss)];
+ }
+ sampled_edges[sampled_edge_idx] = true;
+ const Hypergraph::Edge& sampled_edge = hg.edges_[sampled_edge_idx];
+ deriv_features += sampled_edge.feature_values_;
+ model_score *= sampled_edge.edge_prob_;
+ //sampled_deriv->push_back(sampled_edge_idx);
+ for (unsigned j = 0; j < sampled_edge.tail_nodes_.size(); ++j) {
+ q.push(sampled_edge.tail_nodes_[j]);
+ }
+ }
+ Viterbi(hg, &hyp.words, ESentenceTraversal(), SampledDerivationWeightFunction(sampled_edges));
+ }
+}
+
diff --git a/dtrain/sample_hg.h b/dtrain/sample_hg.h
new file mode 100644
index 00000000..932fd369
--- /dev/null
+++ b/dtrain/sample_hg.h
@@ -0,0 +1,24 @@
+#ifndef _SAMPLE_HG_H_
+#define _SAMPLE_HG_H_
+
+#include <vector>
+#include "sparse_vector.h"
+#include "sampler.h"
+#include "wordid.h"
+
+class Hypergraph;
+
+struct HypergraphSampler {
+ struct Hypothesis {
+ std::vector<WordID> words;
+ SparseVector<double> fmap;
+ prob_t model_score;
+ };
+
+ static void sample_hypotheses(const Hypergraph& hg,
+ unsigned n,
+ MT19937* rng,
+ std::vector<Hypothesis>* hypos);
+};
+
+#endif