diff options
| -rw-r--r-- | decoder/ff_klm.cc | 23 | ||||
| -rw-r--r-- | decoder/ff_klm.h | 1 | 
2 files changed, 20 insertions, 4 deletions
| diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 7c37ddb7..203bced5 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -82,11 +82,13 @@ class KLanguageModelImpl {    }   public: -  double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, void* remnant) { +  double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, double* oovs, double* est_oovs, void* remnant) {      double sum = 0.0;      double est_sum = 0.0;      int num_scored = 0;      int num_estimated = 0; +    if (oovs) *oovs = 0; +    if (est_oovs) *est_oovs = 0;      bool saw_eos = false;      bool has_some_history = false;      lm::ngram::State state = ngram_->NullContextState(); @@ -98,6 +100,7 @@ class KLanguageModelImpl {          int unscored_ant_len = UnscoredSize(astate);          for (int k = 0; k < unscored_ant_len; ++k) {            const lm::WordIndex cur_word = IthUnscoredWord(k, astate); +          const bool is_oov = (cur_word == 0);            double p = 0;            if (cur_word == kSOS_) {              state = ngram_->BeginSentenceState(); @@ -120,11 +123,13 @@ class KLanguageModelImpl {            }            if (context_complete) {              sum += p; +            if (oovs && is_oov) (*oovs)++;            } else {              if (remnant)                SetIthUnscoredWord(num_estimated, cur_word, remnant);              ++num_estimated;              est_sum += p; +            if (est_oovs && is_oov) (*est_oovs)++;            }          }          saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT); @@ -135,6 +140,7 @@ class KLanguageModelImpl {        } else {   // handle terminal          const lm::WordIndex cur_word = MapWord(e[j]);          double p = 0; +        const bool is_oov = (cur_word == 0);          if (cur_word == kSOS_) {            state = ngram_->BeginSentenceState();            if (has_some_history) {  // this is immediately fully scored, and bad @@ -156,11 +162,13 @@ class KLanguageModelImpl {          }          if (context_complete) {            sum += p; +          if (oovs && is_oov) (*oovs)++;          } else {            if (remnant)              SetIthUnscoredWord(num_estimated, cur_word, remnant);            ++num_estimated;            est_sum += p; +          if (est_oovs && is_oov) (*est_oovs)++;          }        }      } @@ -183,7 +191,7 @@ class KLanguageModelImpl {        SetHasFullContext(1, dummy_state_);        SetUnscoredSize(0, dummy_state_);        dummy_ants_[1] = state; -      return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL); +      return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL, NULL, NULL);      } else {  // rules DO produce <s> ... </s>        double p = 0;        if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } @@ -264,7 +272,8 @@ class KLanguageModelImpl {  template <class Model>  KLanguageModel<Model>::KLanguageModel(const string& param) {    pimpl_ = new KLanguageModelImpl<Model>(param); -  fid_ = FD::Convert("LanguageModel"); +  fid_ = FD::Convert("LanguageModel");  // todo support LM feature name +  oov_fid_ = FD::Convert("OOV");        // should also be named    SetStateSize(pimpl_->ReserveStateSize());  } @@ -286,8 +295,14 @@ void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* sme                                            SparseVector<double>* estimated_features,                                            void* state) const {    double est = 0; -  features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &est, state)); +  double oovs = 0; +  double est_oovs = 0; +  features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &est, &oovs, &est_oovs, state));    estimated_features->set_value(fid_, est); +  if (oov_fid_) { +    if (oovs) features->set_value(oov_fid_, oovs); +    if (est_oovs) estimated_features->set_value(oov_fid_, est_oovs); +  }  }  template <class Model> diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h index 95e1e897..5eafe8be 100644 --- a/decoder/ff_klm.h +++ b/decoder/ff_klm.h @@ -30,6 +30,7 @@ class KLanguageModel : public FeatureFunction {                                       void* out_context) const;   private:    int fid_; // conceptually const; mutable only to simplify constructor +  int oov_fid_; // will be zero if extra OOV feature is not configured by decoder    KLanguageModelImpl<Model>* pimpl_;  }; | 
