From 5c40dcfe323dbf95a16a995588a77f393d42749c Mon Sep 17 00:00:00 2001
From: graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>
Date: Sat, 7 Aug 2010 03:41:32 +0000
Subject: apply fsa models (so far only by bottom up) in cdec

git-svn-id: https://ws10smt.googlecode.com/svn/trunk@487 ec762483-ff6d-05da-a07a-a48fb63a330f
---
 decoder/Makefile.am         |   1 +
 decoder/apply_fsa_models.cc | 102 ++++++++++++++++++++++++++++++++++++++++++++
 decoder/apply_fsa_models.h  |  29 ++++++++++++-
 decoder/apply_models.h      |   6 +++
 decoder/cdec.cc             |  29 ++++++++++---
 decoder/feature_vector.h    |   4 ++
 decoder/stringlib.h         |  12 ++++++
 7 files changed, 176 insertions(+), 7 deletions(-)
 create mode 100755 decoder/apply_fsa_models.cc

diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index 189e28b0..88d6d17a 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -44,6 +44,7 @@ rule_lexer.cc: rule_lexer.l
 noinst_LIBRARIES = libcdec.a
 
 libcdec_a_SOURCES = \
+  apply_fsa_models.cc \
   rule_lexer.cc \
   fst_translator.cc \
   csplit.cc \
diff --git a/decoder/apply_fsa_models.cc b/decoder/apply_fsa_models.cc
new file mode 100755
index 00000000..27773b0d
--- /dev/null
+++ b/decoder/apply_fsa_models.cc
@@ -0,0 +1,102 @@
+#include "apply_fsa_models.h"
+#include "hg.h"
+#include "ff_fsa_dynamic.h"
+#include "feature_vector.h"
+#include "stringlib.h"
+#include "apply_models.h"
+#include <stdexcept>
+#include <cassert>
+
+using namespace std;
+
+struct ApplyFsa {
+  ApplyFsa(const Hypergraph& ih,
+           const SentenceMetadata& smeta,
+           const FsaFeatureFunction& fsa,
+           DenseWeightVector const& weights,
+           ApplyFsaBy const& cfg,
+           Hypergraph* oh)
+    :ih(ih),smeta(smeta),fsa(fsa),weights(weights),cfg(cfg),oh(oh)
+  {
+//    sparse_to_dense(weight_vector,&weights);
+    Init();
+  }
+  void Init() {
+    ApplyBottomUp();
+    //TODO: implement l->r
+  }
+  void ApplyBottomUp() {
+    assert(cfg.IsBottomUp());
+    vector<const FeatureFunction*> ffs;
+    ModelSet models(weights, ffs);
+    IntersectionConfiguration i(cfg.BottomUpAlgorithm(),cfg.pop_limit);
+    ApplyModelSet(ih,smeta,models,i,oh);
+  }
+private:
+  const Hypergraph& ih;
+  const SentenceMetadata& smeta;
+  const FsaFeatureFunction& fsa;
+//  WeightVector weight_vector;
+  DenseWeightVector weights;
+  ApplyFsaBy cfg;
+  Hypergraph* oh;
+};
+
+
+void ApplyFsaModels(const Hypergraph& ih,
+                    const SentenceMetadata& smeta,
+                    const FsaFeatureFunction& fsa,
+                    DenseWeightVector const& weight_vector,
+                    ApplyFsaBy const& cfg,
+                    Hypergraph* oh)
+{
+  ApplyFsa a(ih,smeta,fsa,weight_vector,cfg,oh);
+}
+
+
+namespace {
+char const* anames[]={
+  "BU_CUBE",
+  "BU_FULL",
+  "EARLEY",
+  0
+};
+}
+
+//TODO: named enum type in boost?
+
+std::string ApplyFsaBy::name() const {
+  return anames[algorithm];
+}
+
+std::string ApplyFsaBy::all_names() {
+  std::ostringstream o;
+  for (int i=0;i<N_ALGORITHMS;++i) {
+    assert(anames[i]);
+    if (i) o<<' ';
+    o<<anames[i];
+  }
+  return o.str();
+}
+
+ApplyFsaBy::ApplyFsaBy(std::string const& n, int pop_limit) : pop_limit(pop_limit){
+  algorithm=0;
+  std::string uname=toupper(n);
+  while(anames[algorithm] && anames[algorithm] != uname) ++algorithm;
+  if (!anames[algorithm])
+    throw std::runtime_error("Unknown ApplyFsaBy type: "+n+" - legal types: "+all_names());
+}
+
+ApplyFsaBy::ApplyFsaBy(int i, int pop_limit) : pop_limit(pop_limit) {
+  assert (i>=0);
+  assert (i<N_ALGORITHMS);
+  algorithm=i;
+}
+
+int ApplyFsaBy::BottomUpAlgorithm() const {
+  assert(IsBottomUp());
+  return algorithm==BU_CUBE ?
+    IntersectionConfiguration::CUBE
+    :IntersectionConfiguration::FULL;
+}
+
diff --git a/decoder/apply_fsa_models.h b/decoder/apply_fsa_models.h
index 0a8615b5..64ebab39 100755
--- a/decoder/apply_fsa_models.h
+++ b/decoder/apply_fsa_models.h
@@ -1,15 +1,42 @@
 #ifndef _APPLY_FSA_MODELS_H_
 #define _APPLY_FSA_MODELS_H_
 
-#include "ff_fsa_dynamic.h"
+#include <iostream>
+#include "feature_vector.h"
 
 struct FsaFeatureFunction;
 struct Hypergraph;
 struct SentenceMetadata;
 
+struct ApplyFsaBy {
+  enum {
+    BU_CUBE,
+    BU_FULL,
+    EARLEY,
+    N_ALGORITHMS
+  };
+  int pop_limit; // only applies to BU_FULL so far
+  bool IsBottomUp() const {
+    return algorithm==BU_FULL || algorithm==BU_CUBE;
+  }
+  int BottomUpAlgorithm() const;
+  int algorithm;
+  std::string name() const;
+  friend inline std::ostream &operator << (std::ostream &o,ApplyFsaBy const& c) {
+    return o << c.name();
+  }
+  explicit ApplyFsaBy(int alg, int poplimit=200);
+  ApplyFsaBy(std::string const& name, int poplimit=200);
+  ApplyFsaBy(const ApplyFsaBy &o) : algorithm(o.algorithm) {  }
+  static std::string all_names(); // space separated
+};
+
+
 void ApplyFsaModels(const Hypergraph& in,
                     const SentenceMetadata& smeta,
                     const FsaFeatureFunction& fsa,
+                    DenseWeightVector const& weights, // pre: in is weighted by these (except with fsa featval=0 before this)
+                    ApplyFsaBy const& cfg,
                     Hypergraph* out);
 
 #endif
diff --git a/decoder/apply_models.h b/decoder/apply_models.h
index 61a5b8f7..81fa068e 100644
--- a/decoder/apply_models.h
+++ b/decoder/apply_models.h
@@ -8,6 +8,12 @@ struct SentenceMetadata;
 struct exhaustive_t {};
 
 struct IntersectionConfiguration {
+enum {
+  FULL,
+  CUBE,
+  N_ALGORITHMS
+};
+
   const int algorithm; // 0 = full intersection, 1 = cube pruning
   const int pop_limit; // max number of pops off the heap at each node
   IntersectionConfiguration(int alg, int k) : algorithm(alg), pop_limit(k) {}
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index a7c99307..72f0b95e 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -25,6 +25,7 @@
 #include "weights.h"
 #include "tdict.h"
 #include "ff.h"
+#include "ff_fsa_dynamic.h"
 #include "ff_factory.h"
 #include "hg_intersect.h"
 #include "apply_models.h"
@@ -119,7 +120,8 @@ void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* c
     ("warn_0_weight","Warn about any feature id that has a 0 weight (this is perfectly safe if you intend 0 weight, though)")
         ("no_freeze_feature_set,Z", "Do not freeze feature set after reading feature weights file")
         ("feature_function,F",po::value<vector<string> >()->composing(), "Additional feature function(s) (-L for list)")
-        ("fsa_feature_function",po::value<vector<string> >()->composing(), "Additional FSA feature function(s) (-L for list)")
+        ("fsa_feature_function,A",po::value<vector<string> >()->composing(), "Additional FSA feature function(s) (-L for list)")
+    ("apply_fsa_by",po::value<string>()->default_value("BU_CUBE"), "Method for applying fsa_feature_functions - BU_FULL BU_CUBE EARLEY") //+ApplyFsaBy::all_names()
         ("list_feature_functions,L","List available feature functions")
         ("add_pass_through_rules,P","Add rules to translate OOV words as themselves")
 	("k_best,k",po::value<int>(),"Extract the k best derivations")
@@ -147,7 +149,7 @@ void InitCommandLine(int argc, char** argv, OracleBleu &ob, po::variables_map* c
         ("ctf_no_exhaustive", "Do not fall back to exhaustive parse if coarse-to-fine parsing fails")
         ("beam_prune", po::value<double>(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)")
     ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices")
-    ("promise_power",po::value<double>()->default_value(0), "Give more beam budget to more promising previous-pass nodes when pruning - but allocate the same average beams.  0 means off, 1 means beam proportional to inside*outside prob, n means nth power (affects just --cubepruning_pop_limit).  note: for the same poplimit, this gives more search error unless very close to 0 (recommend disabled; even 0.01 is slightly worse than 0) which is a bad sign and suggests this isn't doing a good job; further it's slightly slower to LM cube rescore with 0.01 compared to 0, as well as giving (very insignificantly) lower BLEU.  TODO: test under more conditions, or try idea with different formula, or prob. cube beams.")
+    ("promise_power",po::value<double>()->default_value(0), "Give more beam budget to more promising previous-pass nodes when pruning - but allocate the same average beams.  0 means off, 1 means beam proportional to inside*outside prob, n means nth power (affects just --cubepruning_pop_limit).  note: for the same pop_limit, this gives more search error unless very close to 0 (recommend disabled; even 0.01 is slightly worse than 0) which is a bad sign and suggests this isn't doing a good job; further it's slightly slower to LM cube rescore with 0.01 compared to 0, as well as giving (very insignificantly) lower BLEU.  TODO: test under more conditions, or try idea with different formula, or prob. cube beams.")
         ("lexalign_use_null", "Support source-side null words in lexical translation")
         ("tagger_tagset,t", po::value<string>(), "(Tagger) file containing tag set")
         ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format")
@@ -519,7 +521,8 @@ int main(int argc, char** argv) {
     palg = 0;
     cerr << "Using full intersection (no pruning).\n";
   }
-  const IntersectionConfiguration inter_conf(palg, conf["cubepruning_pop_limit"].as<int>());
+  int pop_limit=conf["cubepruning_pop_limit"].as<int>();
+  const IntersectionConfiguration inter_conf(palg, pop_limit);
 
   const int sample_max_trans = conf.count("max_translation_sample") ?
     conf["max_translation_sample"].as<int>() : 0;
@@ -619,7 +622,7 @@ int main(int argc, char** argv) {
                     inter_conf, // this is now reduced to exhaustive if all are stateless
                     &prelm_forest);
       forest.swap(prelm_forest);
-      forest.Reweight(prelm_feature_weights);
+      forest.Reweight(prelm_feature_weights); //FIXME: why the reweighting? here and below.  maybe in case we already had a featval for that id and ApplyModelSet only adds prob, doesn't recompute it?
       forest_stats(forest," prelm forest",show_tree_structure,show_features,prelm_feature_weights,oracle.show_derivation);
     }
 
@@ -642,8 +645,20 @@ int main(int argc, char** argv) {
 
     maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen);
 
-    vector<WordID> trans;
-    ViterbiESentence(forest, &trans);
+
+
+    if (!fsa_ffs.empty()) {
+      Timer t("Target FSA rescoring:");
+      if (!has_late_models)
+        forest.Reweight(feature_weights);
+      Hypergraph fsa_forest;
+      assert(fsa_ffs.size()==1);
+      ApplyFsaBy cfg(str("apply_fsa_by",conf),pop_limit);
+      ApplyFsaModels(forest,smeta,*fsa_ffs[0],feature_weights,cfg,&fsa_forest);
+      forest.swap(fsa_forest);
+      forest.Reweight(feature_weights);
+      forest_stats(forest,"  +FSA forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation);
+    }
 
 
     /*Oracle Rescoring*/
@@ -687,6 +702,8 @@ int main(int argc, char** argv) {
         cout << HypergraphIO::AsPLF(forest, false) << endl;
       } else {
         if (!graphviz && !has_ref && !joshua_viz) {
+          vector<WordID> trans;
+          ViterbiESentence(forest, &trans);
           cout << TD::GetString(trans) << endl << flush;
         }
         if (joshua_viz) {
diff --git a/decoder/feature_vector.h b/decoder/feature_vector.h
index 1b272506..be378a6a 100755
--- a/decoder/feature_vector.h
+++ b/decoder/feature_vector.h
@@ -11,4 +11,8 @@ typedef SparseVector<Featval> FeatureVector;
 typedef SparseVector<Featval> WeightVector;
 typedef std::vector<Featval> DenseWeightVector;
 
+inline void sparse_to_dense(WeightVector const& wv,DenseWeightVector *dv) {
+  wv.init_vector(dv);
+}
+
 #endif
diff --git a/decoder/stringlib.h b/decoder/stringlib.h
index b3097bd1..53e6fe50 100644
--- a/decoder/stringlib.h
+++ b/decoder/stringlib.h
@@ -18,6 +18,18 @@
 #include <cstring>
 #include <string>
 #include <sstream>
+#include <algorithm>
+
+struct toupperc {
+  inline char operator()(char c) const {
+    return std::toupper(c);
+  }
+};
+
+inline std::string toupper(std::string s) {
+  std::transform(s.begin(),s.end(),s.begin(),toupperc());
+  return s;
+}
 
 template <class Istr, class Isubstr> inline
 bool match_begin(Istr bstr,Istr estr,Isubstr bsub,Isubstr esub)
-- 
cgit v1.2.3