diff options
Diffstat (limited to 'decoder/ff_klm.cc')
-rw-r--r-- | decoder/ff_klm.cc | 72 |
1 files changed, 49 insertions, 23 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 6d9aca54..658aef80 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -71,6 +71,8 @@ string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) { return "KLanguageModel"; } +namespace { + struct VMapper : public lm::ngram::EnumerateVocab { VMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } void Add(lm::WordIndex index, const StringPiece &str) { @@ -83,66 +85,90 @@ struct VMapper : public lm::ngram::EnumerateVocab { const lm::WordIndex kLM_UNKNOWN_TOKEN; }; -template <class Model> -class KLanguageModelImpl { +#pragma pack(push) +#pragma pack(1) - static inline const lm::ngram::ChartState& RemnantLMState(const void* state) { - return *static_cast<const lm::ngram::ChartState*>(state); +struct BoundaryAnnotatedState { + lm::ngram::ChartState state; + bool seen_bos, seen_eos; +}; + +#pragma pack(pop) + +void BoundaryCheck(bool &annotated, bool sub, double &ret) { + if (!sub) return; + if (annotated) { + ret -= 100.0; + } else { + annotated = true; } +} +} // namespace + +template <class Model> +class KLanguageModelImpl { public: double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* oovs, void* remnant) { *oovs = 0; const vector<WordID>& e = rule.e(); - lm::ngram::RuleScore<Model> ruleScore(*ngram_, *static_cast<lm::ngram::ChartState*>(remnant)); + BoundaryAnnotatedState &annotated = *static_cast<BoundaryAnnotatedState*>(remnant); + lm::ngram::RuleScore<Model> ruleScore(*ngram_, annotated.state); + annotated.seen_bos = false; + annotated.seen_eos = false; unsigned i = 0; + double ret = 0.0; if (e.size()) { if (e[i] == kCDEC_SOS) { ++i; ruleScore.BeginSentence(); + annotated.seen_bos = true; } else if (e[i] <= 0) { // special case for left-edge NT - const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[0]]); - ruleScore.BeginNonTerminal(prevState, 0.0f); // TODO + const BoundaryAnnotatedState &sub = *static_cast<const BoundaryAnnotatedState*>(ant_states[-e[0]]); + ruleScore.BeginNonTerminal(sub.state, 0.0f); + annotated.seen_bos = sub.seen_bos; + annotated.seen_eos = sub.seen_eos; ++i; } } for (; i < e.size(); ++i) { if (e[i] <= 0) { - const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[i]]); - ruleScore.NonTerminal(prevState, 0.0f); // TODO + const BoundaryAnnotatedState &sub = *static_cast<const BoundaryAnnotatedState*>(ant_states[-e[i]]); + ruleScore.NonTerminal(sub.state, 0.0f); + BoundaryCheck(annotated.seen_bos, sub.seen_bos, ret); + BoundaryCheck(annotated.seen_eos, sub.seen_eos, ret); } else { const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, // maybe handle emission const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id if (cur_word == 0) (*oovs) += 1.0; + BoundaryCheck(annotated.seen_eos, cur_word == kEOS_, ret); ruleScore.Terminal(cur_word); } } - double ret = ruleScore.Finish(); - static_cast<lm::ngram::ChartState*>(remnant)->ZeroRemaining(); + ret += ruleScore.Finish(); + annotated.state.ZeroRemaining(); return ret; } // this assumes no target words on final unary -> goal rule. is that ok? // for <s> (n-1 left words) and (n-1 right words) </s> - double FinalTraversalCost(const void* state, double* oovs) { + double FinalTraversalCost(const void* state_void, double* oovs) { + const BoundaryAnnotatedState &annotated = *static_cast<const BoundaryAnnotatedState*>(state_void); if (add_sos_eos_) { // rules do not produce <s> </s>, so do it here + assert(!annotated.seen_bos); + assert(!annotated.seen_eos); lm::ngram::ChartState cstate; lm::ngram::RuleScore<Model> ruleScore(*ngram_, cstate); ruleScore.BeginSentence(); - ruleScore.NonTerminal(RemnantLMState(state), 0.0f); + ruleScore.NonTerminal(annotated.state, 0.0f); ruleScore.Terminal(kEOS_); return ruleScore.Finish(); } else { // rules DO produce <s> ... </s> - double p = 0; - cerr << "not implemented"; abort(); // TODO - //if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } - //if (UnscoredSize(state) > 0) { // are there unscored words - // if (kSOS_ != IthUnscoredWord(0, state)) { - // p -= 100 * UnscoredSize(state); - // } - //} - return p; + double ret = 0.0; + if (!annotated.seen_bos) ret -= 100.0; + if (!annotated.seen_eos) ret -= 100.0; + return ret; } } @@ -230,7 +256,7 @@ class KLanguageModelImpl { delete ngram_; } - int ReserveStateSize() const { return sizeof(lm::ngram::ChartState); } + int ReserveStateSize() const { return sizeof(BoundaryAnnotatedState); } private: const WordID kCDEC_UNK; |