diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/Makefile.am | 6 | ||||
| -rw-r--r-- | decoder/cdec_ff.cc | 2 | ||||
| -rw-r--r-- | decoder/ff_klm.cc | 299 | ||||
| -rw-r--r-- | decoder/ff_klm.h | 32 | 
4 files changed, 337 insertions, 2 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am index da0e5987..ea01a4da 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -12,7 +12,7 @@ TESTS = trule_test ff_test parser_test grammar_test hg_test cfg_test  endif  cdec_SOURCES = cdec.cc -cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz +cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz  cfg_test_SOURCES = cfg_test.cc  cfg_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz @@ -26,7 +26,8 @@ hg_test_SOURCES = hg_test.cc  hg_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz  trule_test_SOURCES = trule_test.cc  trule_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz -AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. -I../mteval -I../utils + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. -I../mteval -I../utils -I../klm  rule_lexer.cc: rule_lexer.l  	$(LEX) -s -CF -8 -o$@ $< @@ -58,6 +59,7 @@ libcdec_a_SOURCES = \    trule.cc \    ff.cc \    ff_lm.cc \ +  ff_klm.cc \    ff_ruleshape.cc \    ff_wordalign.cc \    ff_csplit.cc \ diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index ca5334e9..09a19a7b 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -2,6 +2,7 @@  #include "ff.h"  #include "ff_lm.h" +#include "ff_klm.h"  #include "ff_csplit.h"  #include "ff_wordalign.h"  #include "ff_tagger.h" @@ -29,6 +30,7 @@ void register_feature_functions() {    RegisterFsaDynToFF<SameFirstLetter>();    RegisterFF<LanguageModel>(); +  RegisterFF<KLanguageModel>();    RegisterFF<WordPenalty>();    RegisterFF<SourceWordPenalty>(); diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc new file mode 100644 index 00000000..5888c4a3 --- /dev/null +++ b/decoder/ff_klm.cc @@ -0,0 +1,299 @@ +#include "ff_klm.h" + +#include "hg.h" +#include "tdict.h" +#include "lm/model.hh" +#include "lm/enumerate_vocab.hh" + +using namespace std; + +string KLanguageModel::usage(bool param,bool verbose) { +  return "KLanguageModel"; +} + +struct VMapper : public lm::ngram::EnumerateVocab { +  VMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } +  void Add(lm::WordIndex index, const StringPiece &str) { +    const WordID cdec_id = TD::Convert(str.as_string()); +    if (cdec_id >= out_->size()) +      out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); +    (*out_)[cdec_id] = index; +  } +  vector<lm::WordIndex>* out_; +  const lm::WordIndex kLM_UNKNOWN_TOKEN; +}; + +class KLanguageModelImpl { +  inline int StateSize(const void* state) const { +    return *(static_cast<const char*>(state) + state_size_); +  } + +  inline void SetStateSize(int size, void* state) const { +    *(static_cast<char*>(state) + state_size_) = size; +  } + +#if 0 +  virtual double WordProb(WordID word, WordID const* context) { +    return ngram_.wordProb(word, (VocabIndex*)context); +  } + +  // may be shorter than actual null-terminated length.  context must be null terminated.  len is just to save effort for subclasses that don't support contextID +  virtual int ContextSize(WordID const* context,int len) { +    unsigned ret; +    ngram_.contextID((VocabIndex*)context,ret); +    return ret; +  } +  virtual double ContextBOW(WordID const* context,int shortened_len) { +    return ngram_.contextBOW((VocabIndex*)context,shortened_len); +  } + +  inline double LookupProbForBufferContents(int i) { +//    int k = i; cerr << "P("; while(buffer_[k] > 0) { std::cerr << TD::Convert(buffer_[k++]) << " "; } +    double p = WordProb(buffer_[i], &buffer_[i+1]); +    if (p < floor_) p = floor_; +//    cerr << ")=" << p << endl; +    return p; +  } + +  string DebugStateToString(const void* state) const { +    int len = StateSize(state); +    const int* astate = reinterpret_cast<const int*>(state); +    string res = "["; +    for (int i = 0; i < len; ++i) { +      res += " "; +      res += TD::Convert(astate[i]); +    } +    res += " ]"; +    return res; +  } + +  inline double ProbNoRemnant(int i, int len) { +    int edge = len; +    bool flag = true; +    double sum = 0.0; +    while (i >= 0) { +      if (buffer_[i] == kSTAR) { +        edge = i; +        flag = false; +      } else if (buffer_[i] <= 0) { +        edge = i; +        flag = true; +      } else { +        if ((edge-i >= order_) || (flag && !(i == (len-1) && buffer_[i] == kSTART))) +          sum += LookupProbForBufferContents(i); +      } +      --i; +    } +    return sum; +  } + +  double EstimateProb(const vector<WordID>& phrase) { +    int len = phrase.size(); +    buffer_.resize(len + 1); +    buffer_[len] = kNONE; +    int i = len - 1; +    for (int j = 0; j < len; ++j,--i) +      buffer_[i] = phrase[j]; +    return ProbNoRemnant(len - 1, len); +  } + +  //TODO: make sure this doesn't get used in FinalTraversal, or if it does, that it causes no harm. + +  //TODO: use stateless_cost instead of ProbNoRemnant, check left words only.  for items w/ fewer words than ctx len, how are they represented?  kNONE padded? + +  //Vocab_None is (unsigned)-1 in srilm, same as kNONE. in srilm (-1), or that SRILM otherwise interprets -1 as a terminator and not a word +  double EstimateProb(const void* state) { +    if (unigram) return 0.; +    int len = StateSize(state); +    //  << "residual len: " << len << endl; +    buffer_.resize(len + 1); +    buffer_[len] = kNONE; +    const int* astate = reinterpret_cast<const WordID*>(state); +    int i = len - 1; +    for (int j = 0; j < len; ++j,--i) +      buffer_[i] = astate[j]; +    return ProbNoRemnant(len - 1, len); +  } + +  //FIXME: this assumes no target words on final unary -> goal rule.  is that ok? +  // for <s> (n-1 left words) and (n-1 right words) </s> +  double FinalTraversalCost(const void* state) { +    if (unigram) return 0.; +    int slen = StateSize(state); +    int len = slen + 2; +    // cerr << "residual len: " << len << endl; +    buffer_.resize(len + 1); +    buffer_[len] = kNONE; +    buffer_[len-1] = kSTART; +    const int* astate = reinterpret_cast<const WordID*>(state); +    int i = len - 2; +    for (int j = 0; j < slen; ++j,--i) +      buffer_[i] = astate[j]; +    buffer_[i] = kSTOP; +    assert(i == 0); +    return ProbNoRemnant(len - 1, len); +  } + +  /// just how SRILM likes it: [rbegin,rend) is a phrase in reverse word order and null terminated so *rend=kNONE.  return unigram score for rend[-1] plus +  /// cost returned is some kind of log prob (who cares, we're just adding) +  double stateless_cost(WordID *rbegin,WordID *rend) { +    UNIDBG("p("); +    double sum=0; +    for (;rend>rbegin;--rend) { +      sum+=clamp(WordProb(rend[-1],rend)); +      UNIDBG(" "<<TD::Convert(rend[-1])); +    } +    UNIDBG(")="<<sum<<endl); +    return sum; +  } + +  //TODO: this would be a fine rule heuristic (for reordering hyperedges prior to rescoring.  for now you can just use a same-lm-file -o 1 prelm-rescore :( +  double stateless_cost(TRule const& rule) { +    //TODO: make sure this is correct. +    int len = rule.ELength(); // use a gap for each variable +    buffer_.resize(len + 1); +    WordID * const rend=&buffer_[0]+len; +    *rend=kNONE; +    WordID *r=rend;  // append by *--r = x +    const vector<WordID>& e = rule.e(); +    //SRILM is reverse order null terminated +    //let's write down each phrase in reverse order and score it (note: we could lay them out consecutively then score them (we allocated enough buffer for that), but we won't actually use the whole buffer that way, since it wastes L1 cache. +    double sum=0.; +    for (unsigned j = 0; j < e.size(); ++j) { +      if (e[j] < 1) { // variable +        sum+=stateless_cost(r,rend); +        r=rend; +      } else { // terminal +          *--r=e[j]; +      } +    } +    // last phrase (if any) +    return sum+stateless_cost(r,rend); +  } + +  //NOTE: this is where the scoring of words happens (heuristic happens in EstimateProb) +  double LookupWords(const TRule& rule, const vector<const void*>& ant_states, void* vstate) { +    if (unigram) +      return stateless_cost(rule); +    int len = rule.ELength() - rule.Arity(); +    for (int i = 0; i < ant_states.size(); ++i) +      len += StateSize(ant_states[i]); +    buffer_.resize(len + 1); +    buffer_[len] = kNONE; +    int i = len - 1; +    const vector<WordID>& e = rule.e(); +    for (int j = 0; j < e.size(); ++j) { +      if (e[j] < 1) { +        const int* astate = reinterpret_cast<const int*>(ant_states[-e[j]]); +        int slen = StateSize(astate); +        for (int k = 0; k < slen; ++k) +          buffer_[i--] = astate[k]; +      } else { +        buffer_[i--] = e[j]; +      } +    } + +    double sum = 0.0; +    int* remnant = reinterpret_cast<int*>(vstate); +    int j = 0; +    i = len - 1; +    int edge = len; + +    while (i >= 0) { +      if (buffer_[i] == kSTAR) { +        edge = i; +      } else if (edge-i >= order_) { +        sum += LookupProbForBufferContents(i); +      } else if (edge == len && remnant) { +        remnant[j++] = buffer_[i]; +      } +      --i; +    } +    if (!remnant) return sum; + +    if (edge != len || len >= order_) { +      remnant[j++] = kSTAR; +      if (order_-1 < edge) edge = order_-1; +      for (int i = edge-1; i >= 0; --i) +        remnant[j++] = buffer_[i]; +    } + +    SetStateSize(j, vstate); +    return sum; +  } + +private: +public: + + protected: +  vector<WordID> buffer_; + public: +  WordID kSTART; +  WordID kSTOP; +  WordID kUNKNOWN; +  WordID kNONE; +  WordID kSTAR; +  bool unigram; +#endif + +  lm::WordIndex MapWord(WordID w) const { +    if (w >= map_.size()) +      return 0; +    else +      return map_[w]; +  } + + public: +  KLanguageModelImpl(const std::string& param) { +    lm::ngram::Config conf; +    VMapper vm(&map_); +    conf.enumerate_vocab = &vm;  +    ngram_ = new lm::ngram::Model(param.c_str(), conf); +    cerr << "Loaded " << order_ << "-gram KLM from " << param << endl; +    order_ = ngram_->Order(); +    state_size_ = ngram_->StateSize() + 1 + (order_-1) * sizeof(int); +  } + +  ~KLanguageModelImpl() { +    delete ngram_; +  } + +  const int ReserveStateSize() const { return state_size_; } + + private: +  lm::ngram::Model* ngram_; +  int order_; +  int state_size_; +  vector<lm::WordIndex> map_; + +}; + +KLanguageModel::KLanguageModel(const string& param) { +  pimpl_ = new KLanguageModelImpl(param); +  fid_ = FD::Convert("LanguageModel"); +  SetStateSize(pimpl_->ReserveStateSize()); +} + +Features KLanguageModel::features() const { +  return single_feature(fid_); +} + +KLanguageModel::~KLanguageModel() { +  delete pimpl_; +} + +void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, +                                          const Hypergraph::Edge& edge, +                                          const vector<const void*>& ant_states, +                                          SparseVector<double>* features, +                                          SparseVector<double>* estimated_features, +                                          void* state) const { +//  features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state)); +//  estimated_features->set_value(fid_, imp().EstimateProb(state)); +} + +void KLanguageModel::FinalTraversalFeatures(const void* ant_state, +                                           SparseVector<double>* features) const { +//  features->set_value(fid_, imp().FinalTraversalCost(ant_state)); +} + diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h new file mode 100644 index 00000000..0569286f --- /dev/null +++ b/decoder/ff_klm.h @@ -0,0 +1,32 @@ +#ifndef _KLM_FF_H_ +#define _KLM_FF_H_ + +#include <vector> +#include <string> + +#include "ff.h" + +struct KLanguageModelImpl; + +class KLanguageModel : public FeatureFunction { + public: +  // param = "filename.lm [-o n]" +  KLanguageModel(const std::string& param); +  ~KLanguageModel(); +  virtual void FinalTraversalFeatures(const void* context, +                                      SparseVector<double>* features) const; +  static std::string usage(bool param,bool verbose); +  Features features() const; + protected: +  virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, +                                     const Hypergraph::Edge& edge, +                                     const std::vector<const void*>& ant_contexts, +                                     SparseVector<double>* features, +                                     SparseVector<double>* estimated_features, +                                     void* out_context) const; + private: +  int fid_; // conceptually const; mutable only to simplify constructor +  KLanguageModelImpl* pimpl_; +}; + +#endif  | 
