summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
Diffstat (limited to 'decoder')
-rw-r--r--decoder/Makefile.am6
-rw-r--r--decoder/cdec_ff.cc2
-rw-r--r--decoder/ff_klm.cc299
-rw-r--r--decoder/ff_klm.h32
4 files changed, 337 insertions, 2 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index da0e5987..ea01a4da 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -12,7 +12,7 @@ TESTS = trule_test ff_test parser_test grammar_test hg_test cfg_test
endif
cdec_SOURCES = cdec.cc
-cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz
+cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
cfg_test_SOURCES = cfg_test.cc
cfg_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz
@@ -26,7 +26,8 @@ hg_test_SOURCES = hg_test.cc
hg_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz
trule_test_SOURCES = trule_test.cc
trule_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz
-AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. -I../mteval -I../utils
+
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. -I../mteval -I../utils -I../klm
rule_lexer.cc: rule_lexer.l
$(LEX) -s -CF -8 -o$@ $<
@@ -58,6 +59,7 @@ libcdec_a_SOURCES = \
trule.cc \
ff.cc \
ff_lm.cc \
+ ff_klm.cc \
ff_ruleshape.cc \
ff_wordalign.cc \
ff_csplit.cc \
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index ca5334e9..09a19a7b 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -2,6 +2,7 @@
#include "ff.h"
#include "ff_lm.h"
+#include "ff_klm.h"
#include "ff_csplit.h"
#include "ff_wordalign.h"
#include "ff_tagger.h"
@@ -29,6 +30,7 @@ void register_feature_functions() {
RegisterFsaDynToFF<SameFirstLetter>();
RegisterFF<LanguageModel>();
+ RegisterFF<KLanguageModel>();
RegisterFF<WordPenalty>();
RegisterFF<SourceWordPenalty>();
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
new file mode 100644
index 00000000..5888c4a3
--- /dev/null
+++ b/decoder/ff_klm.cc
@@ -0,0 +1,299 @@
+#include "ff_klm.h"
+
+#include "hg.h"
+#include "tdict.h"
+#include "lm/model.hh"
+#include "lm/enumerate_vocab.hh"
+
+using namespace std;
+
+string KLanguageModel::usage(bool param,bool verbose) {
+ return "KLanguageModel";
+}
+
+struct VMapper : public lm::ngram::EnumerateVocab {
+ VMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); }
+ void Add(lm::WordIndex index, const StringPiece &str) {
+ const WordID cdec_id = TD::Convert(str.as_string());
+ if (cdec_id >= out_->size())
+ out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN);
+ (*out_)[cdec_id] = index;
+ }
+ vector<lm::WordIndex>* out_;
+ const lm::WordIndex kLM_UNKNOWN_TOKEN;
+};
+
+class KLanguageModelImpl {
+ inline int StateSize(const void* state) const {
+ return *(static_cast<const char*>(state) + state_size_);
+ }
+
+ inline void SetStateSize(int size, void* state) const {
+ *(static_cast<char*>(state) + state_size_) = size;
+ }
+
+#if 0
+ virtual double WordProb(WordID word, WordID const* context) {
+ return ngram_.wordProb(word, (VocabIndex*)context);
+ }
+
+ // may be shorter than actual null-terminated length. context must be null terminated. len is just to save effort for subclasses that don't support contextID
+ virtual int ContextSize(WordID const* context,int len) {
+ unsigned ret;
+ ngram_.contextID((VocabIndex*)context,ret);
+ return ret;
+ }
+ virtual double ContextBOW(WordID const* context,int shortened_len) {
+ return ngram_.contextBOW((VocabIndex*)context,shortened_len);
+ }
+
+ inline double LookupProbForBufferContents(int i) {
+// int k = i; cerr << "P("; while(buffer_[k] > 0) { std::cerr << TD::Convert(buffer_[k++]) << " "; }
+ double p = WordProb(buffer_[i], &buffer_[i+1]);
+ if (p < floor_) p = floor_;
+// cerr << ")=" << p << endl;
+ return p;
+ }
+
+ string DebugStateToString(const void* state) const {
+ int len = StateSize(state);
+ const int* astate = reinterpret_cast<const int*>(state);
+ string res = "[";
+ for (int i = 0; i < len; ++i) {
+ res += " ";
+ res += TD::Convert(astate[i]);
+ }
+ res += " ]";
+ return res;
+ }
+
+ inline double ProbNoRemnant(int i, int len) {
+ int edge = len;
+ bool flag = true;
+ double sum = 0.0;
+ while (i >= 0) {
+ if (buffer_[i] == kSTAR) {
+ edge = i;
+ flag = false;
+ } else if (buffer_[i] <= 0) {
+ edge = i;
+ flag = true;
+ } else {
+ if ((edge-i >= order_) || (flag && !(i == (len-1) && buffer_[i] == kSTART)))
+ sum += LookupProbForBufferContents(i);
+ }
+ --i;
+ }
+ return sum;
+ }
+
+ double EstimateProb(const vector<WordID>& phrase) {
+ int len = phrase.size();
+ buffer_.resize(len + 1);
+ buffer_[len] = kNONE;
+ int i = len - 1;
+ for (int j = 0; j < len; ++j,--i)
+ buffer_[i] = phrase[j];
+ return ProbNoRemnant(len - 1, len);
+ }
+
+ //TODO: make sure this doesn't get used in FinalTraversal, or if it does, that it causes no harm.
+
+ //TODO: use stateless_cost instead of ProbNoRemnant, check left words only. for items w/ fewer words than ctx len, how are they represented? kNONE padded?
+
+ //Vocab_None is (unsigned)-1 in srilm, same as kNONE. in srilm (-1), or that SRILM otherwise interprets -1 as a terminator and not a word
+ double EstimateProb(const void* state) {
+ if (unigram) return 0.;
+ int len = StateSize(state);
+ // << "residual len: " << len << endl;
+ buffer_.resize(len + 1);
+ buffer_[len] = kNONE;
+ const int* astate = reinterpret_cast<const WordID*>(state);
+ int i = len - 1;
+ for (int j = 0; j < len; ++j,--i)
+ buffer_[i] = astate[j];
+ return ProbNoRemnant(len - 1, len);
+ }
+
+ //FIXME: this assumes no target words on final unary -> goal rule. is that ok?
+ // for <s> (n-1 left words) and (n-1 right words) </s>
+ double FinalTraversalCost(const void* state) {
+ if (unigram) return 0.;
+ int slen = StateSize(state);
+ int len = slen + 2;
+ // cerr << "residual len: " << len << endl;
+ buffer_.resize(len + 1);
+ buffer_[len] = kNONE;
+ buffer_[len-1] = kSTART;
+ const int* astate = reinterpret_cast<const WordID*>(state);
+ int i = len - 2;
+ for (int j = 0; j < slen; ++j,--i)
+ buffer_[i] = astate[j];
+ buffer_[i] = kSTOP;
+ assert(i == 0);
+ return ProbNoRemnant(len - 1, len);
+ }
+
+ /// just how SRILM likes it: [rbegin,rend) is a phrase in reverse word order and null terminated so *rend=kNONE. return unigram score for rend[-1] plus
+ /// cost returned is some kind of log prob (who cares, we're just adding)
+ double stateless_cost(WordID *rbegin,WordID *rend) {
+ UNIDBG("p(");
+ double sum=0;
+ for (;rend>rbegin;--rend) {
+ sum+=clamp(WordProb(rend[-1],rend));
+ UNIDBG(" "<<TD::Convert(rend[-1]));
+ }
+ UNIDBG(")="<<sum<<endl);
+ return sum;
+ }
+
+ //TODO: this would be a fine rule heuristic (for reordering hyperedges prior to rescoring. for now you can just use a same-lm-file -o 1 prelm-rescore :(
+ double stateless_cost(TRule const& rule) {
+ //TODO: make sure this is correct.
+ int len = rule.ELength(); // use a gap for each variable
+ buffer_.resize(len + 1);
+ WordID * const rend=&buffer_[0]+len;
+ *rend=kNONE;
+ WordID *r=rend; // append by *--r = x
+ const vector<WordID>& e = rule.e();
+ //SRILM is reverse order null terminated
+ //let's write down each phrase in reverse order and score it (note: we could lay them out consecutively then score them (we allocated enough buffer for that), but we won't actually use the whole buffer that way, since it wastes L1 cache.
+ double sum=0.;
+ for (unsigned j = 0; j < e.size(); ++j) {
+ if (e[j] < 1) { // variable
+ sum+=stateless_cost(r,rend);
+ r=rend;
+ } else { // terminal
+ *--r=e[j];
+ }
+ }
+ // last phrase (if any)
+ return sum+stateless_cost(r,rend);
+ }
+
+ //NOTE: this is where the scoring of words happens (heuristic happens in EstimateProb)
+ double LookupWords(const TRule& rule, const vector<const void*>& ant_states, void* vstate) {
+ if (unigram)
+ return stateless_cost(rule);
+ int len = rule.ELength() - rule.Arity();
+ for (int i = 0; i < ant_states.size(); ++i)
+ len += StateSize(ant_states[i]);
+ buffer_.resize(len + 1);
+ buffer_[len] = kNONE;
+ int i = len - 1;
+ const vector<WordID>& e = rule.e();
+ for (int j = 0; j < e.size(); ++j) {
+ if (e[j] < 1) {
+ const int* astate = reinterpret_cast<const int*>(ant_states[-e[j]]);
+ int slen = StateSize(astate);
+ for (int k = 0; k < slen; ++k)
+ buffer_[i--] = astate[k];
+ } else {
+ buffer_[i--] = e[j];
+ }
+ }
+
+ double sum = 0.0;
+ int* remnant = reinterpret_cast<int*>(vstate);
+ int j = 0;
+ i = len - 1;
+ int edge = len;
+
+ while (i >= 0) {
+ if (buffer_[i] == kSTAR) {
+ edge = i;
+ } else if (edge-i >= order_) {
+ sum += LookupProbForBufferContents(i);
+ } else if (edge == len && remnant) {
+ remnant[j++] = buffer_[i];
+ }
+ --i;
+ }
+ if (!remnant) return sum;
+
+ if (edge != len || len >= order_) {
+ remnant[j++] = kSTAR;
+ if (order_-1 < edge) edge = order_-1;
+ for (int i = edge-1; i >= 0; --i)
+ remnant[j++] = buffer_[i];
+ }
+
+ SetStateSize(j, vstate);
+ return sum;
+ }
+
+private:
+public:
+
+ protected:
+ vector<WordID> buffer_;
+ public:
+ WordID kSTART;
+ WordID kSTOP;
+ WordID kUNKNOWN;
+ WordID kNONE;
+ WordID kSTAR;
+ bool unigram;
+#endif
+
+ lm::WordIndex MapWord(WordID w) const {
+ if (w >= map_.size())
+ return 0;
+ else
+ return map_[w];
+ }
+
+ public:
+ KLanguageModelImpl(const std::string& param) {
+ lm::ngram::Config conf;
+ VMapper vm(&map_);
+ conf.enumerate_vocab = &vm;
+ ngram_ = new lm::ngram::Model(param.c_str(), conf);
+ cerr << "Loaded " << order_ << "-gram KLM from " << param << endl;
+ order_ = ngram_->Order();
+ state_size_ = ngram_->StateSize() + 1 + (order_-1) * sizeof(int);
+ }
+
+ ~KLanguageModelImpl() {
+ delete ngram_;
+ }
+
+ const int ReserveStateSize() const { return state_size_; }
+
+ private:
+ lm::ngram::Model* ngram_;
+ int order_;
+ int state_size_;
+ vector<lm::WordIndex> map_;
+
+};
+
+KLanguageModel::KLanguageModel(const string& param) {
+ pimpl_ = new KLanguageModelImpl(param);
+ fid_ = FD::Convert("LanguageModel");
+ SetStateSize(pimpl_->ReserveStateSize());
+}
+
+Features KLanguageModel::features() const {
+ return single_feature(fid_);
+}
+
+KLanguageModel::~KLanguageModel() {
+ delete pimpl_;
+}
+
+void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+// features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state));
+// estimated_features->set_value(fid_, imp().EstimateProb(state));
+}
+
+void KLanguageModel::FinalTraversalFeatures(const void* ant_state,
+ SparseVector<double>* features) const {
+// features->set_value(fid_, imp().FinalTraversalCost(ant_state));
+}
+
diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h
new file mode 100644
index 00000000..0569286f
--- /dev/null
+++ b/decoder/ff_klm.h
@@ -0,0 +1,32 @@
+#ifndef _KLM_FF_H_
+#define _KLM_FF_H_
+
+#include <vector>
+#include <string>
+
+#include "ff.h"
+
+struct KLanguageModelImpl;
+
+class KLanguageModel : public FeatureFunction {
+ public:
+ // param = "filename.lm [-o n]"
+ KLanguageModel(const std::string& param);
+ ~KLanguageModel();
+ virtual void FinalTraversalFeatures(const void* context,
+ SparseVector<double>* features) const;
+ static std::string usage(bool param,bool verbose);
+ Features features() const;
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const;
+ private:
+ int fid_; // conceptually const; mutable only to simplify constructor
+ KLanguageModelImpl* pimpl_;
+};
+
+#endif