diff options
-rw-r--r-- | decoder/ff_klm.cc | 297 |
1 files changed, 107 insertions, 190 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 5888c4a3..092c07b0 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -1,5 +1,7 @@ #include "ff_klm.h" +#include <cstring> + #include "hg.h" #include "tdict.h" #include "lm/model.hh" @@ -24,217 +26,116 @@ struct VMapper : public lm::ngram::EnumerateVocab { }; class KLanguageModelImpl { - inline int StateSize(const void* state) const { - return *(static_cast<const char*>(state) + state_size_); - } - - inline void SetStateSize(int size, void* state) const { - *(static_cast<char*>(state) + state_size_) = size; - } -#if 0 - virtual double WordProb(WordID word, WordID const* context) { - return ngram_.wordProb(word, (VocabIndex*)context); + // returns the number of unscored words at the left edge of a span + inline int UnscoredSize(const void* state) const { + return *(static_cast<const char*>(state) + unscored_size_offset_); } - // may be shorter than actual null-terminated length. context must be null terminated. len is just to save effort for subclasses that don't support contextID - virtual int ContextSize(WordID const* context,int len) { - unsigned ret; - ngram_.contextID((VocabIndex*)context,ret); - return ret; - } - virtual double ContextBOW(WordID const* context,int shortened_len) { - return ngram_.contextBOW((VocabIndex*)context,shortened_len); + inline void SetUnscoredSize(int size, void* state) const { + *(static_cast<char*>(state) + unscored_size_offset_) = size; } - inline double LookupProbForBufferContents(int i) { -// int k = i; cerr << "P("; while(buffer_[k] > 0) { std::cerr << TD::Convert(buffer_[k++]) << " "; } - double p = WordProb(buffer_[i], &buffer_[i+1]); - if (p < floor_) p = floor_; -// cerr << ")=" << p << endl; - return p; + static inline const lm::ngram::Model::State& RemnantLMState(const void* state) { + return *static_cast<const lm::ngram::Model::State*>(state); } - string DebugStateToString(const void* state) const { - int len = StateSize(state); - const int* astate = reinterpret_cast<const int*>(state); - string res = "["; - for (int i = 0; i < len; ++i) { - res += " "; - res += TD::Convert(astate[i]); - } - res += " ]"; - return res; + inline void SetRemnantLMState(const lm::ngram::Model::State& lmstate, void* state) const { + // if we were clever, we could use the memory pointed to by state to do all + // the work, avoiding this copy + memcpy(state, &lmstate, ngram_->StateSize()); } - inline double ProbNoRemnant(int i, int len) { - int edge = len; - bool flag = true; - double sum = 0.0; - while (i >= 0) { - if (buffer_[i] == kSTAR) { - edge = i; - flag = false; - } else if (buffer_[i] <= 0) { - edge = i; - flag = true; - } else { - if ((edge-i >= order_) || (flag && !(i == (len-1) && buffer_[i] == kSTART))) - sum += LookupProbForBufferContents(i); - } - --i; - } - return sum; + lm::WordIndex IthUnscoredWord(int i, const void* state) const { + const lm::WordIndex* const mem = reinterpret_cast<const lm::WordIndex*>(static_cast<const char*>(state) + unscored_words_offset_); + return mem[i]; } - double EstimateProb(const vector<WordID>& phrase) { - int len = phrase.size(); - buffer_.resize(len + 1); - buffer_[len] = kNONE; - int i = len - 1; - for (int j = 0; j < len; ++j,--i) - buffer_[i] = phrase[j]; - return ProbNoRemnant(len - 1, len); + void SetIthUnscoredWord(int i, lm::WordIndex index, void *state) const { + lm::WordIndex* mem = reinterpret_cast<lm::WordIndex*>(static_cast<char*>(state) + unscored_words_offset_); + mem[i] = index; } - //TODO: make sure this doesn't get used in FinalTraversal, or if it does, that it causes no harm. - - //TODO: use stateless_cost instead of ProbNoRemnant, check left words only. for items w/ fewer words than ctx len, how are they represented? kNONE padded? - - //Vocab_None is (unsigned)-1 in srilm, same as kNONE. in srilm (-1), or that SRILM otherwise interprets -1 as a terminator and not a word - double EstimateProb(const void* state) { - if (unigram) return 0.; - int len = StateSize(state); - // << "residual len: " << len << endl; - buffer_.resize(len + 1); - buffer_[len] = kNONE; - const int* astate = reinterpret_cast<const WordID*>(state); - int i = len - 1; - for (int j = 0; j < len; ++j,--i) - buffer_[i] = astate[j]; - return ProbNoRemnant(len - 1, len); - } - - //FIXME: this assumes no target words on final unary -> goal rule. is that ok? - // for <s> (n-1 left words) and (n-1 right words) </s> - double FinalTraversalCost(const void* state) { - if (unigram) return 0.; - int slen = StateSize(state); - int len = slen + 2; - // cerr << "residual len: " << len << endl; - buffer_.resize(len + 1); - buffer_[len] = kNONE; - buffer_[len-1] = kSTART; - const int* astate = reinterpret_cast<const WordID*>(state); - int i = len - 2; - for (int j = 0; j < slen; ++j,--i) - buffer_[i] = astate[j]; - buffer_[i] = kSTOP; - assert(i == 0); - return ProbNoRemnant(len - 1, len); - } - - /// just how SRILM likes it: [rbegin,rend) is a phrase in reverse word order and null terminated so *rend=kNONE. return unigram score for rend[-1] plus - /// cost returned is some kind of log prob (who cares, we're just adding) - double stateless_cost(WordID *rbegin,WordID *rend) { - UNIDBG("p("); - double sum=0; - for (;rend>rbegin;--rend) { - sum+=clamp(WordProb(rend[-1],rend)); - UNIDBG(" "<<TD::Convert(rend[-1])); - } - UNIDBG(")="<<sum<<endl); - return sum; + bool HasFullContext(const void *state) const { + return *(static_cast<const char*>(state) + is_complete_offset_); } - //TODO: this would be a fine rule heuristic (for reordering hyperedges prior to rescoring. for now you can just use a same-lm-file -o 1 prelm-rescore :( - double stateless_cost(TRule const& rule) { - //TODO: make sure this is correct. - int len = rule.ELength(); // use a gap for each variable - buffer_.resize(len + 1); - WordID * const rend=&buffer_[0]+len; - *rend=kNONE; - WordID *r=rend; // append by *--r = x - const vector<WordID>& e = rule.e(); - //SRILM is reverse order null terminated - //let's write down each phrase in reverse order and score it (note: we could lay them out consecutively then score them (we allocated enough buffer for that), but we won't actually use the whole buffer that way, since it wastes L1 cache. - double sum=0.; - for (unsigned j = 0; j < e.size(); ++j) { - if (e[j] < 1) { // variable - sum+=stateless_cost(r,rend); - r=rend; - } else { // terminal - *--r=e[j]; - } - } - // last phrase (if any) - return sum+stateless_cost(r,rend); + void SetHasFullContext(bool flag, void *state) const { + *(static_cast<char*>(state) + is_complete_offset_) = flag; } - //NOTE: this is where the scoring of words happens (heuristic happens in EstimateProb) - double LookupWords(const TRule& rule, const vector<const void*>& ant_states, void* vstate) { - if (unigram) - return stateless_cost(rule); + public: + double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, void* remnant) { + double sum = 0.0; + double est_sum = 0.0; int len = rule.ELength() - rule.Arity(); - for (int i = 0; i < ant_states.size(); ++i) - len += StateSize(ant_states[i]); - buffer_.resize(len + 1); - buffer_[len] = kNONE; - int i = len - 1; + int num_scored = 0; + int num_estimated = 0; + lm::ngram::Model::State state = ngram_->NullContextState(); const vector<WordID>& e = rule.e(); + bool context_complete = false; for (int j = 0; j < e.size(); ++j) { if (e[j] < 1) { - const int* astate = reinterpret_cast<const int*>(ant_states[-e[j]]); - int slen = StateSize(astate); - for (int k = 0; k < slen; ++k) - buffer_[i--] = astate[k]; + const void* astate = (ant_states[-e[j]]); + int unscored_ant_len = UnscoredSize(astate); + for (int k = 0; k < unscored_ant_len; ++k) { + const lm::WordIndex cur_word = IthUnscoredWord(k, astate); + const lm::ngram::Model::State scopy(state); + const double p = ngram_->Score(scopy, cur_word, state); + ++num_scored; + if (!context_complete) { + if (num_scored >= order_) context_complete = true; + } + if (context_complete) { + sum += p; + } else { + if (remnant) + SetIthUnscoredWord(num_estimated, cur_word, remnant); + ++num_estimated; + est_sum += p; + } + } + if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007 + state = RemnantLMState(astate); + context_complete = true; + } } else { - buffer_[i--] = e[j]; + const lm::WordIndex cur_word = MapWord(e[j]); + const lm::ngram::Model::State scopy(state); + const double p = ngram_->Score(scopy, cur_word, state); + ++num_scored; + if (!context_complete) { + if (num_scored >= order_) context_complete = true; + } + if (context_complete) { + sum += p; + } else { + if (remnant) + SetIthUnscoredWord(num_estimated, cur_word, remnant); + ++num_estimated; + est_sum += p; + } } } - - double sum = 0.0; - int* remnant = reinterpret_cast<int*>(vstate); - int j = 0; - i = len - 1; - int edge = len; - - while (i >= 0) { - if (buffer_[i] == kSTAR) { - edge = i; - } else if (edge-i >= order_) { - sum += LookupProbForBufferContents(i); - } else if (edge == len && remnant) { - remnant[j++] = buffer_[i]; - } - --i; + if (pest_sum) *pest_sum = est_sum; + if (remnant) { + state.ZeroRemaining(); + SetRemnantLMState(state, remnant); + SetUnscoredSize(num_estimated, remnant); + SetHasFullContext(context_complete || (num_scored >= order_), remnant); } - if (!remnant) return sum; - - if (edge != len || len >= order_) { - remnant[j++] = kSTAR; - if (order_-1 < edge) edge = order_-1; - for (int i = edge-1; i >= 0; --i) - remnant[j++] = buffer_[i]; - } - - SetStateSize(j, vstate); return sum; } -private: -public: - - protected: - vector<WordID> buffer_; - public: - WordID kSTART; - WordID kSTOP; - WordID kUNKNOWN; - WordID kNONE; - WordID kSTAR; - bool unigram; -#endif + //FIXME: this assumes no target words on final unary -> goal rule. is that ok? + // for <s> (n-1 left words) and (n-1 right words) </s> + double FinalTraversalCost(const void* state) { + SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_); + SetHasFullContext(1, dummy_state_); + SetUnscoredSize(0, dummy_state_); + dummy_ants_[1] = state; + return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL); + } lm::WordIndex MapWord(WordID w) const { if (w >= map_.size()) @@ -249,23 +150,38 @@ public: VMapper vm(&map_); conf.enumerate_vocab = &vm; ngram_ = new lm::ngram::Model(param.c_str(), conf); - cerr << "Loaded " << order_ << "-gram KLM from " << param << endl; order_ = ngram_->Order(); - state_size_ = ngram_->StateSize() + 1 + (order_-1) * sizeof(int); + cerr << "Loaded " << order_ << "-gram KLM from " << param << " (MapSize=" << map_.size() << ")\n"; + state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex); + unscored_size_offset_ = ngram_->StateSize(); + is_complete_offset_ = unscored_size_offset_ + 1; + unscored_words_offset_ = is_complete_offset_ + 1; + + // special handling of beginning / ending sentence markers + dummy_state_ = new char[state_size_]; + dummy_ants_.push_back(dummy_state_); + dummy_ants_.push_back(NULL); + dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0")); } ~KLanguageModelImpl() { delete ngram_; + delete[] dummy_state_; } - const int ReserveStateSize() const { return state_size_; } + int ReserveStateSize() const { return state_size_; } private: lm::ngram::Model* ngram_; int order_; int state_size_; + int unscored_size_offset_; + int is_complete_offset_; + int unscored_words_offset_; + char* dummy_state_; + vector<const void*> dummy_ants_; vector<lm::WordIndex> map_; - + TRulePtr dummy_rule_; }; KLanguageModel::KLanguageModel(const string& param) { @@ -288,12 +204,13 @@ void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, SparseVector<double>* features, SparseVector<double>* estimated_features, void* state) const { -// features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state)); -// estimated_features->set_value(fid_, imp().EstimateProb(state)); + double est = 0; + features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &est, state)); + estimated_features->set_value(fid_, est); } void KLanguageModel::FinalTraversalFeatures(const void* ant_state, SparseVector<double>* features) const { -// features->set_value(fid_, imp().FinalTraversalCost(ant_state)); + features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state)); } |