#include "ff_klm.h" #include "hg.h" #include "tdict.h" #include "lm/model.hh" #include "lm/enumerate_vocab.hh" using namespace std; string KLanguageModel::usage(bool param,bool verbose) { return "KLanguageModel"; } struct VMapper : public lm::ngram::EnumerateVocab { VMapper(vector* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } void Add(lm::WordIndex index, const StringPiece &str) { const WordID cdec_id = TD::Convert(str.as_string()); if (cdec_id >= out_->size()) out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); (*out_)[cdec_id] = index; } vector* out_; const lm::WordIndex kLM_UNKNOWN_TOKEN; }; class KLanguageModelImpl { inline int StateSize(const void* state) const { return *(static_cast(state) + state_size_); } inline void SetStateSize(int size, void* state) const { *(static_cast(state) + state_size_) = size; } #if 0 virtual double WordProb(WordID word, WordID const* context) { return ngram_.wordProb(word, (VocabIndex*)context); } // may be shorter than actual null-terminated length. context must be null terminated. len is just to save effort for subclasses that don't support contextID virtual int ContextSize(WordID const* context,int len) { unsigned ret; ngram_.contextID((VocabIndex*)context,ret); return ret; } virtual double ContextBOW(WordID const* context,int shortened_len) { return ngram_.contextBOW((VocabIndex*)context,shortened_len); } inline double LookupProbForBufferContents(int i) { // int k = i; cerr << "P("; while(buffer_[k] > 0) { std::cerr << TD::Convert(buffer_[k++]) << " "; } double p = WordProb(buffer_[i], &buffer_[i+1]); if (p < floor_) p = floor_; // cerr << ")=" << p << endl; return p; } string DebugStateToString(const void* state) const { int len = StateSize(state); const int* astate = reinterpret_cast(state); string res = "["; for (int i = 0; i < len; ++i) { res += " "; res += TD::Convert(astate[i]); } res += " ]"; return res; } inline double ProbNoRemnant(int i, int len) { int edge = len; bool flag = true; double sum = 0.0; while (i >= 0) { if (buffer_[i] == kSTAR) { edge = i; flag = false; } else if (buffer_[i] <= 0) { edge = i; flag = true; } else { if ((edge-i >= order_) || (flag && !(i == (len-1) && buffer_[i] == kSTART))) sum += LookupProbForBufferContents(i); } --i; } return sum; } double EstimateProb(const vector& phrase) { int len = phrase.size(); buffer_.resize(len + 1); buffer_[len] = kNONE; int i = len - 1; for (int j = 0; j < len; ++j,--i) buffer_[i] = phrase[j]; return ProbNoRemnant(len - 1, len); } //TODO: make sure this doesn't get used in FinalTraversal, or if it does, that it causes no harm. //TODO: use stateless_cost instead of ProbNoRemnant, check left words only. for items w/ fewer words than ctx len, how are they represented? kNONE padded? //Vocab_None is (unsigned)-1 in srilm, same as kNONE. in srilm (-1), or that SRILM otherwise interprets -1 as a terminator and not a word double EstimateProb(const void* state) { if (unigram) return 0.; int len = StateSize(state); // << "residual len: " << len << endl; buffer_.resize(len + 1); buffer_[len] = kNONE; const int* astate = reinterpret_cast(state); int i = len - 1; for (int j = 0; j < len; ++j,--i) buffer_[i] = astate[j]; return ProbNoRemnant(len - 1, len); } //FIXME: this assumes no target words on final unary -> goal rule. is that ok? // for (n-1 left words) and (n-1 right words) double FinalTraversalCost(const void* state) { if (unigram) return 0.; int slen = StateSize(state); int len = slen + 2; // cerr << "residual len: " << len << endl; buffer_.resize(len + 1); buffer_[len] = kNONE; buffer_[len-1] = kSTART; const int* astate = reinterpret_cast(state); int i = len - 2; for (int j = 0; j < slen; ++j,--i) buffer_[i] = astate[j]; buffer_[i] = kSTOP; assert(i == 0); return ProbNoRemnant(len - 1, len); } /// just how SRILM likes it: [rbegin,rend) is a phrase in reverse word order and null terminated so *rend=kNONE. return unigram score for rend[-1] plus /// cost returned is some kind of log prob (who cares, we're just adding) double stateless_cost(WordID *rbegin,WordID *rend) { UNIDBG("p("); double sum=0; for (;rend>rbegin;--rend) { sum+=clamp(WordProb(rend[-1],rend)); UNIDBG(" "<Order(); state_size_ = ngram_->StateSize() + 1 + (order_-1) * sizeof(int); } ~KLanguageModelImpl() { delete ngram_; } const int ReserveStateSize() const { return state_size_; } private: lm::ngram::Model* ngram_; int order_; int state_size_; vector map_; }; KLanguageModel::KLanguageModel(const string& param) { pimpl_ = new KLanguageModelImpl(param); fid_ = FD::Convert("LanguageModel"); SetStateSize(pimpl_->ReserveStateSize()); } Features KLanguageModel::features() const { return single_feature(fid_); } KLanguageModel::~KLanguageModel() { delete pimpl_; } void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, const Hypergraph::Edge& edge, const vector& ant_states, SparseVector* features, SparseVector* estimated_features, void* state) const { // features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state)); // estimated_features->set_value(fid_, imp().EstimateProb(state)); } void KLanguageModel::FinalTraversalFeatures(const void* ant_state, SparseVector* features) const { // features->set_value(fid_, imp().FinalTraversalCost(ant_state)); }