From 427edc9f1869f3e091779fa83776630013db5ad1 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 17 Oct 2011 16:58:26 +0100 Subject: Chris, I'd like you to review this for use with your rules that contain and . --- decoder/ff_klm.cc | 72 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 49 insertions(+), 23 deletions(-) (limited to 'decoder') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 6d9aca54..658aef80 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -71,6 +71,8 @@ string KLanguageModel::usage(bool /*param*/,bool /*verbose*/) { return "KLanguageModel"; } +namespace { + struct VMapper : public lm::ngram::EnumerateVocab { VMapper(vector* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } void Add(lm::WordIndex index, const StringPiece &str) { @@ -83,66 +85,90 @@ struct VMapper : public lm::ngram::EnumerateVocab { const lm::WordIndex kLM_UNKNOWN_TOKEN; }; -template -class KLanguageModelImpl { +#pragma pack(push) +#pragma pack(1) - static inline const lm::ngram::ChartState& RemnantLMState(const void* state) { - return *static_cast(state); +struct BoundaryAnnotatedState { + lm::ngram::ChartState state; + bool seen_bos, seen_eos; +}; + +#pragma pack(pop) + +void BoundaryCheck(bool &annotated, bool sub, double &ret) { + if (!sub) return; + if (annotated) { + ret -= 100.0; + } else { + annotated = true; } +} +} // namespace + +template +class KLanguageModelImpl { public: double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { *oovs = 0; const vector& e = rule.e(); - lm::ngram::RuleScore ruleScore(*ngram_, *static_cast(remnant)); + BoundaryAnnotatedState &annotated = *static_cast(remnant); + lm::ngram::RuleScore ruleScore(*ngram_, annotated.state); + annotated.seen_bos = false; + annotated.seen_eos = false; unsigned i = 0; + double ret = 0.0; if (e.size()) { if (e[i] == kCDEC_SOS) { ++i; ruleScore.BeginSentence(); + annotated.seen_bos = true; } else if (e[i] <= 0) { // special case for left-edge NT - const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[0]]); - ruleScore.BeginNonTerminal(prevState, 0.0f); // TODO + const BoundaryAnnotatedState &sub = *static_cast(ant_states[-e[0]]); + ruleScore.BeginNonTerminal(sub.state, 0.0f); + annotated.seen_bos = sub.seen_bos; + annotated.seen_eos = sub.seen_eos; ++i; } } for (; i < e.size(); ++i) { if (e[i] <= 0) { - const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[i]]); - ruleScore.NonTerminal(prevState, 0.0f); // TODO + const BoundaryAnnotatedState &sub = *static_cast(ant_states[-e[i]]); + ruleScore.NonTerminal(sub.state, 0.0f); + BoundaryCheck(annotated.seen_bos, sub.seen_bos, ret); + BoundaryCheck(annotated.seen_eos, sub.seen_eos, ret); } else { const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, // maybe handle emission const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id if (cur_word == 0) (*oovs) += 1.0; + BoundaryCheck(annotated.seen_eos, cur_word == kEOS_, ret); ruleScore.Terminal(cur_word); } } - double ret = ruleScore.Finish(); - static_cast(remnant)->ZeroRemaining(); + ret += ruleScore.Finish(); + annotated.state.ZeroRemaining(); return ret; } // this assumes no target words on final unary -> goal rule. is that ok? // for (n-1 left words) and (n-1 right words) - double FinalTraversalCost(const void* state, double* oovs) { + double FinalTraversalCost(const void* state_void, double* oovs) { + const BoundaryAnnotatedState &annotated = *static_cast(state_void); if (add_sos_eos_) { // rules do not produce , so do it here + assert(!annotated.seen_bos); + assert(!annotated.seen_eos); lm::ngram::ChartState cstate; lm::ngram::RuleScore ruleScore(*ngram_, cstate); ruleScore.BeginSentence(); - ruleScore.NonTerminal(RemnantLMState(state), 0.0f); + ruleScore.NonTerminal(annotated.state, 0.0f); ruleScore.Terminal(kEOS_); return ruleScore.Finish(); } else { // rules DO produce ... - double p = 0; - cerr << "not implemented"; abort(); // TODO - //if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } - //if (UnscoredSize(state) > 0) { // are there unscored words - // if (kSOS_ != IthUnscoredWord(0, state)) { - // p -= 100 * UnscoredSize(state); - // } - //} - return p; + double ret = 0.0; + if (!annotated.seen_bos) ret -= 100.0; + if (!annotated.seen_eos) ret -= 100.0; + return ret; } } @@ -230,7 +256,7 @@ class KLanguageModelImpl { delete ngram_; } - int ReserveStateSize() const { return sizeof(lm::ngram::ChartState); } + int ReserveStateSize() const { return sizeof(BoundaryAnnotatedState); } private: const WordID kCDEC_UNK; -- cgit v1.2.3