diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-05-26 01:21:44 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-05-26 01:21:44 -0400 |
commit | c135e494b7a49d20b51f9c181f104d9adb0099af (patch) | |
tree | 2e2ea1ec49b4f5789b6f548d9d0e8d995d75fba4 | |
parent | cc9a613359d707b452ac0daf2adb782cb96e0223 (diff) |
fix bug preventing oovs from firing when they're near the beginning of a word
-rw-r--r-- | decoder/ff_klm.cc | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index ab44232a..35b35d36 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -239,13 +239,14 @@ class KLanguageModelImpl { // this assumes no target words on final unary -> goal rule. is that ok? // for <s> (n-1 left words) and (n-1 right words) </s> - double FinalTraversalCost(const void* state) { + double FinalTraversalCost(const void* state, double* oovs) { if (add_sos_eos_) { // rules do not produce <s> </s>, so do it here SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_); SetHasFullContext(1, dummy_state_); SetUnscoredSize(0, dummy_state_); dummy_ants_[1] = state; - return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL, NULL, NULL); + *oovs = 0; + return LookupWords(*dummy_rule_, dummy_ants_, NULL, oovs, NULL, NULL); } else { // rules DO produce <s> ... </s> double p = 0; if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } @@ -387,6 +388,7 @@ KLanguageModel<Model>::KLanguageModel(const string& param) { pimpl_ = new KLanguageModelImpl<Model>(filename, mapfile, explicit_markers); fid_ = FD::Convert(featname); oov_fid_ = FD::Convert(featname+"_OOV"); + cerr << "FID: " << oov_fid_ << endl; SetStateSize(pimpl_->ReserveStateSize()); } @@ -421,7 +423,11 @@ void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* sme template <class Model> void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state, SparseVector<double>* features) const { - features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state)); + double oovs = 0; + double lm = pimpl_->FinalTraversalCost(ant_state, &oovs); + features->set_value(fid_, lm); + if (oov_fid_ && oovs) + features->set_value(oov_fid_, oovs); } // instantiate templates |