From 0c9fe1dd72fb2321a5652e0ee66d1c897a3d9f80 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 1 Mar 2011 03:05:25 -0500 Subject: support explicit sentence boundary markers with cdec --- decoder/ff_klm.cc | 60 +++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 47 insertions(+), 13 deletions(-) (limited to 'decoder') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 5049f156..5e590053 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -75,13 +75,19 @@ class KLanguageModelImpl { const vector& e = rule.e(); bool context_complete = false; for (int j = 0; j < e.size(); ++j) { - if (e[j] < 1) { + if (e[j] < 1) { // handle non-terminal substitution const void* astate = (ant_states[-e[j]]); int unscored_ant_len = UnscoredSize(astate); for (int k = 0; k < unscored_ant_len; ++k) { const lm::WordIndex cur_word = IthUnscoredWord(k, astate); - const lm::ngram::State scopy(state); - const double p = ngram_->Score(scopy, cur_word, state); + double p = 0; + if (cur_word == kSOS_) { + if (state.ValidLength() > 0) { p = -100; cerr << rule << endl; } + state = ngram_->BeginSentenceState(); + } else { + const lm::ngram::State scopy(state); + p = ngram_->Score(scopy, cur_word, state); + } ++num_scored; if (!context_complete) { if (num_scored >= order_) context_complete = true; @@ -99,10 +105,16 @@ class KLanguageModelImpl { state = RemnantLMState(astate); context_complete = true; } - } else { + } else { // handle terminal const lm::WordIndex cur_word = MapWord(e[j]); - const lm::ngram::State scopy(state); - const double p = ngram_->Score(scopy, cur_word, state); + double p = 0; + if (cur_word == kSOS_) { + if (state.ValidLength() > 0) p = -100; + state = ngram_->BeginSentenceState(); + } else { + const lm::ngram::State scopy(state); + p = ngram_->Score(scopy, cur_word, state); + } ++num_scored; if (!context_complete) { if (num_scored >= order_) context_complete = true; @@ -130,11 +142,16 @@ class KLanguageModelImpl { //FIXME: this assumes no target words on final unary -> goal rule. is that ok? // for (n-1 left words) and (n-1 right words) double FinalTraversalCost(const void* state) { - SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_); - SetHasFullContext(1, dummy_state_); - SetUnscoredSize(0, dummy_state_); - dummy_ants_[1] = state; - return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL); + if (add_sos_eos_) { + SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_); + SetHasFullContext(1, dummy_state_); + SetUnscoredSize(0, dummy_state_); + dummy_ants_[1] = state; + return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL); + } else { + // TODO, figure out whether spans are correct + return 0; + } } lm::WordIndex MapWord(WordID w) const { @@ -146,12 +163,18 @@ class KLanguageModelImpl { public: KLanguageModelImpl(const std::string& param) { + add_sos_eos_ = true; + string fname = param; + if (param.find("-x ") == 0) { + add_sos_eos_ = false; + fname = param.substr(3); + } lm::ngram::Config conf; VMapper vm(&map_); conf.enumerate_vocab = &vm; - ngram_ = new Model(param.c_str(), conf); + ngram_ = new Model(fname.c_str(), conf); order_ = ngram_->Order(); - cerr << "Loaded " << order_ << "-gram KLM from " << param << " (MapSize=" << map_.size() << ")\n"; + cerr << "Loaded " << order_ << "-gram KLM from " << fname << " (MapSize=" << map_.size() << ")\n"; state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex); unscored_size_offset_ = ngram_->StateSize(); is_complete_offset_ = unscored_size_offset_ + 1; @@ -162,6 +185,9 @@ class KLanguageModelImpl { dummy_ants_.push_back(dummy_state_); dummy_ants_.push_back(NULL); dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] ||| X=0")); + kSOS_ = MapWord(TD::Convert("")); + assert(kSOS_ > 0); + kEOS_ = MapWord(TD::Convert("")); } ~KLanguageModelImpl() { @@ -172,7 +198,15 @@ class KLanguageModelImpl { int ReserveStateSize() const { return state_size_; } private: + lm::WordIndex kSOS_; // - requires special handling. + lm::WordIndex kEOS_; // Model* ngram_; + bool add_sos_eos_; // flag indicating whether the hypergraph produces and + // if this is true, FinalTransitionFeatures will "add" and + // if false, FinalTransitionFeatures will score anything with the + // markers in the right place (i.e., the beginning and end of + // the sentence) with 0, and anything else with -100 + int order_; int state_size_; int unscored_size_offset_; -- cgit v1.2.3