summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
Diffstat (limited to 'decoder')
-rw-r--r--decoder/ff_klm.cc297
1 files changed, 107 insertions, 190 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index 5888c4a3..092c07b0 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -1,5 +1,7 @@
#include "ff_klm.h"
+#include <cstring>
+
#include "hg.h"
#include "tdict.h"
#include "lm/model.hh"
@@ -24,217 +26,116 @@ struct VMapper : public lm::ngram::EnumerateVocab {
};
class KLanguageModelImpl {
- inline int StateSize(const void* state) const {
- return *(static_cast<const char*>(state) + state_size_);
- }
-
- inline void SetStateSize(int size, void* state) const {
- *(static_cast<char*>(state) + state_size_) = size;
- }
-#if 0
- virtual double WordProb(WordID word, WordID const* context) {
- return ngram_.wordProb(word, (VocabIndex*)context);
+ // returns the number of unscored words at the left edge of a span
+ inline int UnscoredSize(const void* state) const {
+ return *(static_cast<const char*>(state) + unscored_size_offset_);
}
- // may be shorter than actual null-terminated length. context must be null terminated. len is just to save effort for subclasses that don't support contextID
- virtual int ContextSize(WordID const* context,int len) {
- unsigned ret;
- ngram_.contextID((VocabIndex*)context,ret);
- return ret;
- }
- virtual double ContextBOW(WordID const* context,int shortened_len) {
- return ngram_.contextBOW((VocabIndex*)context,shortened_len);
+ inline void SetUnscoredSize(int size, void* state) const {
+ *(static_cast<char*>(state) + unscored_size_offset_) = size;
}
- inline double LookupProbForBufferContents(int i) {
-// int k = i; cerr << "P("; while(buffer_[k] > 0) { std::cerr << TD::Convert(buffer_[k++]) << " "; }
- double p = WordProb(buffer_[i], &buffer_[i+1]);
- if (p < floor_) p = floor_;
-// cerr << ")=" << p << endl;
- return p;
+ static inline const lm::ngram::Model::State& RemnantLMState(const void* state) {
+ return *static_cast<const lm::ngram::Model::State*>(state);
}
- string DebugStateToString(const void* state) const {
- int len = StateSize(state);
- const int* astate = reinterpret_cast<const int*>(state);
- string res = "[";
- for (int i = 0; i < len; ++i) {
- res += " ";
- res += TD::Convert(astate[i]);
- }
- res += " ]";
- return res;
+ inline void SetRemnantLMState(const lm::ngram::Model::State& lmstate, void* state) const {
+ // if we were clever, we could use the memory pointed to by state to do all
+ // the work, avoiding this copy
+ memcpy(state, &lmstate, ngram_->StateSize());
}
- inline double ProbNoRemnant(int i, int len) {
- int edge = len;
- bool flag = true;
- double sum = 0.0;
- while (i >= 0) {
- if (buffer_[i] == kSTAR) {
- edge = i;
- flag = false;
- } else if (buffer_[i] <= 0) {
- edge = i;
- flag = true;
- } else {
- if ((edge-i >= order_) || (flag && !(i == (len-1) && buffer_[i] == kSTART)))
- sum += LookupProbForBufferContents(i);
- }
- --i;
- }
- return sum;
+ lm::WordIndex IthUnscoredWord(int i, const void* state) const {
+ const lm::WordIndex* const mem = reinterpret_cast<const lm::WordIndex*>(static_cast<const char*>(state) + unscored_words_offset_);
+ return mem[i];
}
- double EstimateProb(const vector<WordID>& phrase) {
- int len = phrase.size();
- buffer_.resize(len + 1);
- buffer_[len] = kNONE;
- int i = len - 1;
- for (int j = 0; j < len; ++j,--i)
- buffer_[i] = phrase[j];
- return ProbNoRemnant(len - 1, len);
+ void SetIthUnscoredWord(int i, lm::WordIndex index, void *state) const {
+ lm::WordIndex* mem = reinterpret_cast<lm::WordIndex*>(static_cast<char*>(state) + unscored_words_offset_);
+ mem[i] = index;
}
- //TODO: make sure this doesn't get used in FinalTraversal, or if it does, that it causes no harm.
-
- //TODO: use stateless_cost instead of ProbNoRemnant, check left words only. for items w/ fewer words than ctx len, how are they represented? kNONE padded?
-
- //Vocab_None is (unsigned)-1 in srilm, same as kNONE. in srilm (-1), or that SRILM otherwise interprets -1 as a terminator and not a word
- double EstimateProb(const void* state) {
- if (unigram) return 0.;
- int len = StateSize(state);
- // << "residual len: " << len << endl;
- buffer_.resize(len + 1);
- buffer_[len] = kNONE;
- const int* astate = reinterpret_cast<const WordID*>(state);
- int i = len - 1;
- for (int j = 0; j < len; ++j,--i)
- buffer_[i] = astate[j];
- return ProbNoRemnant(len - 1, len);
- }
-
- //FIXME: this assumes no target words on final unary -> goal rule. is that ok?
- // for <s> (n-1 left words) and (n-1 right words) </s>
- double FinalTraversalCost(const void* state) {
- if (unigram) return 0.;
- int slen = StateSize(state);
- int len = slen + 2;
- // cerr << "residual len: " << len << endl;
- buffer_.resize(len + 1);
- buffer_[len] = kNONE;
- buffer_[len-1] = kSTART;
- const int* astate = reinterpret_cast<const WordID*>(state);
- int i = len - 2;
- for (int j = 0; j < slen; ++j,--i)
- buffer_[i] = astate[j];
- buffer_[i] = kSTOP;
- assert(i == 0);
- return ProbNoRemnant(len - 1, len);
- }
-
- /// just how SRILM likes it: [rbegin,rend) is a phrase in reverse word order and null terminated so *rend=kNONE. return unigram score for rend[-1] plus
- /// cost returned is some kind of log prob (who cares, we're just adding)
- double stateless_cost(WordID *rbegin,WordID *rend) {
- UNIDBG("p(");
- double sum=0;
- for (;rend>rbegin;--rend) {
- sum+=clamp(WordProb(rend[-1],rend));
- UNIDBG(" "<<TD::Convert(rend[-1]));
- }
- UNIDBG(")="<<sum<<endl);
- return sum;
+ bool HasFullContext(const void *state) const {
+ return *(static_cast<const char*>(state) + is_complete_offset_);
}
- //TODO: this would be a fine rule heuristic (for reordering hyperedges prior to rescoring. for now you can just use a same-lm-file -o 1 prelm-rescore :(
- double stateless_cost(TRule const& rule) {
- //TODO: make sure this is correct.
- int len = rule.ELength(); // use a gap for each variable
- buffer_.resize(len + 1);
- WordID * const rend=&buffer_[0]+len;
- *rend=kNONE;
- WordID *r=rend; // append by *--r = x
- const vector<WordID>& e = rule.e();
- //SRILM is reverse order null terminated
- //let's write down each phrase in reverse order and score it (note: we could lay them out consecutively then score them (we allocated enough buffer for that), but we won't actually use the whole buffer that way, since it wastes L1 cache.
- double sum=0.;
- for (unsigned j = 0; j < e.size(); ++j) {
- if (e[j] < 1) { // variable
- sum+=stateless_cost(r,rend);
- r=rend;
- } else { // terminal
- *--r=e[j];
- }
- }
- // last phrase (if any)
- return sum+stateless_cost(r,rend);
+ void SetHasFullContext(bool flag, void *state) const {
+ *(static_cast<char*>(state) + is_complete_offset_) = flag;
}
- //NOTE: this is where the scoring of words happens (heuristic happens in EstimateProb)
- double LookupWords(const TRule& rule, const vector<const void*>& ant_states, void* vstate) {
- if (unigram)
- return stateless_cost(rule);
+ public:
+ double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, void* remnant) {
+ double sum = 0.0;
+ double est_sum = 0.0;
int len = rule.ELength() - rule.Arity();
- for (int i = 0; i < ant_states.size(); ++i)
- len += StateSize(ant_states[i]);
- buffer_.resize(len + 1);
- buffer_[len] = kNONE;
- int i = len - 1;
+ int num_scored = 0;
+ int num_estimated = 0;
+ lm::ngram::Model::State state = ngram_->NullContextState();
const vector<WordID>& e = rule.e();
+ bool context_complete = false;
for (int j = 0; j < e.size(); ++j) {
if (e[j] < 1) {
- const int* astate = reinterpret_cast<const int*>(ant_states[-e[j]]);
- int slen = StateSize(astate);
- for (int k = 0; k < slen; ++k)
- buffer_[i--] = astate[k];
+ const void* astate = (ant_states[-e[j]]);
+ int unscored_ant_len = UnscoredSize(astate);
+ for (int k = 0; k < unscored_ant_len; ++k) {
+ const lm::WordIndex cur_word = IthUnscoredWord(k, astate);
+ const lm::ngram::Model::State scopy(state);
+ const double p = ngram_->Score(scopy, cur_word, state);
+ ++num_scored;
+ if (!context_complete) {
+ if (num_scored >= order_) context_complete = true;
+ }
+ if (context_complete) {
+ sum += p;
+ } else {
+ if (remnant)
+ SetIthUnscoredWord(num_estimated, cur_word, remnant);
+ ++num_estimated;
+ est_sum += p;
+ }
+ }
+ if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007
+ state = RemnantLMState(astate);
+ context_complete = true;
+ }
} else {
- buffer_[i--] = e[j];
+ const lm::WordIndex cur_word = MapWord(e[j]);
+ const lm::ngram::Model::State scopy(state);
+ const double p = ngram_->Score(scopy, cur_word, state);
+ ++num_scored;
+ if (!context_complete) {
+ if (num_scored >= order_) context_complete = true;
+ }
+ if (context_complete) {
+ sum += p;
+ } else {
+ if (remnant)
+ SetIthUnscoredWord(num_estimated, cur_word, remnant);
+ ++num_estimated;
+ est_sum += p;
+ }
}
}
-
- double sum = 0.0;
- int* remnant = reinterpret_cast<int*>(vstate);
- int j = 0;
- i = len - 1;
- int edge = len;
-
- while (i >= 0) {
- if (buffer_[i] == kSTAR) {
- edge = i;
- } else if (edge-i >= order_) {
- sum += LookupProbForBufferContents(i);
- } else if (edge == len && remnant) {
- remnant[j++] = buffer_[i];
- }
- --i;
+ if (pest_sum) *pest_sum = est_sum;
+ if (remnant) {
+ state.ZeroRemaining();
+ SetRemnantLMState(state, remnant);
+ SetUnscoredSize(num_estimated, remnant);
+ SetHasFullContext(context_complete || (num_scored >= order_), remnant);
}
- if (!remnant) return sum;
-
- if (edge != len || len >= order_) {
- remnant[j++] = kSTAR;
- if (order_-1 < edge) edge = order_-1;
- for (int i = edge-1; i >= 0; --i)
- remnant[j++] = buffer_[i];
- }
-
- SetStateSize(j, vstate);
return sum;
}
-private:
-public:
-
- protected:
- vector<WordID> buffer_;
- public:
- WordID kSTART;
- WordID kSTOP;
- WordID kUNKNOWN;
- WordID kNONE;
- WordID kSTAR;
- bool unigram;
-#endif
+ //FIXME: this assumes no target words on final unary -> goal rule. is that ok?
+ // for <s> (n-1 left words) and (n-1 right words) </s>
+ double FinalTraversalCost(const void* state) {
+ SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_);
+ SetHasFullContext(1, dummy_state_);
+ SetUnscoredSize(0, dummy_state_);
+ dummy_ants_[1] = state;
+ return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL);
+ }
lm::WordIndex MapWord(WordID w) const {
if (w >= map_.size())
@@ -249,23 +150,38 @@ public:
VMapper vm(&map_);
conf.enumerate_vocab = &vm;
ngram_ = new lm::ngram::Model(param.c_str(), conf);
- cerr << "Loaded " << order_ << "-gram KLM from " << param << endl;
order_ = ngram_->Order();
- state_size_ = ngram_->StateSize() + 1 + (order_-1) * sizeof(int);
+ cerr << "Loaded " << order_ << "-gram KLM from " << param << " (MapSize=" << map_.size() << ")\n";
+ state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex);
+ unscored_size_offset_ = ngram_->StateSize();
+ is_complete_offset_ = unscored_size_offset_ + 1;
+ unscored_words_offset_ = is_complete_offset_ + 1;
+
+ // special handling of beginning / ending sentence markers
+ dummy_state_ = new char[state_size_];
+ dummy_ants_.push_back(dummy_state_);
+ dummy_ants_.push_back(NULL);
+ dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0"));
}
~KLanguageModelImpl() {
delete ngram_;
+ delete[] dummy_state_;
}
- const int ReserveStateSize() const { return state_size_; }
+ int ReserveStateSize() const { return state_size_; }
private:
lm::ngram::Model* ngram_;
int order_;
int state_size_;
+ int unscored_size_offset_;
+ int is_complete_offset_;
+ int unscored_words_offset_;
+ char* dummy_state_;
+ vector<const void*> dummy_ants_;
vector<lm::WordIndex> map_;
-
+ TRulePtr dummy_rule_;
};
KLanguageModel::KLanguageModel(const string& param) {
@@ -288,12 +204,13 @@ void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,
SparseVector<double>* features,
SparseVector<double>* estimated_features,
void* state) const {
-// features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state));
-// estimated_features->set_value(fid_, imp().EstimateProb(state));
+ double est = 0;
+ features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &est, state));
+ estimated_features->set_value(fid_, est);
}
void KLanguageModel::FinalTraversalFeatures(const void* ant_state,
SparseVector<double>* features) const {
-// features->set_value(fid_, imp().FinalTraversalCost(ant_state));
+ features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state));
}