summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-05-26 01:21:44 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2011-05-26 01:21:44 -0400
commitc135e494b7a49d20b51f9c181f104d9adb0099af (patch)
tree2e2ea1ec49b4f5789b6f548d9d0e8d995d75fba4 /decoder
parentcc9a613359d707b452ac0daf2adb782cb96e0223 (diff)
fix bug preventing oovs from firing when they're near the beginning of a word
Diffstat (limited to 'decoder')
-rw-r--r--decoder/ff_klm.cc12
1 files changed, 9 insertions, 3 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index ab44232a..35b35d36 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -239,13 +239,14 @@ class KLanguageModelImpl {
// this assumes no target words on final unary -> goal rule. is that ok?
// for <s> (n-1 left words) and (n-1 right words) </s>
- double FinalTraversalCost(const void* state) {
+ double FinalTraversalCost(const void* state, double* oovs) {
if (add_sos_eos_) { // rules do not produce <s> </s>, so do it here
SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_);
SetHasFullContext(1, dummy_state_);
SetUnscoredSize(0, dummy_state_);
dummy_ants_[1] = state;
- return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL, NULL, NULL);
+ *oovs = 0;
+ return LookupWords(*dummy_rule_, dummy_ants_, NULL, oovs, NULL, NULL);
} else { // rules DO produce <s> ... </s>
double p = 0;
if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; }
@@ -387,6 +388,7 @@ KLanguageModel<Model>::KLanguageModel(const string& param) {
pimpl_ = new KLanguageModelImpl<Model>(filename, mapfile, explicit_markers);
fid_ = FD::Convert(featname);
oov_fid_ = FD::Convert(featname+"_OOV");
+ cerr << "FID: " << oov_fid_ << endl;
SetStateSize(pimpl_->ReserveStateSize());
}
@@ -421,7 +423,11 @@ void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* sme
template <class Model>
void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state,
SparseVector<double>* features) const {
- features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state));
+ double oovs = 0;
+ double lm = pimpl_->FinalTraversalCost(ant_state, &oovs);
+ features->set_value(fid_, lm);
+ if (oov_fid_ && oovs)
+ features->set_value(oov_fid_, oovs);
}
// instantiate templates