diff options
-rw-r--r-- | decoder/Makefile.am | 6 | ||||
-rw-r--r-- | decoder/cdec_ff.cc | 2 | ||||
-rw-r--r-- | decoder/ff_klm.cc | 299 | ||||
-rw-r--r-- | decoder/ff_klm.h | 32 |
4 files changed, 337 insertions, 2 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am index da0e5987..ea01a4da 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -12,7 +12,7 @@ TESTS = trule_test ff_test parser_test grammar_test hg_test cfg_test endif cdec_SOURCES = cdec.cc -cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz +cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz cfg_test_SOURCES = cfg_test.cc cfg_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz @@ -26,7 +26,8 @@ hg_test_SOURCES = hg_test.cc hg_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz trule_test_SOURCES = trule_test.cc trule_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz -AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. -I../mteval -I../utils + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. -I../mteval -I../utils -I../klm rule_lexer.cc: rule_lexer.l $(LEX) -s -CF -8 -o$@ $< @@ -58,6 +59,7 @@ libcdec_a_SOURCES = \ trule.cc \ ff.cc \ ff_lm.cc \ + ff_klm.cc \ ff_ruleshape.cc \ ff_wordalign.cc \ ff_csplit.cc \ diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index ca5334e9..09a19a7b 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -2,6 +2,7 @@ #include "ff.h" #include "ff_lm.h" +#include "ff_klm.h" #include "ff_csplit.h" #include "ff_wordalign.h" #include "ff_tagger.h" @@ -29,6 +30,7 @@ void register_feature_functions() { RegisterFsaDynToFF<SameFirstLetter>(); RegisterFF<LanguageModel>(); + RegisterFF<KLanguageModel>(); RegisterFF<WordPenalty>(); RegisterFF<SourceWordPenalty>(); diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc new file mode 100644 index 00000000..5888c4a3 --- /dev/null +++ b/decoder/ff_klm.cc @@ -0,0 +1,299 @@ +#include "ff_klm.h" + +#include "hg.h" +#include "tdict.h" +#include "lm/model.hh" +#include "lm/enumerate_vocab.hh" + +using namespace std; + +string KLanguageModel::usage(bool param,bool verbose) { + return "KLanguageModel"; +} + +struct VMapper : public lm::ngram::EnumerateVocab { + VMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } + void Add(lm::WordIndex index, const StringPiece &str) { + const WordID cdec_id = TD::Convert(str.as_string()); + if (cdec_id >= out_->size()) + out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); + (*out_)[cdec_id] = index; + } + vector<lm::WordIndex>* out_; + const lm::WordIndex kLM_UNKNOWN_TOKEN; +}; + +class KLanguageModelImpl { + inline int StateSize(const void* state) const { + return *(static_cast<const char*>(state) + state_size_); + } + + inline void SetStateSize(int size, void* state) const { + *(static_cast<char*>(state) + state_size_) = size; + } + +#if 0 + virtual double WordProb(WordID word, WordID const* context) { + return ngram_.wordProb(word, (VocabIndex*)context); + } + + // may be shorter than actual null-terminated length. context must be null terminated. len is just to save effort for subclasses that don't support contextID + virtual int ContextSize(WordID const* context,int len) { + unsigned ret; + ngram_.contextID((VocabIndex*)context,ret); + return ret; + } + virtual double ContextBOW(WordID const* context,int shortened_len) { + return ngram_.contextBOW((VocabIndex*)context,shortened_len); + } + + inline double LookupProbForBufferContents(int i) { +// int k = i; cerr << "P("; while(buffer_[k] > 0) { std::cerr << TD::Convert(buffer_[k++]) << " "; } + double p = WordProb(buffer_[i], &buffer_[i+1]); + if (p < floor_) p = floor_; +// cerr << ")=" << p << endl; + return p; + } + + string DebugStateToString(const void* state) const { + int len = StateSize(state); + const int* astate = reinterpret_cast<const int*>(state); + string res = "["; + for (int i = 0; i < len; ++i) { + res += " "; + res += TD::Convert(astate[i]); + } + res += " ]"; + return res; + } + + inline double ProbNoRemnant(int i, int len) { + int edge = len; + bool flag = true; + double sum = 0.0; + while (i >= 0) { + if (buffer_[i] == kSTAR) { + edge = i; + flag = false; + } else if (buffer_[i] <= 0) { + edge = i; + flag = true; + } else { + if ((edge-i >= order_) || (flag && !(i == (len-1) && buffer_[i] == kSTART))) + sum += LookupProbForBufferContents(i); + } + --i; + } + return sum; + } + + double EstimateProb(const vector<WordID>& phrase) { + int len = phrase.size(); + buffer_.resize(len + 1); + buffer_[len] = kNONE; + int i = len - 1; + for (int j = 0; j < len; ++j,--i) + buffer_[i] = phrase[j]; + return ProbNoRemnant(len - 1, len); + } + + //TODO: make sure this doesn't get used in FinalTraversal, or if it does, that it causes no harm. + + //TODO: use stateless_cost instead of ProbNoRemnant, check left words only. for items w/ fewer words than ctx len, how are they represented? kNONE padded? + + //Vocab_None is (unsigned)-1 in srilm, same as kNONE. in srilm (-1), or that SRILM otherwise interprets -1 as a terminator and not a word + double EstimateProb(const void* state) { + if (unigram) return 0.; + int len = StateSize(state); + // << "residual len: " << len << endl; + buffer_.resize(len + 1); + buffer_[len] = kNONE; + const int* astate = reinterpret_cast<const WordID*>(state); + int i = len - 1; + for (int j = 0; j < len; ++j,--i) + buffer_[i] = astate[j]; + return ProbNoRemnant(len - 1, len); + } + + //FIXME: this assumes no target words on final unary -> goal rule. is that ok? + // for <s> (n-1 left words) and (n-1 right words) </s> + double FinalTraversalCost(const void* state) { + if (unigram) return 0.; + int slen = StateSize(state); + int len = slen + 2; + // cerr << "residual len: " << len << endl; + buffer_.resize(len + 1); + buffer_[len] = kNONE; + buffer_[len-1] = kSTART; + const int* astate = reinterpret_cast<const WordID*>(state); + int i = len - 2; + for (int j = 0; j < slen; ++j,--i) + buffer_[i] = astate[j]; + buffer_[i] = kSTOP; + assert(i == 0); + return ProbNoRemnant(len - 1, len); + } + + /// just how SRILM likes it: [rbegin,rend) is a phrase in reverse word order and null terminated so *rend=kNONE. return unigram score for rend[-1] plus + /// cost returned is some kind of log prob (who cares, we're just adding) + double stateless_cost(WordID *rbegin,WordID *rend) { + UNIDBG("p("); + double sum=0; + for (;rend>rbegin;--rend) { + sum+=clamp(WordProb(rend[-1],rend)); + UNIDBG(" "<<TD::Convert(rend[-1])); + } + UNIDBG(")="<<sum<<endl); + return sum; + } + + //TODO: this would be a fine rule heuristic (for reordering hyperedges prior to rescoring. for now you can just use a same-lm-file -o 1 prelm-rescore :( + double stateless_cost(TRule const& rule) { + //TODO: make sure this is correct. + int len = rule.ELength(); // use a gap for each variable + buffer_.resize(len + 1); + WordID * const rend=&buffer_[0]+len; + *rend=kNONE; + WordID *r=rend; // append by *--r = x + const vector<WordID>& e = rule.e(); + //SRILM is reverse order null terminated + //let's write down each phrase in reverse order and score it (note: we could lay them out consecutively then score them (we allocated enough buffer for that), but we won't actually use the whole buffer that way, since it wastes L1 cache. + double sum=0.; + for (unsigned j = 0; j < e.size(); ++j) { + if (e[j] < 1) { // variable + sum+=stateless_cost(r,rend); + r=rend; + } else { // terminal + *--r=e[j]; + } + } + // last phrase (if any) + return sum+stateless_cost(r,rend); + } + + //NOTE: this is where the scoring of words happens (heuristic happens in EstimateProb) + double LookupWords(const TRule& rule, const vector<const void*>& ant_states, void* vstate) { + if (unigram) + return stateless_cost(rule); + int len = rule.ELength() - rule.Arity(); + for (int i = 0; i < ant_states.size(); ++i) + len += StateSize(ant_states[i]); + buffer_.resize(len + 1); + buffer_[len] = kNONE; + int i = len - 1; + const vector<WordID>& e = rule.e(); + for (int j = 0; j < e.size(); ++j) { + if (e[j] < 1) { + const int* astate = reinterpret_cast<const int*>(ant_states[-e[j]]); + int slen = StateSize(astate); + for (int k = 0; k < slen; ++k) + buffer_[i--] = astate[k]; + } else { + buffer_[i--] = e[j]; + } + } + + double sum = 0.0; + int* remnant = reinterpret_cast<int*>(vstate); + int j = 0; + i = len - 1; + int edge = len; + + while (i >= 0) { + if (buffer_[i] == kSTAR) { + edge = i; + } else if (edge-i >= order_) { + sum += LookupProbForBufferContents(i); + } else if (edge == len && remnant) { + remnant[j++] = buffer_[i]; + } + --i; + } + if (!remnant) return sum; + + if (edge != len || len >= order_) { + remnant[j++] = kSTAR; + if (order_-1 < edge) edge = order_-1; + for (int i = edge-1; i >= 0; --i) + remnant[j++] = buffer_[i]; + } + + SetStateSize(j, vstate); + return sum; + } + +private: +public: + + protected: + vector<WordID> buffer_; + public: + WordID kSTART; + WordID kSTOP; + WordID kUNKNOWN; + WordID kNONE; + WordID kSTAR; + bool unigram; +#endif + + lm::WordIndex MapWord(WordID w) const { + if (w >= map_.size()) + return 0; + else + return map_[w]; + } + + public: + KLanguageModelImpl(const std::string& param) { + lm::ngram::Config conf; + VMapper vm(&map_); + conf.enumerate_vocab = &vm; + ngram_ = new lm::ngram::Model(param.c_str(), conf); + cerr << "Loaded " << order_ << "-gram KLM from " << param << endl; + order_ = ngram_->Order(); + state_size_ = ngram_->StateSize() + 1 + (order_-1) * sizeof(int); + } + + ~KLanguageModelImpl() { + delete ngram_; + } + + const int ReserveStateSize() const { return state_size_; } + + private: + lm::ngram::Model* ngram_; + int order_; + int state_size_; + vector<lm::WordIndex> map_; + +}; + +KLanguageModel::KLanguageModel(const string& param) { + pimpl_ = new KLanguageModelImpl(param); + fid_ = FD::Convert("LanguageModel"); + SetStateSize(pimpl_->ReserveStateSize()); +} + +Features KLanguageModel::features() const { + return single_feature(fid_); +} + +KLanguageModel::~KLanguageModel() { + delete pimpl_; +} + +void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, + const Hypergraph::Edge& edge, + const vector<const void*>& ant_states, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* state) const { +// features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state)); +// estimated_features->set_value(fid_, imp().EstimateProb(state)); +} + +void KLanguageModel::FinalTraversalFeatures(const void* ant_state, + SparseVector<double>* features) const { +// features->set_value(fid_, imp().FinalTraversalCost(ant_state)); +} + diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h new file mode 100644 index 00000000..0569286f --- /dev/null +++ b/decoder/ff_klm.h @@ -0,0 +1,32 @@ +#ifndef _KLM_FF_H_ +#define _KLM_FF_H_ + +#include <vector> +#include <string> + +#include "ff.h" + +struct KLanguageModelImpl; + +class KLanguageModel : public FeatureFunction { + public: + // param = "filename.lm [-o n]" + KLanguageModel(const std::string& param); + ~KLanguageModel(); + virtual void FinalTraversalFeatures(const void* context, + SparseVector<double>* features) const; + static std::string usage(bool param,bool verbose); + Features features() const; + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* out_context) const; + private: + int fid_; // conceptually const; mutable only to simplify constructor + KLanguageModelImpl* pimpl_; +}; + +#endif |