From 0e46089cafa4e8e2f060e370d7afaceeda6b90a9 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 22 Apr 2013 22:50:14 -0400 Subject: support emission probabilities in class-based LMs --- decoder/ff_klm.cc | 49 ++++++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 19 deletions(-) (limited to 'decoder/ff_klm.cc') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index fefa90bd..c8ca917a 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -1,6 +1,7 @@ #include "ff_klm.h" #include +#include #include #include @@ -151,8 +152,9 @@ template class BoundaryRuleScore { template class KLanguageModelImpl { public: - double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { + double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, double* emit, void* remnant) { *oovs = 0; + *emit = 0; const vector& e = rule.e(); BoundaryRuleScore ruleScore(*ngram_, *static_cast(remnant)); unsigned i = 0; @@ -169,8 +171,9 @@ class KLanguageModelImpl { if (e[i] <= 0) { ruleScore.NonTerminal(*static_cast(ant_states[-e[i]])); } else { - const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, - // maybe handle emission + float ep = 0.f; + const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i], &ep); + if (ep) { *emit += ep; } const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id if (cur_word == 0) (*oovs) += 1.0; ruleScore.Terminal(cur_word); @@ -205,12 +208,14 @@ class KLanguageModelImpl { // if this is not a class-based LM, returns w untransformed, // otherwise returns a word class mapping of w, // returns TD::Convert("") if there is no mapping for w - WordID ClassifyWordIfNecessary(WordID w) const { + WordID ClassifyWordIfNecessary(WordID w, float* emitp) const { if (word2class_map_.empty()) return w; if (w >= word2class_map_.size()) return kCDEC_UNK; - else - return word2class_map_[w]; + else { + *emitp = word2class_map_[w].second; + return word2class_map_[w].first; + } } // converts to cdec word id's to KenLM's id space, OOVs and end up at 0 @@ -256,32 +261,32 @@ class KLanguageModelImpl { int lc = 0; if (!SILENT) cerr << " Loading word classes from " << file << " ...\n"; - AddWordToClassMapping_(TD::Convert(""), TD::Convert("")); - AddWordToClassMapping_(TD::Convert(""), TD::Convert("")); - while(in) { - getline(in, line); - if (!in) continue; + AddWordToClassMapping_(TD::Convert(""), TD::Convert(""), 0.0); + AddWordToClassMapping_(TD::Convert(""), TD::Convert(""), 0.0); + while(getline(in, line)) { dummy.clear(); TD::ConvertSentence(line, &dummy); ++lc; - if (dummy.size() != 2) { + if (dummy.size() != 3) { + cerr << " Class map file expects: CLASS WORD logp(WORD|CLASS)\n"; cerr << " Format error in " << file << ", line " << lc << ": " << line << endl; abort(); } - AddWordToClassMapping_(dummy[0], dummy[1]); + AddWordToClassMapping_(dummy[1], dummy[0], strtof(TD::Convert(dummy[2]).c_str(), NULL)); } } - void AddWordToClassMapping_(WordID word, WordID cls) { + void AddWordToClassMapping_(WordID word, WordID cls, float emit) { if (word2class_map_.size() <= word) { - word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK); + word2class_map_.resize((word + 10) * 1.1, pair(kCDEC_UNK,0.f)); assert(word2class_map_.size() > word); } - if(word2class_map_[word] != kCDEC_UNK) { + if(word2class_map_[word].first != kCDEC_UNK) { cerr << "Multiple classes for symbol " << TD::Convert(word) << endl; abort(); } - word2class_map_[word] = cls; + word2class_map_[word].first = cls; + word2class_map_[word].second = emit; } ~KLanguageModelImpl() { @@ -304,7 +309,9 @@ class KLanguageModelImpl { int order_; vector cdec2klm_map_; - vector word2class_map_; // if this is a class-based LM, this is the word->class mapping + vector > word2class_map_; // if this is a class-based LM, + // .first is the word->class mapping + // .second is the emission log probability }; template @@ -322,6 +329,7 @@ KLanguageModel::KLanguageModel(const string& param) { } fid_ = FD::Convert(featname); oov_fid_ = FD::Convert(featname+"_OOV"); + emit_fid_ = FD::Convert(featname+"_Emit"); // cerr << "FID: " << oov_fid_ << endl; SetStateSize(pimpl_->ReserveStateSize()); } @@ -340,9 +348,12 @@ void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* sme void* state) const { double est = 0; double oovs = 0; - features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &oovs, state)); + double emit = 0; + features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &oovs, &emit, state)); if (oovs && oov_fid_) features->set_value(oov_fid_, oovs); + if (emit && emit_fid_) + features->set_value(emit_fid_, emit); } template -- cgit v1.2.3