diff options
author | Guest_account Guest_account prguest11 <prguest11@taipan.cs> | 2011-09-23 20:49:43 +0100 |
---|---|---|
committer | Guest_account Guest_account prguest11 <prguest11@taipan.cs> | 2011-09-23 20:49:43 +0100 |
commit | 8ecf63852d730f99e7c1bbacfbffdf518d5a0c3f (patch) | |
tree | 9cc80e9a47ca8c6d543667c97af5162b9e251516 /decoder | |
parent | e1b61419329c83709018ca397a29d069e4294bd1 (diff) |
stub work to talk to new kenlm
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/ff_klm.cc | 349 |
1 files changed, 349 insertions, 0 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 24dcb9c3..016aad26 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,6 +12,353 @@ #include "lm/model.hh" #include "lm/enumerate_vocab.hh" +#undef NEW_KENLM +#ifdef NEW_KENLM + +#include "lm/left.hh" + +using namespace std; + +// -x : rules include <s> and </s> +// -n NAME : feature id is NAME +bool ParseLMArgs(string const& in, string* filename, string* mapfile, bool* explicit_markers, string* featname) { + vector<string> const& argv=SplitOnWhitespace(in); + *explicit_markers = false; + *featname="LanguageModel"; + *mapfile = ""; +#define LMSPEC_NEXTARG if (i==argv.end()) { \ + cerr << "Missing argument for "<<*last<<". "; goto usage; \ + } else { ++i; } + + for (vector<string>::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { + string const& s=*i; + if (s[0]=='-') { + if (s.size()>2) goto fail; + switch (s[1]) { + case 'x': + *explicit_markers = true; + break; + case 'm': + LMSPEC_NEXTARG; *mapfile=*i; + break; + case 'n': + LMSPEC_NEXTARG; *featname=*i; + break; +#undef LMSPEC_NEXTARG + default: + fail: + cerr<<"Unknown KLanguageModel option "<<s<<" ; "; + goto usage; + } + } else { + if (filename->empty()) + *filename=s; + else { + cerr<<"More than one filename provided. "; + goto usage; + } + } + } + if (!filename->empty()) + return true; +usage: + cerr << "KLanguageModel is incorrect!\n"; + return false; +} + +template <class Model> +string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) { + return "KLanguageModel"; +} + +struct VMapper : public lm::ngram::EnumerateVocab { + VMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } + void Add(lm::WordIndex index, const StringPiece &str) { + const WordID cdec_id = TD::Convert(str.as_string()); + if (cdec_id >= out_->size()) + out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); + (*out_)[cdec_id] = index; + } + vector<lm::WordIndex>* out_; + const lm::WordIndex kLM_UNKNOWN_TOKEN; +}; + +template <class Model> +class KLanguageModelImpl { + + static inline const lm::ngram::ChartState& RemnantLMState(const void* state) { + return *static_cast<const lm::ngram::ChartState*>(state); + } + + inline void SetRemnantLMState(const lm::ngram::ChartState& lmstate, void* state) const { + // if we were clever, we could use the memory pointed to by state to do all + // the work, avoiding this copy + memcpy(state, &lmstate, ngram_->StateSize()); + } + + public: + double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* oovs, void* remnant) { + double sum = 0.0; + if (oovs) *oovs = 0; + const vector<WordID>& e = rule.e(); + lm::ngram::ChartState state; + lm::ngram::RuleScore<Model> ruleScore(*ngram_, state); + unsigned i = 0; + if (e.size()) { + if (e[i] == kCDEC_SOS) { + ++i; + ruleScore.BeginSentence(); + } else if (e[i] <= 0) { // special case for left-edge NT + const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[0]]); + ruleScore.BeginNonTerminal(prevState, 0.0f); // TODO + ++i; + } + } + for (; i < e.size(); ++i) { + if (e[i] <= 0) { + const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[i]]); + ruleScore.NonTerminal(prevState, 0.0f); // TODO + } else { + const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, + // maybe handle emission + const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id + const bool is_oov = (cur_word == 0); + if (is_oov) (*oovs) += 1.0; + ruleScore.Terminal(cur_word); + } + } + if (remnant) SetRemnantLMState(state, remnant); + return ruleScore.Finish(); + } + + // this assumes no target words on final unary -> goal rule. is that ok? + // for <s> (n-1 left words) and (n-1 right words) </s> + double FinalTraversalCost(const void* state, double* oovs) { + if (add_sos_eos_) { // rules do not produce <s> </s>, so do it here + lm::ngram::ChartState cstate; + lm::ngram::RuleScore<Model> ruleScore(*ngram_, cstate); + ruleScore.BeginSentence(); + SetRemnantLMState(cstate, dummy_state_); + dummy_ants_[1] = state; + *oovs = 0; + return LookupWords(*dummy_rule_, dummy_ants_, oovs, NULL); + } else { // rules DO produce <s> ... </s> + double p = 0; + cerr << "not implemented"; abort(); // TODO + //if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } + //if (UnscoredSize(state) > 0) { // are there unscored words + // if (kSOS_ != IthUnscoredWord(0, state)) { + // p -= 100 * UnscoredSize(state); + // } + //} + return p; + } + } + + // if this is not a class-based LM, returns w untransformed, + // otherwise returns a word class mapping of w, + // returns TD::Convert("<unk>") if there is no mapping for w + WordID ClassifyWordIfNecessary(WordID w) const { + if (word2class_map_.empty()) return w; + if (w >= word2class_map_.size()) + return kCDEC_UNK; + else + return word2class_map_[w]; + } + + // converts to cdec word id's to KenLM's id space, OOVs and <unk> end up at 0 + lm::WordIndex MapWord(WordID w) const { + if (w >= cdec2klm_map_.size()) + return 0; + else + return cdec2klm_map_[w]; + } + + public: + KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) : + kCDEC_UNK(TD::Convert("<unk>")) , + kCDEC_SOS(TD::Convert("<s>")) , + add_sos_eos_(!explicit_markers) { + { + VMapper vm(&cdec2klm_map_); + lm::ngram::Config conf; + conf.enumerate_vocab = &vm; + ngram_ = new Model(filename.c_str(), conf); + } + order_ = ngram_->Order(); + cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n"; + state_size_ = sizeof(lm::ngram::ChartState); + + // special handling of beginning / ending sentence markers + dummy_state_ = new char[state_size_]; + memset(dummy_state_, 0, state_size_); + dummy_ants_.push_back(dummy_state_); + dummy_ants_.push_back(NULL); + dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0")); + kSOS_ = MapWord(kCDEC_SOS); + assert(kSOS_ > 0); + kEOS_ = MapWord(TD::Convert("</s>")); + assert(kEOS_ > 0); + assert(MapWord(kCDEC_UNK) == 0); // KenLM invariant + + // handle class-based LMs (unambiguous word->class mapping reqd.) + if (mapfile.size()) + LoadWordClasses(mapfile); + } + + void LoadWordClasses(const string& file) { + ReadFile rf(file); + istream& in = *rf.stream(); + string line; + vector<WordID> dummy; + int lc = 0; + cerr << " Loading word classes from " << file << " ...\n"; + AddWordToClassMapping_(TD::Convert("<s>"), TD::Convert("<s>")); + AddWordToClassMapping_(TD::Convert("</s>"), TD::Convert("</s>")); + while(in) { + getline(in, line); + if (!in) continue; + dummy.clear(); + TD::ConvertSentence(line, &dummy); + ++lc; + if (dummy.size() != 2) { + cerr << " Format error in " << file << ", line " << lc << ": " << line << endl; + abort(); + } + AddWordToClassMapping_(dummy[0], dummy[1]); + } + } + + void AddWordToClassMapping_(WordID word, WordID cls) { + if (word2class_map_.size() <= word) { + word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK); + assert(word2class_map_.size() > word); + } + if(word2class_map_[word] != kCDEC_UNK) { + cerr << "Multiple classes for symbol " << TD::Convert(word) << endl; + abort(); + } + word2class_map_[word] = cls; + } + + ~KLanguageModelImpl() { + delete ngram_; + delete[] dummy_state_; + } + + int ReserveStateSize() const { return state_size_; } + + private: + const WordID kCDEC_UNK; + const WordID kCDEC_SOS; + lm::WordIndex kSOS_; // <s> - requires special handling. + lm::WordIndex kEOS_; // </s> + Model* ngram_; + const bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s> + // if this is true, FinalTransitionFeatures will "add" <s> and </s> + // if false, FinalTransitionFeatures will score anything with the + // markers in the right place (i.e., the beginning and end of + // the sentence) with 0, and anything else with -100 + + int order_; + int state_size_; + char* dummy_state_; + vector<const void*> dummy_ants_; + vector<lm::WordIndex> cdec2klm_map_; + vector<WordID> word2class_map_; // if this is a class-based LM, this is the word->class mapping + TRulePtr dummy_rule_; +}; + +template <class Model> +KLanguageModel<Model>::KLanguageModel(const string& param) { + string filename, mapfile, featname; + bool explicit_markers; + if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) { + abort(); + } + try { + pimpl_ = new KLanguageModelImpl<Model>(filename, mapfile, explicit_markers); + } catch (std::exception &e) { + std::cerr << e.what() << std::endl; + abort(); + } + fid_ = FD::Convert(featname); + oov_fid_ = FD::Convert(featname+"_OOV"); + // cerr << "FID: " << oov_fid_ << endl; + SetStateSize(pimpl_->ReserveStateSize()); +} + +template <class Model> +Features KLanguageModel<Model>::features() const { + return single_feature(fid_); +} + +template <class Model> +KLanguageModel<Model>::~KLanguageModel() { + delete pimpl_; +} + +template <class Model> +void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, + const Hypergraph::Edge& edge, + const vector<const void*>& ant_states, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* state) const { + double est = 0; + double oovs = 0; + features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &oovs, state)); + if (oovs && oov_fid_) + features->set_value(oov_fid_, oovs); +} + +template <class Model> +void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state, + SparseVector<double>* features) const { + double oovs = 0; + double lm = pimpl_->FinalTraversalCost(ant_state, &oovs); + features->set_value(fid_, lm); + if (oov_fid_ && oovs) + features->set_value(oov_fid_, oovs); +} + +template <class Model> boost::shared_ptr<FeatureFunction> CreateModel(const std::string ¶m) { + KLanguageModel<Model> *ret = new KLanguageModel<Model>(param); + ret->Init(); + return boost::shared_ptr<FeatureFunction>(ret); +} + +boost::shared_ptr<FeatureFunction> KLanguageModelFactory::Create(std::string param) const { + using namespace lm::ngram; + std::string filename, ignored_map; + bool ignored_markers; + std::string ignored_featname; + ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); + ModelType m; + if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; + + switch (m) { + case HASH_PROBING: + return CreateModel<ProbingModel>(param); + case TRIE_SORTED: + return CreateModel<TrieModel>(param); + case ARRAY_TRIE_SORTED: + return CreateModel<ArrayTrieModel>(param); + case QUANT_TRIE_SORTED: + return CreateModel<QuantTrieModel>(param); + case QUANT_ARRAY_TRIE_SORTED: + return CreateModel<QuantArrayTrieModel>(param); + default: + UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); + } +} + +std::string KLanguageModelFactory::usage(bool params,bool verbose) const { + return KLanguageModel<lm::ngram::Model>::usage(params, verbose); +} + +#else + using namespace std; static const unsigned char HAS_FULL_CONTEXT = 1; @@ -469,3 +816,5 @@ boost::shared_ptr<FeatureFunction> KLanguageModelFactory::Create(std::string par std::string KLanguageModelFactory::usage(bool params,bool verbose) const { return KLanguageModel<lm::ngram::Model>::usage(params, verbose); } + +#endif |