diff options
-rw-r--r-- | decoder/ff_klm.cc | 23 | ||||
-rw-r--r-- | decoder/ff_klm.h | 1 |
2 files changed, 20 insertions, 4 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 7c37ddb7..203bced5 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -82,11 +82,13 @@ class KLanguageModelImpl { } public: - double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, void* remnant) { + double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, double* oovs, double* est_oovs, void* remnant) { double sum = 0.0; double est_sum = 0.0; int num_scored = 0; int num_estimated = 0; + if (oovs) *oovs = 0; + if (est_oovs) *est_oovs = 0; bool saw_eos = false; bool has_some_history = false; lm::ngram::State state = ngram_->NullContextState(); @@ -98,6 +100,7 @@ class KLanguageModelImpl { int unscored_ant_len = UnscoredSize(astate); for (int k = 0; k < unscored_ant_len; ++k) { const lm::WordIndex cur_word = IthUnscoredWord(k, astate); + const bool is_oov = (cur_word == 0); double p = 0; if (cur_word == kSOS_) { state = ngram_->BeginSentenceState(); @@ -120,11 +123,13 @@ class KLanguageModelImpl { } if (context_complete) { sum += p; + if (oovs && is_oov) (*oovs)++; } else { if (remnant) SetIthUnscoredWord(num_estimated, cur_word, remnant); ++num_estimated; est_sum += p; + if (est_oovs && is_oov) (*est_oovs)++; } } saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT); @@ -135,6 +140,7 @@ class KLanguageModelImpl { } else { // handle terminal const lm::WordIndex cur_word = MapWord(e[j]); double p = 0; + const bool is_oov = (cur_word == 0); if (cur_word == kSOS_) { state = ngram_->BeginSentenceState(); if (has_some_history) { // this is immediately fully scored, and bad @@ -156,11 +162,13 @@ class KLanguageModelImpl { } if (context_complete) { sum += p; + if (oovs && is_oov) (*oovs)++; } else { if (remnant) SetIthUnscoredWord(num_estimated, cur_word, remnant); ++num_estimated; est_sum += p; + if (est_oovs && is_oov) (*est_oovs)++; } } } @@ -183,7 +191,7 @@ class KLanguageModelImpl { SetHasFullContext(1, dummy_state_); SetUnscoredSize(0, dummy_state_); dummy_ants_[1] = state; - return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL); + return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL, NULL, NULL); } else { // rules DO produce <s> ... </s> double p = 0; if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } @@ -264,7 +272,8 @@ class KLanguageModelImpl { template <class Model> KLanguageModel<Model>::KLanguageModel(const string& param) { pimpl_ = new KLanguageModelImpl<Model>(param); - fid_ = FD::Convert("LanguageModel"); + fid_ = FD::Convert("LanguageModel"); // todo support LM feature name + oov_fid_ = FD::Convert("OOV"); // should also be named SetStateSize(pimpl_->ReserveStateSize()); } @@ -286,8 +295,14 @@ void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* sme SparseVector<double>* estimated_features, void* state) const { double est = 0; - features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &est, state)); + double oovs = 0; + double est_oovs = 0; + features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &est, &oovs, &est_oovs, state)); estimated_features->set_value(fid_, est); + if (oov_fid_) { + if (oovs) features->set_value(oov_fid_, oovs); + if (est_oovs) estimated_features->set_value(oov_fid_, est_oovs); + } } template <class Model> diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h index 95e1e897..5eafe8be 100644 --- a/decoder/ff_klm.h +++ b/decoder/ff_klm.h @@ -30,6 +30,7 @@ class KLanguageModel : public FeatureFunction { void* out_context) const; private: int fid_; // conceptually const; mutable only to simplify constructor + int oov_fid_; // will be zero if extra OOV feature is not configured by decoder KLanguageModelImpl<Model>* pimpl_; }; |