From 557f85fd642541aa8da5ffeff26dbd3fdcc7474e Mon Sep 17 00:00:00 2001 From: Guest_account Guest_account prguest11 Date: Fri, 23 Sep 2011 20:49:43 +0100 Subject: stub work to talk to new kenlm --- decoder/ff_klm.cc | 349 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 349 insertions(+) (limited to 'decoder') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 24dcb9c3..016aad26 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,6 +12,353 @@ #include "lm/model.hh" #include "lm/enumerate_vocab.hh" +#undef NEW_KENLM +#ifdef NEW_KENLM + +#include "lm/left.hh" + +using namespace std; + +// -x : rules include and +// -n NAME : feature id is NAME +bool ParseLMArgs(string const& in, string* filename, string* mapfile, bool* explicit_markers, string* featname) { + vector const& argv=SplitOnWhitespace(in); + *explicit_markers = false; + *featname="LanguageModel"; + *mapfile = ""; +#define LMSPEC_NEXTARG if (i==argv.end()) { \ + cerr << "Missing argument for "<<*last<<". "; goto usage; \ + } else { ++i; } + + for (vector::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { + string const& s=*i; + if (s[0]=='-') { + if (s.size()>2) goto fail; + switch (s[1]) { + case 'x': + *explicit_markers = true; + break; + case 'm': + LMSPEC_NEXTARG; *mapfile=*i; + break; + case 'n': + LMSPEC_NEXTARG; *featname=*i; + break; +#undef LMSPEC_NEXTARG + default: + fail: + cerr<<"Unknown KLanguageModel option "<empty()) + *filename=s; + else { + cerr<<"More than one filename provided. "; + goto usage; + } + } + } + if (!filename->empty()) + return true; +usage: + cerr << "KLanguageModel is incorrect!\n"; + return false; +} + +template +string KLanguageModel::usage(bool /*param*/,bool /*verbose*/) { + return "KLanguageModel"; +} + +struct VMapper : public lm::ngram::EnumerateVocab { + VMapper(vector* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } + void Add(lm::WordIndex index, const StringPiece &str) { + const WordID cdec_id = TD::Convert(str.as_string()); + if (cdec_id >= out_->size()) + out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); + (*out_)[cdec_id] = index; + } + vector* out_; + const lm::WordIndex kLM_UNKNOWN_TOKEN; +}; + +template +class KLanguageModelImpl { + + static inline const lm::ngram::ChartState& RemnantLMState(const void* state) { + return *static_cast(state); + } + + inline void SetRemnantLMState(const lm::ngram::ChartState& lmstate, void* state) const { + // if we were clever, we could use the memory pointed to by state to do all + // the work, avoiding this copy + memcpy(state, &lmstate, ngram_->StateSize()); + } + + public: + double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { + double sum = 0.0; + if (oovs) *oovs = 0; + const vector& e = rule.e(); + lm::ngram::ChartState state; + lm::ngram::RuleScore ruleScore(*ngram_, state); + unsigned i = 0; + if (e.size()) { + if (e[i] == kCDEC_SOS) { + ++i; + ruleScore.BeginSentence(); + } else if (e[i] <= 0) { // special case for left-edge NT + const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[0]]); + ruleScore.BeginNonTerminal(prevState, 0.0f); // TODO + ++i; + } + } + for (; i < e.size(); ++i) { + if (e[i] <= 0) { + const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[i]]); + ruleScore.NonTerminal(prevState, 0.0f); // TODO + } else { + const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, + // maybe handle emission + const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id + const bool is_oov = (cur_word == 0); + if (is_oov) (*oovs) += 1.0; + ruleScore.Terminal(cur_word); + } + } + if (remnant) SetRemnantLMState(state, remnant); + return ruleScore.Finish(); + } + + // this assumes no target words on final unary -> goal rule. is that ok? + // for (n-1 left words) and (n-1 right words) + double FinalTraversalCost(const void* state, double* oovs) { + if (add_sos_eos_) { // rules do not produce , so do it here + lm::ngram::ChartState cstate; + lm::ngram::RuleScore ruleScore(*ngram_, cstate); + ruleScore.BeginSentence(); + SetRemnantLMState(cstate, dummy_state_); + dummy_ants_[1] = state; + *oovs = 0; + return LookupWords(*dummy_rule_, dummy_ants_, oovs, NULL); + } else { // rules DO produce ... + double p = 0; + cerr << "not implemented"; abort(); // TODO + //if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } + //if (UnscoredSize(state) > 0) { // are there unscored words + // if (kSOS_ != IthUnscoredWord(0, state)) { + // p -= 100 * UnscoredSize(state); + // } + //} + return p; + } + } + + // if this is not a class-based LM, returns w untransformed, + // otherwise returns a word class mapping of w, + // returns TD::Convert("") if there is no mapping for w + WordID ClassifyWordIfNecessary(WordID w) const { + if (word2class_map_.empty()) return w; + if (w >= word2class_map_.size()) + return kCDEC_UNK; + else + return word2class_map_[w]; + } + + // converts to cdec word id's to KenLM's id space, OOVs and end up at 0 + lm::WordIndex MapWord(WordID w) const { + if (w >= cdec2klm_map_.size()) + return 0; + else + return cdec2klm_map_[w]; + } + + public: + KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) : + kCDEC_UNK(TD::Convert("")) , + kCDEC_SOS(TD::Convert("")) , + add_sos_eos_(!explicit_markers) { + { + VMapper vm(&cdec2klm_map_); + lm::ngram::Config conf; + conf.enumerate_vocab = &vm; + ngram_ = new Model(filename.c_str(), conf); + } + order_ = ngram_->Order(); + cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n"; + state_size_ = sizeof(lm::ngram::ChartState); + + // special handling of beginning / ending sentence markers + dummy_state_ = new char[state_size_]; + memset(dummy_state_, 0, state_size_); + dummy_ants_.push_back(dummy_state_); + dummy_ants_.push_back(NULL); + dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] ||| X=0")); + kSOS_ = MapWord(kCDEC_SOS); + assert(kSOS_ > 0); + kEOS_ = MapWord(TD::Convert("")); + assert(kEOS_ > 0); + assert(MapWord(kCDEC_UNK) == 0); // KenLM invariant + + // handle class-based LMs (unambiguous word->class mapping reqd.) + if (mapfile.size()) + LoadWordClasses(mapfile); + } + + void LoadWordClasses(const string& file) { + ReadFile rf(file); + istream& in = *rf.stream(); + string line; + vector dummy; + int lc = 0; + cerr << " Loading word classes from " << file << " ...\n"; + AddWordToClassMapping_(TD::Convert(""), TD::Convert("")); + AddWordToClassMapping_(TD::Convert(""), TD::Convert("")); + while(in) { + getline(in, line); + if (!in) continue; + dummy.clear(); + TD::ConvertSentence(line, &dummy); + ++lc; + if (dummy.size() != 2) { + cerr << " Format error in " << file << ", line " << lc << ": " << line << endl; + abort(); + } + AddWordToClassMapping_(dummy[0], dummy[1]); + } + } + + void AddWordToClassMapping_(WordID word, WordID cls) { + if (word2class_map_.size() <= word) { + word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK); + assert(word2class_map_.size() > word); + } + if(word2class_map_[word] != kCDEC_UNK) { + cerr << "Multiple classes for symbol " << TD::Convert(word) << endl; + abort(); + } + word2class_map_[word] = cls; + } + + ~KLanguageModelImpl() { + delete ngram_; + delete[] dummy_state_; + } + + int ReserveStateSize() const { return state_size_; } + + private: + const WordID kCDEC_UNK; + const WordID kCDEC_SOS; + lm::WordIndex kSOS_; // - requires special handling. + lm::WordIndex kEOS_; // + Model* ngram_; + const bool add_sos_eos_; // flag indicating whether the hypergraph produces and + // if this is true, FinalTransitionFeatures will "add" and + // if false, FinalTransitionFeatures will score anything with the + // markers in the right place (i.e., the beginning and end of + // the sentence) with 0, and anything else with -100 + + int order_; + int state_size_; + char* dummy_state_; + vector dummy_ants_; + vector cdec2klm_map_; + vector word2class_map_; // if this is a class-based LM, this is the word->class mapping + TRulePtr dummy_rule_; +}; + +template +KLanguageModel::KLanguageModel(const string& param) { + string filename, mapfile, featname; + bool explicit_markers; + if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) { + abort(); + } + try { + pimpl_ = new KLanguageModelImpl(filename, mapfile, explicit_markers); + } catch (std::exception &e) { + std::cerr << e.what() << std::endl; + abort(); + } + fid_ = FD::Convert(featname); + oov_fid_ = FD::Convert(featname+"_OOV"); + // cerr << "FID: " << oov_fid_ << endl; + SetStateSize(pimpl_->ReserveStateSize()); +} + +template +Features KLanguageModel::features() const { + return single_feature(fid_); +} + +template +KLanguageModel::~KLanguageModel() { + delete pimpl_; +} + +template +void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, + const Hypergraph::Edge& edge, + const vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + double est = 0; + double oovs = 0; + features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &oovs, state)); + if (oovs && oov_fid_) + features->set_value(oov_fid_, oovs); +} + +template +void KLanguageModel::FinalTraversalFeatures(const void* ant_state, + SparseVector* features) const { + double oovs = 0; + double lm = pimpl_->FinalTraversalCost(ant_state, &oovs); + features->set_value(fid_, lm); + if (oov_fid_ && oovs) + features->set_value(oov_fid_, oovs); +} + +template boost::shared_ptr CreateModel(const std::string ¶m) { + KLanguageModel *ret = new KLanguageModel(param); + ret->Init(); + return boost::shared_ptr(ret); +} + +boost::shared_ptr KLanguageModelFactory::Create(std::string param) const { + using namespace lm::ngram; + std::string filename, ignored_map; + bool ignored_markers; + std::string ignored_featname; + ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); + ModelType m; + if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; + + switch (m) { + case HASH_PROBING: + return CreateModel(param); + case TRIE_SORTED: + return CreateModel(param); + case ARRAY_TRIE_SORTED: + return CreateModel(param); + case QUANT_TRIE_SORTED: + return CreateModel(param); + case QUANT_ARRAY_TRIE_SORTED: + return CreateModel(param); + default: + UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); + } +} + +std::string KLanguageModelFactory::usage(bool params,bool verbose) const { + return KLanguageModel::usage(params, verbose); +} + +#else + using namespace std; static const unsigned char HAS_FULL_CONTEXT = 1; @@ -469,3 +816,5 @@ boost::shared_ptr KLanguageModelFactory::Create(std::string par std::string KLanguageModelFactory::usage(bool params,bool verbose) const { return KLanguageModel::usage(params, verbose); } + +#endif -- cgit v1.2.3