summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2011-10-18 10:25:56 +0100
committerKenneth Heafield <github@kheafield.com>2011-10-18 10:25:56 +0100
commit3d1ed02a4e5d81aace80b0e004e96351d116630f (patch)
tree194d61e38362a90544e6349366957b632b1b3f5c /decoder
parent957d90991b4ec80b9877126c736bd60768b094aa (diff)
Revised <s> and </s> handling
Diffstat (limited to 'decoder')
-rw-r--r--decoder/ff_klm.cc84
1 files changed, 58 insertions, 26 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index 658aef80..3c941fbf 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -12,8 +12,8 @@
#include "lm/model.hh"
#include "lm/enumerate_vocab.hh"
+#define NEW_KENLM
#undef NEW_KENLM
-#ifdef NEW_KENLM
#include "lm/left.hh"
@@ -95,14 +95,58 @@ struct BoundaryAnnotatedState {
#pragma pack(pop)
-void BoundaryCheck(bool &annotated, bool sub, double &ret) {
- if (!sub) return;
- if (annotated) {
- ret -= 100.0;
- } else {
- annotated = true;
- }
-}
+template <class Model> class BoundaryRuleScore {
+ public:
+ BoundaryRuleScore(const Model &m, BoundaryAnnotatedState &state) :
+ back_(m, state.state),
+ bos_(state.seen_bos),
+ eos_(state.seen_eos),
+ penalty_(0.0),
+ end_sentence_(m.GetVocabulary().EndSentence()) {
+ bos_ = false;
+ eos_ = false;
+ }
+
+ void BeginSentence() {
+ back_.BeginSentence();
+ bos_ = true;
+ }
+
+ void BeginNonTerminal(const BoundaryAnnotatedState &sub) {
+ back_.BeginNonTerminal(sub.state, 0.0f);
+ bos_ = sub.seen_bos;
+ eos_ = sub.seen_eos;
+ }
+
+ void NonTerminal(const BoundaryAnnotatedState &sub) {
+ back_.NonTerminal(sub.state, 0.0f);
+ // cdec only calls this if there's content.
+ if (sub.seen_bos) {
+ bos_ = true;
+ penalty_ -= 100.0f;
+ }
+ if (eos_) penalty_ -= 100.0f;
+ eos_ |= sub.seen_eos;
+ }
+
+ void Terminal(lm::WordIndex word) {
+ back_.Terminal(word);
+ if (eos_) penalty_ -= 100.0f;
+ if (word == end_sentence_) eos_ = true;
+ }
+
+ float Finish() {
+ return penalty_ + back_.Finish();
+ }
+
+ private:
+ lm::ngram::RuleScore<Model> back_;
+ bool &bos_, &eos_;
+
+ float penalty_;
+
+ lm::WordIndex end_sentence_;
+};
} // namespace
@@ -112,42 +156,30 @@ class KLanguageModelImpl {
double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* oovs, void* remnant) {
*oovs = 0;
const vector<WordID>& e = rule.e();
- BoundaryAnnotatedState &annotated = *static_cast<BoundaryAnnotatedState*>(remnant);
- lm::ngram::RuleScore<Model> ruleScore(*ngram_, annotated.state);
- annotated.seen_bos = false;
- annotated.seen_eos = false;
+ BoundaryRuleScore<Model> ruleScore(*ngram_, *static_cast<BoundaryAnnotatedState*>(remnant));
unsigned i = 0;
- double ret = 0.0;
if (e.size()) {
if (e[i] == kCDEC_SOS) {
++i;
ruleScore.BeginSentence();
- annotated.seen_bos = true;
} else if (e[i] <= 0) { // special case for left-edge NT
- const BoundaryAnnotatedState &sub = *static_cast<const BoundaryAnnotatedState*>(ant_states[-e[0]]);
- ruleScore.BeginNonTerminal(sub.state, 0.0f);
- annotated.seen_bos = sub.seen_bos;
- annotated.seen_eos = sub.seen_eos;
+ ruleScore.BeginNonTerminal(*static_cast<const BoundaryAnnotatedState*>(ant_states[-e[0]]));
++i;
}
}
for (; i < e.size(); ++i) {
if (e[i] <= 0) {
- const BoundaryAnnotatedState &sub = *static_cast<const BoundaryAnnotatedState*>(ant_states[-e[i]]);
- ruleScore.NonTerminal(sub.state, 0.0f);
- BoundaryCheck(annotated.seen_bos, sub.seen_bos, ret);
- BoundaryCheck(annotated.seen_eos, sub.seen_eos, ret);
+ ruleScore.NonTerminal(*static_cast<const BoundaryAnnotatedState*>(ant_states[-e[i]]));
} else {
const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future,
// maybe handle emission
const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id
if (cur_word == 0) (*oovs) += 1.0;
- BoundaryCheck(annotated.seen_eos, cur_word == kEOS_, ret);
ruleScore.Terminal(cur_word);
}
}
- ret += ruleScore.Finish();
- annotated.state.ZeroRemaining();
+ double ret = ruleScore.Finish();
+ static_cast<BoundaryAnnotatedState*>(remnant)->state.ZeroRemaining();
return ret;
}