From d94b8d0690249624a8f6a427df9c7edad354e333 Mon Sep 17 00:00:00 2001
From: redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>
Date: Thu, 11 Nov 2010 20:31:56 +0000
Subject: klm stub

git-svn-id: https://ws10smt.googlecode.com/svn/trunk@711 ec762483-ff6d-05da-a07a-a48fb63a330f
---
 decoder/Makefile.am |   6 +-
 decoder/cdec_ff.cc  |   2 +
 decoder/ff_klm.cc   | 299 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 decoder/ff_klm.h    |  32 ++++++
 4 files changed, 337 insertions(+), 2 deletions(-)
 create mode 100644 decoder/ff_klm.cc
 create mode 100644 decoder/ff_klm.h

(limited to 'decoder')

diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index da0e5987..ea01a4da 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -12,7 +12,7 @@ TESTS = trule_test ff_test parser_test grammar_test hg_test cfg_test
 endif
 
 cdec_SOURCES = cdec.cc
-cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz
+cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
 
 cfg_test_SOURCES = cfg_test.cc
 cfg_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz
@@ -26,7 +26,8 @@ hg_test_SOURCES = hg_test.cc
 hg_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz
 trule_test_SOURCES = trule_test.cc
 trule_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz
-AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. -I../mteval -I../utils
+
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. -I../mteval -I../utils -I../klm
 
 rule_lexer.cc: rule_lexer.l
 	$(LEX) -s -CF -8 -o$@ $<
@@ -58,6 +59,7 @@ libcdec_a_SOURCES = \
   trule.cc \
   ff.cc \
   ff_lm.cc \
+  ff_klm.cc \
   ff_ruleshape.cc \
   ff_wordalign.cc \
   ff_csplit.cc \
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index ca5334e9..09a19a7b 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -2,6 +2,7 @@
 
 #include "ff.h"
 #include "ff_lm.h"
+#include "ff_klm.h"
 #include "ff_csplit.h"
 #include "ff_wordalign.h"
 #include "ff_tagger.h"
@@ -29,6 +30,7 @@ void register_feature_functions() {
   RegisterFsaDynToFF<SameFirstLetter>();
 
   RegisterFF<LanguageModel>();
+  RegisterFF<KLanguageModel>();
 
   RegisterFF<WordPenalty>();
   RegisterFF<SourceWordPenalty>();
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
new file mode 100644
index 00000000..5888c4a3
--- /dev/null
+++ b/decoder/ff_klm.cc
@@ -0,0 +1,299 @@
+#include "ff_klm.h"
+
+#include "hg.h"
+#include "tdict.h"
+#include "lm/model.hh"
+#include "lm/enumerate_vocab.hh"
+
+using namespace std;
+
+string KLanguageModel::usage(bool param,bool verbose) {
+  return "KLanguageModel";
+}
+
+struct VMapper : public lm::ngram::EnumerateVocab {
+  VMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); }
+  void Add(lm::WordIndex index, const StringPiece &str) {
+    const WordID cdec_id = TD::Convert(str.as_string());
+    if (cdec_id >= out_->size())
+      out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN);
+    (*out_)[cdec_id] = index;
+  }
+  vector<lm::WordIndex>* out_;
+  const lm::WordIndex kLM_UNKNOWN_TOKEN;
+};
+
+class KLanguageModelImpl {
+  inline int StateSize(const void* state) const {
+    return *(static_cast<const char*>(state) + state_size_);
+  }
+
+  inline void SetStateSize(int size, void* state) const {
+    *(static_cast<char*>(state) + state_size_) = size;
+  }
+
+#if 0
+  virtual double WordProb(WordID word, WordID const* context) {
+    return ngram_.wordProb(word, (VocabIndex*)context);
+  }
+
+  // may be shorter than actual null-terminated length.  context must be null terminated.  len is just to save effort for subclasses that don't support contextID
+  virtual int ContextSize(WordID const* context,int len) {
+    unsigned ret;
+    ngram_.contextID((VocabIndex*)context,ret);
+    return ret;
+  }
+  virtual double ContextBOW(WordID const* context,int shortened_len) {
+    return ngram_.contextBOW((VocabIndex*)context,shortened_len);
+  }
+
+  inline double LookupProbForBufferContents(int i) {
+//    int k = i; cerr << "P("; while(buffer_[k] > 0) { std::cerr << TD::Convert(buffer_[k++]) << " "; }
+    double p = WordProb(buffer_[i], &buffer_[i+1]);
+    if (p < floor_) p = floor_;
+//    cerr << ")=" << p << endl;
+    return p;
+  }
+
+  string DebugStateToString(const void* state) const {
+    int len = StateSize(state);
+    const int* astate = reinterpret_cast<const int*>(state);
+    string res = "[";
+    for (int i = 0; i < len; ++i) {
+      res += " ";
+      res += TD::Convert(astate[i]);
+    }
+    res += " ]";
+    return res;
+  }
+
+  inline double ProbNoRemnant(int i, int len) {
+    int edge = len;
+    bool flag = true;
+    double sum = 0.0;
+    while (i >= 0) {
+      if (buffer_[i] == kSTAR) {
+        edge = i;
+        flag = false;
+      } else if (buffer_[i] <= 0) {
+        edge = i;
+        flag = true;
+      } else {
+        if ((edge-i >= order_) || (flag && !(i == (len-1) && buffer_[i] == kSTART)))
+          sum += LookupProbForBufferContents(i);
+      }
+      --i;
+    }
+    return sum;
+  }
+
+  double EstimateProb(const vector<WordID>& phrase) {
+    int len = phrase.size();
+    buffer_.resize(len + 1);
+    buffer_[len] = kNONE;
+    int i = len - 1;
+    for (int j = 0; j < len; ++j,--i)
+      buffer_[i] = phrase[j];
+    return ProbNoRemnant(len - 1, len);
+  }
+
+  //TODO: make sure this doesn't get used in FinalTraversal, or if it does, that it causes no harm.
+
+  //TODO: use stateless_cost instead of ProbNoRemnant, check left words only.  for items w/ fewer words than ctx len, how are they represented?  kNONE padded?
+
+  //Vocab_None is (unsigned)-1 in srilm, same as kNONE. in srilm (-1), or that SRILM otherwise interprets -1 as a terminator and not a word
+  double EstimateProb(const void* state) {
+    if (unigram) return 0.;
+    int len = StateSize(state);
+    //  << "residual len: " << len << endl;
+    buffer_.resize(len + 1);
+    buffer_[len] = kNONE;
+    const int* astate = reinterpret_cast<const WordID*>(state);
+    int i = len - 1;
+    for (int j = 0; j < len; ++j,--i)
+      buffer_[i] = astate[j];
+    return ProbNoRemnant(len - 1, len);
+  }
+
+  //FIXME: this assumes no target words on final unary -> goal rule.  is that ok?
+  // for <s> (n-1 left words) and (n-1 right words) </s>
+  double FinalTraversalCost(const void* state) {
+    if (unigram) return 0.;
+    int slen = StateSize(state);
+    int len = slen + 2;
+    // cerr << "residual len: " << len << endl;
+    buffer_.resize(len + 1);
+    buffer_[len] = kNONE;
+    buffer_[len-1] = kSTART;
+    const int* astate = reinterpret_cast<const WordID*>(state);
+    int i = len - 2;
+    for (int j = 0; j < slen; ++j,--i)
+      buffer_[i] = astate[j];
+    buffer_[i] = kSTOP;
+    assert(i == 0);
+    return ProbNoRemnant(len - 1, len);
+  }
+
+  /// just how SRILM likes it: [rbegin,rend) is a phrase in reverse word order and null terminated so *rend=kNONE.  return unigram score for rend[-1] plus
+  /// cost returned is some kind of log prob (who cares, we're just adding)
+  double stateless_cost(WordID *rbegin,WordID *rend) {
+    UNIDBG("p(");
+    double sum=0;
+    for (;rend>rbegin;--rend) {
+      sum+=clamp(WordProb(rend[-1],rend));
+      UNIDBG(" "<<TD::Convert(rend[-1]));
+    }
+    UNIDBG(")="<<sum<<endl);
+    return sum;
+  }
+
+  //TODO: this would be a fine rule heuristic (for reordering hyperedges prior to rescoring.  for now you can just use a same-lm-file -o 1 prelm-rescore :(
+  double stateless_cost(TRule const& rule) {
+    //TODO: make sure this is correct.
+    int len = rule.ELength(); // use a gap for each variable
+    buffer_.resize(len + 1);
+    WordID * const rend=&buffer_[0]+len;
+    *rend=kNONE;
+    WordID *r=rend;  // append by *--r = x
+    const vector<WordID>& e = rule.e();
+    //SRILM is reverse order null terminated
+    //let's write down each phrase in reverse order and score it (note: we could lay them out consecutively then score them (we allocated enough buffer for that), but we won't actually use the whole buffer that way, since it wastes L1 cache.
+    double sum=0.;
+    for (unsigned j = 0; j < e.size(); ++j) {
+      if (e[j] < 1) { // variable
+        sum+=stateless_cost(r,rend);
+        r=rend;
+      } else { // terminal
+          *--r=e[j];
+      }
+    }
+    // last phrase (if any)
+    return sum+stateless_cost(r,rend);
+  }
+
+  //NOTE: this is where the scoring of words happens (heuristic happens in EstimateProb)
+  double LookupWords(const TRule& rule, const vector<const void*>& ant_states, void* vstate) {
+    if (unigram)
+      return stateless_cost(rule);
+    int len = rule.ELength() - rule.Arity();
+    for (int i = 0; i < ant_states.size(); ++i)
+      len += StateSize(ant_states[i]);
+    buffer_.resize(len + 1);
+    buffer_[len] = kNONE;
+    int i = len - 1;
+    const vector<WordID>& e = rule.e();
+    for (int j = 0; j < e.size(); ++j) {
+      if (e[j] < 1) {
+        const int* astate = reinterpret_cast<const int*>(ant_states[-e[j]]);
+        int slen = StateSize(astate);
+        for (int k = 0; k < slen; ++k)
+          buffer_[i--] = astate[k];
+      } else {
+        buffer_[i--] = e[j];
+      }
+    }
+
+    double sum = 0.0;
+    int* remnant = reinterpret_cast<int*>(vstate);
+    int j = 0;
+    i = len - 1;
+    int edge = len;
+
+    while (i >= 0) {
+      if (buffer_[i] == kSTAR) {
+        edge = i;
+      } else if (edge-i >= order_) {
+        sum += LookupProbForBufferContents(i);
+      } else if (edge == len && remnant) {
+        remnant[j++] = buffer_[i];
+      }
+      --i;
+    }
+    if (!remnant) return sum;
+
+    if (edge != len || len >= order_) {
+      remnant[j++] = kSTAR;
+      if (order_-1 < edge) edge = order_-1;
+      for (int i = edge-1; i >= 0; --i)
+        remnant[j++] = buffer_[i];
+    }
+
+    SetStateSize(j, vstate);
+    return sum;
+  }
+
+private:
+public:
+
+ protected:
+  vector<WordID> buffer_;
+ public:
+  WordID kSTART;
+  WordID kSTOP;
+  WordID kUNKNOWN;
+  WordID kNONE;
+  WordID kSTAR;
+  bool unigram;
+#endif
+
+  lm::WordIndex MapWord(WordID w) const {
+    if (w >= map_.size())
+      return 0;
+    else
+      return map_[w];
+  }
+
+ public:
+  KLanguageModelImpl(const std::string& param) {
+    lm::ngram::Config conf;
+    VMapper vm(&map_);
+    conf.enumerate_vocab = &vm; 
+    ngram_ = new lm::ngram::Model(param.c_str(), conf);
+    cerr << "Loaded " << order_ << "-gram KLM from " << param << endl;
+    order_ = ngram_->Order();
+    state_size_ = ngram_->StateSize() + 1 + (order_-1) * sizeof(int);
+  }
+
+  ~KLanguageModelImpl() {
+    delete ngram_;
+  }
+
+  const int ReserveStateSize() const { return state_size_; }
+
+ private:
+  lm::ngram::Model* ngram_;
+  int order_;
+  int state_size_;
+  vector<lm::WordIndex> map_;
+
+};
+
+KLanguageModel::KLanguageModel(const string& param) {
+  pimpl_ = new KLanguageModelImpl(param);
+  fid_ = FD::Convert("LanguageModel");
+  SetStateSize(pimpl_->ReserveStateSize());
+}
+
+Features KLanguageModel::features() const {
+  return single_feature(fid_);
+}
+
+KLanguageModel::~KLanguageModel() {
+  delete pimpl_;
+}
+
+void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,
+                                          const Hypergraph::Edge& edge,
+                                          const vector<const void*>& ant_states,
+                                          SparseVector<double>* features,
+                                          SparseVector<double>* estimated_features,
+                                          void* state) const {
+//  features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state));
+//  estimated_features->set_value(fid_, imp().EstimateProb(state));
+}
+
+void KLanguageModel::FinalTraversalFeatures(const void* ant_state,
+                                           SparseVector<double>* features) const {
+//  features->set_value(fid_, imp().FinalTraversalCost(ant_state));
+}
+
diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h
new file mode 100644
index 00000000..0569286f
--- /dev/null
+++ b/decoder/ff_klm.h
@@ -0,0 +1,32 @@
+#ifndef _KLM_FF_H_
+#define _KLM_FF_H_
+
+#include <vector>
+#include <string>
+
+#include "ff.h"
+
+struct KLanguageModelImpl;
+
+class KLanguageModel : public FeatureFunction {
+ public:
+  // param = "filename.lm [-o n]"
+  KLanguageModel(const std::string& param);
+  ~KLanguageModel();
+  virtual void FinalTraversalFeatures(const void* context,
+                                      SparseVector<double>* features) const;
+  static std::string usage(bool param,bool verbose);
+  Features features() const;
+ protected:
+  virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+                                     const Hypergraph::Edge& edge,
+                                     const std::vector<const void*>& ant_contexts,
+                                     SparseVector<double>* features,
+                                     SparseVector<double>* estimated_features,
+                                     void* out_context) const;
+ private:
+  int fid_; // conceptually const; mutable only to simplify constructor
+  KLanguageModelImpl* pimpl_;
+};
+
+#endif
-- 
cgit v1.2.3