From 243544888fe06e5416e66cf723a3ac34ab33d59a Mon Sep 17 00:00:00 2001
From: Chris Dyer <cdyer@cs.cmu.edu>
Date: Thu, 24 Mar 2011 18:04:06 -0400
Subject: various summary feature types, part 1

---
 decoder/decoder.cc | 108 ++++++++++++++++++++++++++++++++++++-----------------
 1 file changed, 73 insertions(+), 35 deletions(-)

(limited to 'decoder')

diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index b7774acc..fdaf8cb1 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -66,6 +66,13 @@ void DecoderObserver::NotifyAlignmentFailure(const SentenceMetadata&) {}
 void DecoderObserver::NotifyAlignmentForest(const SentenceMetadata&, Hypergraph*) {}
 void DecoderObserver::NotifyDecodingComplete(const SentenceMetadata&) {}
 
+enum SummaryFeature {
+  kNODE_RISK = 1,
+  kEDGE_RISK,
+  kEDGE_PROB
+};
+
+
 struct ELengthWeightFunction {
   double operator()(const Hypergraph::Edge& e) const {
     return e.rule_->ELength() - e.rule_->Arity();
@@ -364,6 +371,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
         ("feature_function,F",po::value<vector<string> >()->composing(), "Pass 1 additional feature function(s) (-L for list)")
         ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full")
         ("summary_feature", po::value<string>(), "Compute a 'summary feature' at the end of the pass (before any pruning) with name=arg and value=inside-outside/Z")
+        ("summary_feature_type", po::value<string>()->default_value("node_risk"), "Summary feature types: node_risk, edge_risk, edge_prob")
         ("density_prune", po::value<double>(), "Pass 1 pruning: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
         ("beam_prune", po::value<double>(), "Pass 1 pruning: Prune paths from scored forest, keep paths within exp(alpha>=0)")
 
@@ -386,8 +394,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
     ("apply_fsa_by",po::value<string>()->default_value("BU_CUBE"), "Method for applying fsa_feature_functions - BU_FULL BU_CUBE EARLEY") //+ApplyFsaBy::all_names()
 #endif
         ("add_pass_through_rules,P","Add rules to translate OOV words as themselves")
-	("k_best,k",po::value<int>(),"Extract the k best derivations")
-	("unique_k_best,r", "Unique k-best translation list")
+        ("k_best,k",po::value<int>(),"Extract the k best derivations")
+        ("unique_k_best,r", "Unique k-best translation list")
         ("cubepruning_pop_limit,K",po::value<int>()->default_value(200), "Max number of pops from the candidate heap at each node")
         ("aligner,a", "Run as a word/phrase aligner (src & ref required)")
         ("aligner_use_viterbi", "If run in alignment mode, compute the Viterbi (rather than MAP) alignment")
@@ -775,6 +783,18 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
     cerr << "  Expected length  (words): " << res.r / res.p << "\t" << res << endl;
   }
 
+  SummaryFeature summary_feature_type = kNODE_RISK;
+  if (conf["summary_feature_type"].as<string>() == "edge_risk")
+    summary_feature_type = kEDGE_RISK;
+  else if (conf["summary_feature_type"].as<string>() == "node_risk")
+    summary_feature_type = kNODE_RISK;
+  else if (conf["summary_feature_type"].as<string>() == "edge_prob")
+    summary_feature_type = kEDGE_PROB;
+  else {
+    cerr << "Bad summary_feature_type: " << conf["summary_feature_type"].as<string>() << endl;
+    abort();
+  }
+
   for (int pass = 0; pass < rescoring_passes.size(); ++pass) {
     const RescoringPass& rp = rescoring_passes[pass];
     const vector<double>& cur_weights = rp.weight_vector;
@@ -806,43 +826,61 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
     }
 
     if (rp.fid_summary) {
-#if 0
-      const prob_t z = forest.PushWeightsToGoal(1.0);
-      if (!SILENT) { cerr << "  " << passtr << " adding summary feature " << FD::Convert(rp.fid_summary) << " log(Z)=" << log(z) << endl; }
-      if (!isfinite(log(z)) || isnan(log(z))) {
-        cerr << "  " << passtr << " !!! Invalid partition detected, abandoning.\n";
-      } else {
-        for (int i = 0; i < forest.edges_.size(); ++i) {
-          const double log_prob_transition = log(forest.edges_[i].edge_prob_); // locally normalized by the edge
-                                                                            // head node by forest.PushWeightsToGoal
-          if (!isfinite(log_prob_transition) || isnan(log_prob_transition)) {
-            cerr << "Edge: i=" << i << " got bad inside prob: " << *forest.edges_[i].rule_ << endl;
-            abort();
+      if (summary_feature_type == kEDGE_PROB) {
+        const prob_t z = forest.PushWeightsToGoal(1.0);
+        if (!isfinite(log(z)) || isnan(log(z))) {
+          cerr << "  " << passtr << " !!! Invalid partition detected, abandoning.\n";
+        } else {
+          for (int i = 0; i < forest.edges_.size(); ++i) {
+            const double log_prob_transition = log(forest.edges_[i].edge_prob_); // locally normalized by the edge
+                                                                              // head node by forest.PushWeightsToGoal
+            if (!isfinite(log_prob_transition) || isnan(log_prob_transition)) {
+              cerr << "Edge: i=" << i << " got bad inside prob: " << *forest.edges_[i].rule_ << endl;
+              abort();
+            }
+
+            forest.edges_[i].feature_values_.set_value(rp.fid_summary, log_prob_transition);
           }
-
-          forest.edges_[i].feature_values_.set_value(rp.fid_summary, log_prob_transition);
+          forest.Reweight(cur_weights);  // reset weights
         }
-        forest.Reweight(cur_weights);  // reset weights
-      }
-#endif
-      Hypergraph::EdgeProbs posts;
-      const prob_t z = forest.ComputeEdgePosteriors(1.0, &posts);
-      if (!isfinite(log(z)) || isnan(log(z))) {
-        cerr << "  " << passtr << " !!! Invalid partition detected, abandoning.\n";
-      } else {
-        for (int i = 0; i < forest.nodes_.size(); ++i) {
-          const Hypergraph::EdgesVector& in_edges = forest.nodes_[i].in_edges_;
-          prob_t node_post = prob_t(0);
-          for (int j = 0; j < in_edges.size(); ++j)
-            node_post += (posts[in_edges[j]] / z);
-          const double log_np = log(node_post);
-          if (!isfinite(log_np) || isnan(log_np)) {
-            cerr << "got bad posterior prob for node " << i << endl;
-            abort();
+      } else if (summary_feature_type == kNODE_RISK) {
+        Hypergraph::EdgeProbs posts;
+        const prob_t z = forest.ComputeEdgePosteriors(1.0, &posts);
+        if (!isfinite(log(z)) || isnan(log(z))) {
+          cerr << "  " << passtr << " !!! Invalid partition detected, abandoning.\n";
+        } else {
+          for (int i = 0; i < forest.nodes_.size(); ++i) {
+            const Hypergraph::EdgesVector& in_edges = forest.nodes_[i].in_edges_;
+            prob_t node_post = prob_t(0);
+            for (int j = 0; j < in_edges.size(); ++j)
+              node_post += (posts[in_edges[j]] / z);
+            const double log_np = log(node_post);
+            if (!isfinite(log_np) || isnan(log_np)) {
+              cerr << "got bad posterior prob for node " << i << endl;
+              abort();
+            }
+            for (int j = 0; j < in_edges.size(); ++j)
+              forest.edges_[in_edges[j]].feature_values_.set_value(rp.fid_summary, exp(log_np));
           }
-          for (int j = 0; j < in_edges.size(); ++j)
-            forest.edges_[in_edges[j]].feature_values_.set_value(rp.fid_summary, exp(log_np));
         }
+      } else if (summary_feature_type == kEDGE_RISK) {
+        Hypergraph::EdgeProbs posts;
+        const prob_t z = forest.ComputeEdgePosteriors(1.0, &posts);
+        if (!isfinite(log(z)) || isnan(log(z))) {
+          cerr << "  " << passtr << " !!! Invalid partition detected, abandoning.\n";
+        } else {
+          assert(posts.size() == forest.edges_.size());
+          for (int i = 0; i < posts.size(); ++i) {
+            const double log_np = log(posts[i] / z);
+            if (!isfinite(log_np) || isnan(log_np)) {
+              cerr << "got bad posterior prob for node " << i << endl;
+              abort();
+            }
+            forest.edges_[i].feature_values_.set_value(rp.fid_summary, exp(log_np));
+          }
+        }
+      } else {
+        assert(!"shouldn't happen");
       }
     }
 
-- 
cgit v1.2.3