diff options
-rw-r--r-- | decoder/ff_klm.cc | 96 |
1 files changed, 82 insertions, 14 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 854653c3..446485f6 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -1,7 +1,11 @@ #include "ff_klm.h" #include <cstring> +#include <iostream> +#include <boost/scoped_ptr.hpp> + +#include "filelib.h" #include "stringlib.h" #include "hg.h" #include "tdict.h" @@ -15,10 +19,11 @@ static const unsigned char MASK = 7; // -x : rules include <s> and </s> // -n NAME : feature id is NAME -bool ParseLMArgs(string const& in, string* filename, bool* explicit_markers, string* featname) { +bool ParseLMArgs(string const& in, string* filename, string* mapfile, bool* explicit_markers, string* featname) { vector<string> const& argv=SplitOnWhitespace(in); *explicit_markers = true; *featname="LanguageModel"; + *mapfile = ""; #define LMSPEC_NEXTARG if (i==argv.end()) { \ cerr << "Missing argument for "<<*last<<". "; goto usage; \ } else { ++i; } @@ -31,6 +36,9 @@ bool ParseLMArgs(string const& in, string* filename, bool* explicit_markers, str case 'x': *explicit_markers = true; break; + case 'm': + LMSPEC_NEXTARG; *mapfile=*i; + break; case 'n': LMSPEC_NEXTARG; *featname=*i; break; @@ -182,7 +190,9 @@ class KLanguageModelImpl { context_complete = true; } } else { // handle terminal - const lm::WordIndex cur_word = MapWord(e[j]); + const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[j]); // in future, + // maybe handle emission + const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id double p = 0; const bool is_oov = (cur_word == 0); if (cur_word == kSOS_) { @@ -248,22 +258,38 @@ class KLanguageModelImpl { } } + // if this is not a class-based LM, returns w untransformed, + // otherwise returns a word class mapping of w, + // returns TD::Convert("<unk>") if there is no mapping for w + WordID ClassifyWordIfNecessary(WordID w) const { + if (word2class_map_.empty()) return w; + if (w >= word2class_map_.size()) + return kCDEC_UNK; + else + return word2class_map_[w]; + } + + // converts to cdec word id's to KenLM's id space, OOVs and <unk> end up at 0 lm::WordIndex MapWord(WordID w) const { - if (w >= map_.size()) + if (w >= cdec2klm_map_.size()) return 0; else - return map_[w]; + return cdec2klm_map_[w]; } public: - KLanguageModelImpl(const string& filename, bool explicit_markers) : + KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) : + kCDEC_UNK(TD::Convert("<unk>")) , add_sos_eos_(!explicit_markers) { - lm::ngram::Config conf; - VMapper vm(&map_); - conf.enumerate_vocab = &vm; - ngram_ = new Model(filename.c_str(), conf); + if (true) { + boost::scoped_ptr<lm::ngram::EnumerateVocab> vm; + vm.reset(new VMapper(&cdec2klm_map_)); + lm::ngram::Config conf; + conf.enumerate_vocab = vm.get(); + ngram_ = new Model(filename.c_str(), conf); + } order_ = ngram_->Order(); - cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << map_.size() << ")\n"; + cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n"; state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex); unscored_size_offset_ = ngram_->StateSize(); is_complete_offset_ = unscored_size_offset_ + 1; @@ -278,6 +304,46 @@ class KLanguageModelImpl { assert(kSOS_ > 0); kEOS_ = MapWord(TD::Convert("</s>")); assert(kEOS_ > 0); + assert(MapWord(kCDEC_UNK) == 0); // KenLM invariant + + // handle class-based LMs (unambiguous word->class mapping reqd.) + if (mapfile.size()) + LoadWordClasses(mapfile); + } + + void LoadWordClasses(const string& file) { + ReadFile rf(file); + istream& in = *rf.stream(); + string line; + vector<WordID> dummy; + int lc = 0; + cerr << " Loading word classes from " << file << " ...\n"; + AddWordToClassMapping_(TD::Convert("<s>"), TD::Convert("</s>")); + AddWordToClassMapping_(TD::Convert("</s>"), TD::Convert("</s>")); + while(in) { + getline(in, line); + if (!in) continue; + dummy.clear(); + TD::ConvertSentence(line, &dummy); + ++lc; + if (dummy.size() != 2) { + cerr << " Format error in " << file << ", line " << lc << ": " << line << endl; + abort(); + } + AddWordToClassMapping_(dummy[0], dummy[1]); + } + } + + void AddWordToClassMapping_(WordID word, WordID cls) { + if (word2class_map_.size() <= word) { + word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK); + assert(word2class_map_.size() > word); + } + if(word2class_map_[word] == kCDEC_UNK) { + cerr << "Multiple classes for symbol " << TD::Convert(word) << endl; + abort(); + } + word2class_map_[word] = cls; } ~KLanguageModelImpl() { @@ -288,6 +354,7 @@ class KLanguageModelImpl { int ReserveStateSize() const { return state_size_; } private: + const WordID kCDEC_UNK; lm::WordIndex kSOS_; // <s> - requires special handling. lm::WordIndex kEOS_; // </s> Model* ngram_; @@ -304,18 +371,19 @@ class KLanguageModelImpl { int unscored_words_offset_; char* dummy_state_; vector<const void*> dummy_ants_; - vector<lm::WordIndex> map_; + vector<lm::WordIndex> cdec2klm_map_; + vector<WordID> word2class_map_; // if this is a class-based LM, this is the word->class mapping TRulePtr dummy_rule_; }; template <class Model> KLanguageModel<Model>::KLanguageModel(const string& param) { - string filename, featname; + string filename, mapfile, featname; bool explicit_markers; - if (!ParseLMArgs(param, &filename, &explicit_markers, &featname)) { + if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) { abort(); } - pimpl_ = new KLanguageModelImpl<Model>(filename, explicit_markers); + pimpl_ = new KLanguageModelImpl<Model>(filename, mapfile, explicit_markers); fid_ = FD::Convert(featname); oov_fid_ = FD::Convert(featname+"_OOV"); SetStateSize(pimpl_->ReserveStateSize()); |