summaryrefslogtreecommitdiff
path: root/decoder/ff_klm.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/ff_klm.cc')
-rw-r--r--decoder/ff_klm.cc23
1 files changed, 19 insertions, 4 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index 7c37ddb7..203bced5 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -82,11 +82,13 @@ class KLanguageModelImpl {
}
public:
- double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, void* remnant) {
+ double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, double* oovs, double* est_oovs, void* remnant) {
double sum = 0.0;
double est_sum = 0.0;
int num_scored = 0;
int num_estimated = 0;
+ if (oovs) *oovs = 0;
+ if (est_oovs) *est_oovs = 0;
bool saw_eos = false;
bool has_some_history = false;
lm::ngram::State state = ngram_->NullContextState();
@@ -98,6 +100,7 @@ class KLanguageModelImpl {
int unscored_ant_len = UnscoredSize(astate);
for (int k = 0; k < unscored_ant_len; ++k) {
const lm::WordIndex cur_word = IthUnscoredWord(k, astate);
+ const bool is_oov = (cur_word == 0);
double p = 0;
if (cur_word == kSOS_) {
state = ngram_->BeginSentenceState();
@@ -120,11 +123,13 @@ class KLanguageModelImpl {
}
if (context_complete) {
sum += p;
+ if (oovs && is_oov) (*oovs)++;
} else {
if (remnant)
SetIthUnscoredWord(num_estimated, cur_word, remnant);
++num_estimated;
est_sum += p;
+ if (est_oovs && is_oov) (*est_oovs)++;
}
}
saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT);
@@ -135,6 +140,7 @@ class KLanguageModelImpl {
} else { // handle terminal
const lm::WordIndex cur_word = MapWord(e[j]);
double p = 0;
+ const bool is_oov = (cur_word == 0);
if (cur_word == kSOS_) {
state = ngram_->BeginSentenceState();
if (has_some_history) { // this is immediately fully scored, and bad
@@ -156,11 +162,13 @@ class KLanguageModelImpl {
}
if (context_complete) {
sum += p;
+ if (oovs && is_oov) (*oovs)++;
} else {
if (remnant)
SetIthUnscoredWord(num_estimated, cur_word, remnant);
++num_estimated;
est_sum += p;
+ if (est_oovs && is_oov) (*est_oovs)++;
}
}
}
@@ -183,7 +191,7 @@ class KLanguageModelImpl {
SetHasFullContext(1, dummy_state_);
SetUnscoredSize(0, dummy_state_);
dummy_ants_[1] = state;
- return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL);
+ return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL, NULL, NULL);
} else { // rules DO produce <s> ... </s>
double p = 0;
if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; }
@@ -264,7 +272,8 @@ class KLanguageModelImpl {
template <class Model>
KLanguageModel<Model>::KLanguageModel(const string& param) {
pimpl_ = new KLanguageModelImpl<Model>(param);
- fid_ = FD::Convert("LanguageModel");
+ fid_ = FD::Convert("LanguageModel"); // todo support LM feature name
+ oov_fid_ = FD::Convert("OOV"); // should also be named
SetStateSize(pimpl_->ReserveStateSize());
}
@@ -286,8 +295,14 @@ void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* sme
SparseVector<double>* estimated_features,
void* state) const {
double est = 0;
- features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &est, state));
+ double oovs = 0;
+ double est_oovs = 0;
+ features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &est, &oovs, &est_oovs, state));
estimated_features->set_value(fid_, est);
+ if (oov_fid_) {
+ if (oovs) features->set_value(oov_fid_, oovs);
+ if (est_oovs) estimated_features->set_value(oov_fid_, est_oovs);
+ }
}
template <class Model>