diff options
Diffstat (limited to 'decoder/ff_klm.cc')
-rw-r--r-- | decoder/ff_klm.cc | 308 |
1 files changed, 125 insertions, 183 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 24dcb9c3..28bcb6b9 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,11 +12,9 @@ #include "lm/model.hh" #include "lm/enumerate_vocab.hh" -using namespace std; +#include "lm/left.hh" -static const unsigned char HAS_FULL_CONTEXT = 1; -static const unsigned char HAS_EOS_ON_RIGHT = 2; -static const unsigned char MASK = 7; +using namespace std; // -x : rules include <s> and </s> // -n NAME : feature id is NAME @@ -70,6 +68,8 @@ string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) { return "KLanguageModel"; } +namespace { + struct VMapper : public lm::ngram::EnumerateVocab { VMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } void Add(lm::WordIndex index, const StringPiece &str) { @@ -82,181 +82,122 @@ struct VMapper : public lm::ngram::EnumerateVocab { const lm::WordIndex kLM_UNKNOWN_TOKEN; }; -template <class Model> -class KLanguageModelImpl { - - // returns the number of unscored words at the left edge of a span - inline int UnscoredSize(const void* state) const { - return *(static_cast<const char*>(state) + unscored_size_offset_); - } +#pragma pack(push) +#pragma pack(1) - inline void SetUnscoredSize(int size, void* state) const { - *(static_cast<char*>(state) + unscored_size_offset_) = size; - } +struct BoundaryAnnotatedState { + lm::ngram::ChartState state; + bool seen_bos, seen_eos; +}; - static inline const lm::ngram::State& RemnantLMState(const void* state) { - return *static_cast<const lm::ngram::State*>(state); - } +#pragma pack(pop) + +template <class Model> class BoundaryRuleScore { + public: + BoundaryRuleScore(const Model &m, BoundaryAnnotatedState &state) : + back_(m, state.state), + bos_(state.seen_bos), + eos_(state.seen_eos), + penalty_(0.0), + end_sentence_(m.GetVocabulary().EndSentence()) { + bos_ = false; + eos_ = false; + } - inline void SetRemnantLMState(const lm::ngram::State& lmstate, void* state) const { - // if we were clever, we could use the memory pointed to by state to do all - // the work, avoiding this copy - memcpy(state, &lmstate, ngram_->StateSize()); - } + void BeginSentence() { + back_.BeginSentence(); + bos_ = true; + } - lm::WordIndex IthUnscoredWord(int i, const void* state) const { - const lm::WordIndex* const mem = reinterpret_cast<const lm::WordIndex*>(static_cast<const char*>(state) + unscored_words_offset_); - return mem[i]; - } + void BeginNonTerminal(const BoundaryAnnotatedState &sub) { + back_.BeginNonTerminal(sub.state, 0.0f); + bos_ = sub.seen_bos; + eos_ = sub.seen_eos; + } - void SetIthUnscoredWord(int i, lm::WordIndex index, void *state) const { - lm::WordIndex* mem = reinterpret_cast<lm::WordIndex*>(static_cast<char*>(state) + unscored_words_offset_); - mem[i] = index; - } + void NonTerminal(const BoundaryAnnotatedState &sub) { + back_.NonTerminal(sub.state, 0.0f); + // cdec only calls this if there's content. + if (sub.seen_bos) { + bos_ = true; + penalty_ -= 100.0f; + } + if (eos_) penalty_ -= 100.0f; + eos_ |= sub.seen_eos; + } - inline bool GetFlag(const void *state, unsigned char flag) const { - return (*(static_cast<const char*>(state) + is_complete_offset_) & flag); - } + void Terminal(lm::WordIndex word) { + back_.Terminal(word); + if (eos_) penalty_ -= 100.0f; + if (word == end_sentence_) eos_ = true; + } - inline void SetFlag(bool on, unsigned char flag, void *state) const { - if (on) { - *(static_cast<char*>(state) + is_complete_offset_) |= flag; - } else { - *(static_cast<char*>(state) + is_complete_offset_) &= (MASK ^ flag); + float Finish() { + return penalty_ + back_.Finish(); } - } - inline bool HasFullContext(const void *state) const { - return GetFlag(state, HAS_FULL_CONTEXT); - } + private: + lm::ngram::RuleScore<Model> back_; + bool &bos_, &eos_; - inline void SetHasFullContext(bool flag, void *state) const { - SetFlag(flag, HAS_FULL_CONTEXT, state); - } + float penalty_; + lm::WordIndex end_sentence_; +}; + +} // namespace + +template <class Model> +class KLanguageModelImpl { public: - double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, double* oovs, double* est_oovs, void* remnant) { - double sum = 0.0; - double est_sum = 0.0; - int num_scored = 0; - int num_estimated = 0; - if (oovs) *oovs = 0; - if (est_oovs) *est_oovs = 0; - bool saw_eos = false; - bool has_some_history = false; - lm::ngram::State state = ngram_->NullContextState(); + double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* oovs, void* remnant) { + *oovs = 0; const vector<WordID>& e = rule.e(); - bool context_complete = false; - for (int j = 0; j < e.size(); ++j) { - if (e[j] < 1) { // handle non-terminal substitution - const void* astate = (ant_states[-e[j]]); - int unscored_ant_len = UnscoredSize(astate); - for (int k = 0; k < unscored_ant_len; ++k) { - const lm::WordIndex cur_word = IthUnscoredWord(k, astate); - const bool is_oov = (cur_word == 0); - double p = 0; - if (cur_word == kSOS_) { - state = ngram_->BeginSentenceState(); - if (has_some_history) { // this is immediately fully scored, and bad - p = -100; - context_complete = true; - } else { // this might be a real <s> - num_scored = max(0, order_ - 2); - } - } else { - const lm::ngram::State scopy(state); - p = ngram_->Score(scopy, cur_word, state); - if (saw_eos) { p = -100; } - saw_eos = (cur_word == kEOS_); - } - has_some_history = true; - ++num_scored; - if (!context_complete) { - if (num_scored >= order_) context_complete = true; - } - if (context_complete) { - sum += p; - if (oovs && is_oov) (*oovs)++; - } else { - if (remnant) - SetIthUnscoredWord(num_estimated, cur_word, remnant); - ++num_estimated; - est_sum += p; - if (est_oovs && is_oov) (*est_oovs)++; - } - } - saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT); - if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007 - state = RemnantLMState(astate); - context_complete = true; - } - } else { // handle terminal - const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[j]); // in future, + BoundaryRuleScore<Model> ruleScore(*ngram_, *static_cast<BoundaryAnnotatedState*>(remnant)); + unsigned i = 0; + if (e.size()) { + if (e[i] == kCDEC_SOS) { + ++i; + ruleScore.BeginSentence(); + } else if (e[i] <= 0) { // special case for left-edge NT + ruleScore.BeginNonTerminal(*static_cast<const BoundaryAnnotatedState*>(ant_states[-e[0]])); + ++i; + } + } + for (; i < e.size(); ++i) { + if (e[i] <= 0) { + ruleScore.NonTerminal(*static_cast<const BoundaryAnnotatedState*>(ant_states[-e[i]])); + } else { + const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, // maybe handle emission const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id - double p = 0; - const bool is_oov = (cur_word == 0); - if (cur_word == kSOS_) { - state = ngram_->BeginSentenceState(); - if (has_some_history) { // this is immediately fully scored, and bad - p = -100; - context_complete = true; - } else { // this might be a real <s> - num_scored = max(0, order_ - 2); - } - } else { - const lm::ngram::State scopy(state); - p = ngram_->Score(scopy, cur_word, state); - if (saw_eos) { p = -100; } - saw_eos = (cur_word == kEOS_); - } - has_some_history = true; - ++num_scored; - if (!context_complete) { - if (num_scored >= order_) context_complete = true; - } - if (context_complete) { - sum += p; - if (oovs && is_oov) (*oovs)++; - } else { - if (remnant) - SetIthUnscoredWord(num_estimated, cur_word, remnant); - ++num_estimated; - est_sum += p; - if (est_oovs && is_oov) (*est_oovs)++; - } + if (cur_word == 0) (*oovs) += 1.0; + ruleScore.Terminal(cur_word); } } - if (pest_sum) *pest_sum = est_sum; - if (remnant) { - state.ZeroRemaining(); - SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant); - SetRemnantLMState(state, remnant); - SetUnscoredSize(num_estimated, remnant); - SetHasFullContext(context_complete || (num_scored >= order_), remnant); - } - return sum; + double ret = ruleScore.Finish(); + static_cast<BoundaryAnnotatedState*>(remnant)->state.ZeroRemaining(); + return ret; } // this assumes no target words on final unary -> goal rule. is that ok? // for <s> (n-1 left words) and (n-1 right words) </s> - double FinalTraversalCost(const void* state, double* oovs) { + double FinalTraversalCost(const void* state_void, double* oovs) { + const BoundaryAnnotatedState &annotated = *static_cast<const BoundaryAnnotatedState*>(state_void); if (add_sos_eos_) { // rules do not produce <s> </s>, so do it here - SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_); - SetHasFullContext(1, dummy_state_); - SetUnscoredSize(0, dummy_state_); - dummy_ants_[1] = state; - *oovs = 0; - return LookupWords(*dummy_rule_, dummy_ants_, NULL, oovs, NULL, NULL); + assert(!annotated.seen_bos); + assert(!annotated.seen_eos); + lm::ngram::ChartState cstate; + lm::ngram::RuleScore<Model> ruleScore(*ngram_, cstate); + ruleScore.BeginSentence(); + ruleScore.NonTerminal(annotated.state, 0.0f); + ruleScore.Terminal(kEOS_); + return ruleScore.Finish(); } else { // rules DO produce <s> ... </s> - double p = 0; - if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } - if (UnscoredSize(state) > 0) { // are there unscored words - if (kSOS_ != IthUnscoredWord(0, state)) { - p -= 100 * UnscoredSize(state); - } - } - return p; + double ret = 0.0; + if (!annotated.seen_bos) ret -= 100.0; + if (!annotated.seen_eos) ret -= 100.0; + return ret; } } @@ -282,6 +223,7 @@ class KLanguageModelImpl { public: KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) : kCDEC_UNK(TD::Convert("<unk>")) , + kCDEC_SOS(TD::Convert("<s>")) , add_sos_eos_(!explicit_markers) { { VMapper vm(&cdec2klm_map_); @@ -291,18 +233,9 @@ class KLanguageModelImpl { } order_ = ngram_->Order(); cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n"; - state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex); - unscored_size_offset_ = ngram_->StateSize(); - is_complete_offset_ = unscored_size_offset_ + 1; - unscored_words_offset_ = is_complete_offset_ + 1; // special handling of beginning / ending sentence markers - dummy_state_ = new char[state_size_]; - memset(dummy_state_, 0, state_size_); - dummy_ants_.push_back(dummy_state_); - dummy_ants_.push_back(NULL); - dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0")); - kSOS_ = MapWord(TD::Convert("<s>")); + kSOS_ = MapWord(kCDEC_SOS); assert(kSOS_ > 0); kEOS_ = MapWord(TD::Convert("</s>")); assert(kEOS_ > 0); @@ -350,13 +283,13 @@ class KLanguageModelImpl { ~KLanguageModelImpl() { delete ngram_; - delete[] dummy_state_; } - int ReserveStateSize() const { return state_size_; } + int ReserveStateSize() const { return sizeof(BoundaryAnnotatedState); } private: const WordID kCDEC_UNK; + const WordID kCDEC_SOS; lm::WordIndex kSOS_; // <s> - requires special handling. lm::WordIndex kEOS_; // </s> Model* ngram_; @@ -367,15 +300,8 @@ class KLanguageModelImpl { // the sentence) with 0, and anything else with -100 int order_; - int state_size_; - int unscored_size_offset_; - int is_complete_offset_; - int unscored_words_offset_; - char* dummy_state_; - vector<const void*> dummy_ants_; vector<lm::WordIndex> cdec2klm_map_; vector<WordID> word2class_map_; // if this is a class-based LM, this is the word->class mapping - TRulePtr dummy_rule_; }; template <class Model> @@ -393,7 +319,7 @@ KLanguageModel<Model>::KLanguageModel(const string& param) { } fid_ = FD::Convert(featname); oov_fid_ = FD::Convert(featname+"_OOV"); - cerr << "FID: " << oov_fid_ << endl; + // cerr << "FID: " << oov_fid_ << endl; SetStateSize(pimpl_->ReserveStateSize()); } @@ -416,13 +342,9 @@ void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* sme void* state) const { double est = 0; double oovs = 0; - double est_oovs = 0; - features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &est, &oovs, &est_oovs, state)); - estimated_features->set_value(fid_, est); - if (oov_fid_) { - if (oovs) features->set_value(oov_fid_, oovs); - if (est_oovs) estimated_features->set_value(oov_fid_, est_oovs); - } + features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &oovs, state)); + if (oovs && oov_fid_) + features->set_value(oov_fid_, oovs); } template <class Model> @@ -469,3 +391,23 @@ boost::shared_ptr<FeatureFunction> KLanguageModelFactory::Create(std::string par std::string KLanguageModelFactory::usage(bool params,bool verbose) const { return KLanguageModel<lm::ngram::Model>::usage(params, verbose); } + + switch (m) { + case HASH_PROBING: + return CreateModel<ProbingModel>(param); + case TRIE_SORTED: + return CreateModel<TrieModel>(param); + case ARRAY_TRIE_SORTED: + return CreateModel<ArrayTrieModel>(param); + case QUANT_TRIE_SORTED: + return CreateModel<QuantTrieModel>(param); + case QUANT_ARRAY_TRIE_SORTED: + return CreateModel<QuantArrayTrieModel>(param); + default: + UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); + } +} + +std::string KLanguageModelFactory::usage(bool params,bool verbose) const { + return KLanguageModel<lm::ngram::Model>::usage(params, verbose); +} |