diff options
| author | Chris Dyer <cdyer@cs.cmu.edu> | 2010-12-13 21:40:08 -0500 | 
|---|---|---|
| committer | Chris Dyer <cdyer@cs.cmu.edu> | 2010-12-13 21:40:08 -0500 | 
| commit | b58684f29a9ee4331014404cf3128725e3333e78 (patch) | |
| tree | 313f4c8938e4ad977369ec9f238b694eeff0e625 | |
| parent | 66e5956906e61b047d2fd451f3053916cbc92433 (diff) | |
integration complete with KenLM, not fully tested
| -rw-r--r-- | decoder/ff_klm.cc | 297 | 
1 files changed, 107 insertions, 190 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 5888c4a3..092c07b0 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -1,5 +1,7 @@  #include "ff_klm.h" +#include <cstring> +  #include "hg.h"  #include "tdict.h"  #include "lm/model.hh" @@ -24,217 +26,116 @@ struct VMapper : public lm::ngram::EnumerateVocab {  };  class KLanguageModelImpl { -  inline int StateSize(const void* state) const { -    return *(static_cast<const char*>(state) + state_size_); -  } - -  inline void SetStateSize(int size, void* state) const { -    *(static_cast<char*>(state) + state_size_) = size; -  } -#if 0 -  virtual double WordProb(WordID word, WordID const* context) { -    return ngram_.wordProb(word, (VocabIndex*)context); +  // returns the number of unscored words at the left edge of a span +  inline int UnscoredSize(const void* state) const { +    return *(static_cast<const char*>(state) + unscored_size_offset_);    } -  // may be shorter than actual null-terminated length.  context must be null terminated.  len is just to save effort for subclasses that don't support contextID -  virtual int ContextSize(WordID const* context,int len) { -    unsigned ret; -    ngram_.contextID((VocabIndex*)context,ret); -    return ret; -  } -  virtual double ContextBOW(WordID const* context,int shortened_len) { -    return ngram_.contextBOW((VocabIndex*)context,shortened_len); +  inline void SetUnscoredSize(int size, void* state) const { +    *(static_cast<char*>(state) + unscored_size_offset_) = size;    } -  inline double LookupProbForBufferContents(int i) { -//    int k = i; cerr << "P("; while(buffer_[k] > 0) { std::cerr << TD::Convert(buffer_[k++]) << " "; } -    double p = WordProb(buffer_[i], &buffer_[i+1]); -    if (p < floor_) p = floor_; -//    cerr << ")=" << p << endl; -    return p; +  static inline const lm::ngram::Model::State& RemnantLMState(const void* state) { +    return *static_cast<const lm::ngram::Model::State*>(state);    } -  string DebugStateToString(const void* state) const { -    int len = StateSize(state); -    const int* astate = reinterpret_cast<const int*>(state); -    string res = "["; -    for (int i = 0; i < len; ++i) { -      res += " "; -      res += TD::Convert(astate[i]); -    } -    res += " ]"; -    return res; +  inline void SetRemnantLMState(const lm::ngram::Model::State& lmstate, void* state) const { +    // if we were clever, we could use the memory pointed to by state to do all +    // the work, avoiding this copy +    memcpy(state, &lmstate, ngram_->StateSize());    } -  inline double ProbNoRemnant(int i, int len) { -    int edge = len; -    bool flag = true; -    double sum = 0.0; -    while (i >= 0) { -      if (buffer_[i] == kSTAR) { -        edge = i; -        flag = false; -      } else if (buffer_[i] <= 0) { -        edge = i; -        flag = true; -      } else { -        if ((edge-i >= order_) || (flag && !(i == (len-1) && buffer_[i] == kSTART))) -          sum += LookupProbForBufferContents(i); -      } -      --i; -    } -    return sum; +  lm::WordIndex IthUnscoredWord(int i, const void* state) const { +    const lm::WordIndex* const mem = reinterpret_cast<const lm::WordIndex*>(static_cast<const char*>(state) + unscored_words_offset_); +    return mem[i];    } -  double EstimateProb(const vector<WordID>& phrase) { -    int len = phrase.size(); -    buffer_.resize(len + 1); -    buffer_[len] = kNONE; -    int i = len - 1; -    for (int j = 0; j < len; ++j,--i) -      buffer_[i] = phrase[j]; -    return ProbNoRemnant(len - 1, len); +  void SetIthUnscoredWord(int i, lm::WordIndex index, void *state) const { +    lm::WordIndex* mem = reinterpret_cast<lm::WordIndex*>(static_cast<char*>(state) + unscored_words_offset_); +    mem[i] = index;    } -  //TODO: make sure this doesn't get used in FinalTraversal, or if it does, that it causes no harm. - -  //TODO: use stateless_cost instead of ProbNoRemnant, check left words only.  for items w/ fewer words than ctx len, how are they represented?  kNONE padded? - -  //Vocab_None is (unsigned)-1 in srilm, same as kNONE. in srilm (-1), or that SRILM otherwise interprets -1 as a terminator and not a word -  double EstimateProb(const void* state) { -    if (unigram) return 0.; -    int len = StateSize(state); -    //  << "residual len: " << len << endl; -    buffer_.resize(len + 1); -    buffer_[len] = kNONE; -    const int* astate = reinterpret_cast<const WordID*>(state); -    int i = len - 1; -    for (int j = 0; j < len; ++j,--i) -      buffer_[i] = astate[j]; -    return ProbNoRemnant(len - 1, len); -  } - -  //FIXME: this assumes no target words on final unary -> goal rule.  is that ok? -  // for <s> (n-1 left words) and (n-1 right words) </s> -  double FinalTraversalCost(const void* state) { -    if (unigram) return 0.; -    int slen = StateSize(state); -    int len = slen + 2; -    // cerr << "residual len: " << len << endl; -    buffer_.resize(len + 1); -    buffer_[len] = kNONE; -    buffer_[len-1] = kSTART; -    const int* astate = reinterpret_cast<const WordID*>(state); -    int i = len - 2; -    for (int j = 0; j < slen; ++j,--i) -      buffer_[i] = astate[j]; -    buffer_[i] = kSTOP; -    assert(i == 0); -    return ProbNoRemnant(len - 1, len); -  } - -  /// just how SRILM likes it: [rbegin,rend) is a phrase in reverse word order and null terminated so *rend=kNONE.  return unigram score for rend[-1] plus -  /// cost returned is some kind of log prob (who cares, we're just adding) -  double stateless_cost(WordID *rbegin,WordID *rend) { -    UNIDBG("p("); -    double sum=0; -    for (;rend>rbegin;--rend) { -      sum+=clamp(WordProb(rend[-1],rend)); -      UNIDBG(" "<<TD::Convert(rend[-1])); -    } -    UNIDBG(")="<<sum<<endl); -    return sum; +  bool HasFullContext(const void *state) const { +    return *(static_cast<const char*>(state) + is_complete_offset_);    } -  //TODO: this would be a fine rule heuristic (for reordering hyperedges prior to rescoring.  for now you can just use a same-lm-file -o 1 prelm-rescore :( -  double stateless_cost(TRule const& rule) { -    //TODO: make sure this is correct. -    int len = rule.ELength(); // use a gap for each variable -    buffer_.resize(len + 1); -    WordID * const rend=&buffer_[0]+len; -    *rend=kNONE; -    WordID *r=rend;  // append by *--r = x -    const vector<WordID>& e = rule.e(); -    //SRILM is reverse order null terminated -    //let's write down each phrase in reverse order and score it (note: we could lay them out consecutively then score them (we allocated enough buffer for that), but we won't actually use the whole buffer that way, since it wastes L1 cache. -    double sum=0.; -    for (unsigned j = 0; j < e.size(); ++j) { -      if (e[j] < 1) { // variable -        sum+=stateless_cost(r,rend); -        r=rend; -      } else { // terminal -          *--r=e[j]; -      } -    } -    // last phrase (if any) -    return sum+stateless_cost(r,rend); +  void SetHasFullContext(bool flag, void *state) const { +    *(static_cast<char*>(state) + is_complete_offset_) = flag;    } -  //NOTE: this is where the scoring of words happens (heuristic happens in EstimateProb) -  double LookupWords(const TRule& rule, const vector<const void*>& ant_states, void* vstate) { -    if (unigram) -      return stateless_cost(rule); + public: +  double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, void* remnant) { +    double sum = 0.0; +    double est_sum = 0.0;      int len = rule.ELength() - rule.Arity(); -    for (int i = 0; i < ant_states.size(); ++i) -      len += StateSize(ant_states[i]); -    buffer_.resize(len + 1); -    buffer_[len] = kNONE; -    int i = len - 1; +    int num_scored = 0; +    int num_estimated = 0; +    lm::ngram::Model::State state = ngram_->NullContextState();      const vector<WordID>& e = rule.e(); +    bool context_complete = false;      for (int j = 0; j < e.size(); ++j) {        if (e[j] < 1) { -        const int* astate = reinterpret_cast<const int*>(ant_states[-e[j]]); -        int slen = StateSize(astate); -        for (int k = 0; k < slen; ++k) -          buffer_[i--] = astate[k]; +        const void* astate = (ant_states[-e[j]]); +        int unscored_ant_len = UnscoredSize(astate); +        for (int k = 0; k < unscored_ant_len; ++k) { +          const lm::WordIndex cur_word = IthUnscoredWord(k, astate); +          const lm::ngram::Model::State scopy(state); +          const double p = ngram_->Score(scopy, cur_word, state); +          ++num_scored; +          if (!context_complete) { +            if (num_scored >= order_) context_complete = true; +          } +          if (context_complete) { +            sum += p; +          } else { +            if (remnant) +              SetIthUnscoredWord(num_estimated, cur_word, remnant); +            ++num_estimated; +            est_sum += p; +          } +        } +        if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007 +          state = RemnantLMState(astate); +          context_complete = true; +        }        } else { -        buffer_[i--] = e[j]; +        const lm::WordIndex cur_word = MapWord(e[j]); +        const lm::ngram::Model::State scopy(state); +        const double p = ngram_->Score(scopy, cur_word, state); +        ++num_scored; +        if (!context_complete) { +          if (num_scored >= order_) context_complete = true; +        } +        if (context_complete) { +          sum += p; +        } else { +          if (remnant) +            SetIthUnscoredWord(num_estimated, cur_word, remnant); +          ++num_estimated; +          est_sum += p; +        }        }      } - -    double sum = 0.0; -    int* remnant = reinterpret_cast<int*>(vstate); -    int j = 0; -    i = len - 1; -    int edge = len; - -    while (i >= 0) { -      if (buffer_[i] == kSTAR) { -        edge = i; -      } else if (edge-i >= order_) { -        sum += LookupProbForBufferContents(i); -      } else if (edge == len && remnant) { -        remnant[j++] = buffer_[i]; -      } -      --i; +    if (pest_sum) *pest_sum = est_sum; +    if (remnant) { +      state.ZeroRemaining(); +      SetRemnantLMState(state, remnant); +      SetUnscoredSize(num_estimated, remnant); +      SetHasFullContext(context_complete || (num_scored >= order_), remnant);      } -    if (!remnant) return sum; - -    if (edge != len || len >= order_) { -      remnant[j++] = kSTAR; -      if (order_-1 < edge) edge = order_-1; -      for (int i = edge-1; i >= 0; --i) -        remnant[j++] = buffer_[i]; -    } - -    SetStateSize(j, vstate);      return sum;    } -private: -public: - - protected: -  vector<WordID> buffer_; - public: -  WordID kSTART; -  WordID kSTOP; -  WordID kUNKNOWN; -  WordID kNONE; -  WordID kSTAR; -  bool unigram; -#endif +  //FIXME: this assumes no target words on final unary -> goal rule.  is that ok? +  // for <s> (n-1 left words) and (n-1 right words) </s> +  double FinalTraversalCost(const void* state) { +    SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_); +    SetHasFullContext(1, dummy_state_); +    SetUnscoredSize(0, dummy_state_); +    dummy_ants_[1] = state; +    return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL); +  }    lm::WordIndex MapWord(WordID w) const {      if (w >= map_.size()) @@ -249,23 +150,38 @@ public:      VMapper vm(&map_);      conf.enumerate_vocab = &vm;       ngram_ = new lm::ngram::Model(param.c_str(), conf); -    cerr << "Loaded " << order_ << "-gram KLM from " << param << endl;      order_ = ngram_->Order(); -    state_size_ = ngram_->StateSize() + 1 + (order_-1) * sizeof(int); +    cerr << "Loaded " << order_ << "-gram KLM from " << param << " (MapSize=" << map_.size() << ")\n"; +    state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex); +    unscored_size_offset_ = ngram_->StateSize(); +    is_complete_offset_ = unscored_size_offset_ + 1; +    unscored_words_offset_ = is_complete_offset_ + 1; + +    // special handling of beginning / ending sentence markers +    dummy_state_ = new char[state_size_]; +    dummy_ants_.push_back(dummy_state_); +    dummy_ants_.push_back(NULL); +    dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0"));    }    ~KLanguageModelImpl() {      delete ngram_; +    delete[] dummy_state_;    } -  const int ReserveStateSize() const { return state_size_; } +  int ReserveStateSize() const { return state_size_; }   private:    lm::ngram::Model* ngram_;    int order_;    int state_size_; +  int unscored_size_offset_; +  int is_complete_offset_; +  int unscored_words_offset_; +  char* dummy_state_; +  vector<const void*> dummy_ants_;    vector<lm::WordIndex> map_; - +  TRulePtr dummy_rule_;  };  KLanguageModel::KLanguageModel(const string& param) { @@ -288,12 +204,13 @@ void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,                                            SparseVector<double>* features,                                            SparseVector<double>* estimated_features,                                            void* state) const { -//  features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state)); -//  estimated_features->set_value(fid_, imp().EstimateProb(state)); +  double est = 0; +  features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &est, state)); +  estimated_features->set_value(fid_, est);  }  void KLanguageModel::FinalTraversalFeatures(const void* ant_state,                                             SparseVector<double>* features) const { -//  features->set_value(fid_, imp().FinalTraversalCost(ant_state)); +  features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state));  }  | 
