diff options
-rw-r--r-- | decoder/ff_klm.cc | 60 |
1 files changed, 47 insertions, 13 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 5049f156..5e590053 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -75,13 +75,19 @@ class KLanguageModelImpl { const vector<WordID>& e = rule.e(); bool context_complete = false; for (int j = 0; j < e.size(); ++j) { - if (e[j] < 1) { + if (e[j] < 1) { // handle non-terminal substitution const void* astate = (ant_states[-e[j]]); int unscored_ant_len = UnscoredSize(astate); for (int k = 0; k < unscored_ant_len; ++k) { const lm::WordIndex cur_word = IthUnscoredWord(k, astate); - const lm::ngram::State scopy(state); - const double p = ngram_->Score(scopy, cur_word, state); + double p = 0; + if (cur_word == kSOS_) { + if (state.ValidLength() > 0) { p = -100; cerr << rule << endl; } + state = ngram_->BeginSentenceState(); + } else { + const lm::ngram::State scopy(state); + p = ngram_->Score(scopy, cur_word, state); + } ++num_scored; if (!context_complete) { if (num_scored >= order_) context_complete = true; @@ -99,10 +105,16 @@ class KLanguageModelImpl { state = RemnantLMState(astate); context_complete = true; } - } else { + } else { // handle terminal const lm::WordIndex cur_word = MapWord(e[j]); - const lm::ngram::State scopy(state); - const double p = ngram_->Score(scopy, cur_word, state); + double p = 0; + if (cur_word == kSOS_) { + if (state.ValidLength() > 0) p = -100; + state = ngram_->BeginSentenceState(); + } else { + const lm::ngram::State scopy(state); + p = ngram_->Score(scopy, cur_word, state); + } ++num_scored; if (!context_complete) { if (num_scored >= order_) context_complete = true; @@ -130,11 +142,16 @@ class KLanguageModelImpl { //FIXME: this assumes no target words on final unary -> goal rule. is that ok? // for <s> (n-1 left words) and (n-1 right words) </s> double FinalTraversalCost(const void* state) { - SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_); - SetHasFullContext(1, dummy_state_); - SetUnscoredSize(0, dummy_state_); - dummy_ants_[1] = state; - return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL); + if (add_sos_eos_) { + SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_); + SetHasFullContext(1, dummy_state_); + SetUnscoredSize(0, dummy_state_); + dummy_ants_[1] = state; + return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL); + } else { + // TODO, figure out whether spans are correct + return 0; + } } lm::WordIndex MapWord(WordID w) const { @@ -146,12 +163,18 @@ class KLanguageModelImpl { public: KLanguageModelImpl(const std::string& param) { + add_sos_eos_ = true; + string fname = param; + if (param.find("-x ") == 0) { + add_sos_eos_ = false; + fname = param.substr(3); + } lm::ngram::Config conf; VMapper vm(&map_); conf.enumerate_vocab = &vm; - ngram_ = new Model(param.c_str(), conf); + ngram_ = new Model(fname.c_str(), conf); order_ = ngram_->Order(); - cerr << "Loaded " << order_ << "-gram KLM from " << param << " (MapSize=" << map_.size() << ")\n"; + cerr << "Loaded " << order_ << "-gram KLM from " << fname << " (MapSize=" << map_.size() << ")\n"; state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex); unscored_size_offset_ = ngram_->StateSize(); is_complete_offset_ = unscored_size_offset_ + 1; @@ -162,6 +185,9 @@ class KLanguageModelImpl { dummy_ants_.push_back(dummy_state_); dummy_ants_.push_back(NULL); dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0")); + kSOS_ = MapWord(TD::Convert("<s>")); + assert(kSOS_ > 0); + kEOS_ = MapWord(TD::Convert("</s>")); } ~KLanguageModelImpl() { @@ -172,7 +198,15 @@ class KLanguageModelImpl { int ReserveStateSize() const { return state_size_; } private: + lm::WordIndex kSOS_; // <s> - requires special handling. + lm::WordIndex kEOS_; // </s> Model* ngram_; + bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s> + // if this is true, FinalTransitionFeatures will "add" <s> and </s> + // if false, FinalTransitionFeatures will score anything with the + // markers in the right place (i.e., the beginning and end of + // the sentence) with 0, and anything else with -100 + int order_; int state_size_; int unscored_size_offset_; |