diff options
Diffstat (limited to 'decoder/ff_klm.cc')
-rw-r--r-- | decoder/ff_klm.cc | 464 |
1 files changed, 0 insertions, 464 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 3c941fbf..ed6f731e 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,9 +12,6 @@ #include "lm/model.hh" #include "lm/enumerate_vocab.hh" -#define NEW_KENLM -#undef NEW_KENLM - #include "lm/left.hh" using namespace std; @@ -395,464 +392,3 @@ std::string KLanguageModelFactory::usage(bool params,bool verbose) const { return KLanguageModel<lm::ngram::Model>::usage(params, verbose); } -#else - -using namespace std; - -static const unsigned char HAS_FULL_CONTEXT = 1; -static const unsigned char HAS_EOS_ON_RIGHT = 2; -static const unsigned char MASK = 7; - -// -x : rules include <s> and </s> -// -n NAME : feature id is NAME -bool ParseLMArgs(string const& in, string* filename, string* mapfile, bool* explicit_markers, string* featname) { - vector<string> const& argv=SplitOnWhitespace(in); - *explicit_markers = false; - *featname="LanguageModel"; - *mapfile = ""; -#define LMSPEC_NEXTARG if (i==argv.end()) { \ - cerr << "Missing argument for "<<*last<<". "; goto usage; \ - } else { ++i; } - - for (vector<string>::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { - string const& s=*i; - if (s[0]=='-') { - if (s.size()>2) goto fail; - switch (s[1]) { - case 'x': - *explicit_markers = true; - break; - case 'm': - LMSPEC_NEXTARG; *mapfile=*i; - break; - case 'n': - LMSPEC_NEXTARG; *featname=*i; - break; -#undef LMSPEC_NEXTARG - default: - fail: - cerr<<"Unknown KLanguageModel option "<<s<<" ; "; - goto usage; - } - } else { - if (filename->empty()) - *filename=s; - else { - cerr<<"More than one filename provided. "; - goto usage; - } - } - } - if (!filename->empty()) - return true; -usage: - cerr << "KLanguageModel is incorrect!\n"; - return false; -} - -template <class Model> -string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) { - return "KLanguageModel"; -} - -struct VMapper : public lm::ngram::EnumerateVocab { - VMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } - void Add(lm::WordIndex index, const StringPiece &str) { - const WordID cdec_id = TD::Convert(str.as_string()); - if (cdec_id >= out_->size()) - out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); - (*out_)[cdec_id] = index; - } - vector<lm::WordIndex>* out_; - const lm::WordIndex kLM_UNKNOWN_TOKEN; -}; - -template <class Model> -class KLanguageModelImpl { - - // returns the number of unscored words at the left edge of a span - inline int UnscoredSize(const void* state) const { - return *(static_cast<const char*>(state) + unscored_size_offset_); - } - - inline void SetUnscoredSize(int size, void* state) const { - *(static_cast<char*>(state) + unscored_size_offset_) = size; - } - - static inline const lm::ngram::State& RemnantLMState(const void* state) { - return *static_cast<const lm::ngram::State*>(state); - } - - inline void SetRemnantLMState(const lm::ngram::State& lmstate, void* state) const { - // if we were clever, we could use the memory pointed to by state to do all - // the work, avoiding this copy - memcpy(state, &lmstate, ngram_->StateSize()); - } - - lm::WordIndex IthUnscoredWord(int i, const void* state) const { - const lm::WordIndex* const mem = reinterpret_cast<const lm::WordIndex*>(static_cast<const char*>(state) + unscored_words_offset_); - return mem[i]; - } - - void SetIthUnscoredWord(int i, lm::WordIndex index, void *state) const { - lm::WordIndex* mem = reinterpret_cast<lm::WordIndex*>(static_cast<char*>(state) + unscored_words_offset_); - mem[i] = index; - } - - inline bool GetFlag(const void *state, unsigned char flag) const { - return (*(static_cast<const char*>(state) + is_complete_offset_) & flag); - } - - inline void SetFlag(bool on, unsigned char flag, void *state) const { - if (on) { - *(static_cast<char*>(state) + is_complete_offset_) |= flag; - } else { - *(static_cast<char*>(state) + is_complete_offset_) &= (MASK ^ flag); - } - } - - inline bool HasFullContext(const void *state) const { - return GetFlag(state, HAS_FULL_CONTEXT); - } - - inline void SetHasFullContext(bool flag, void *state) const { - SetFlag(flag, HAS_FULL_CONTEXT, state); - } - - public: - double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, double* oovs, double* est_oovs, void* remnant) { - double sum = 0.0; - double est_sum = 0.0; - int num_scored = 0; - int num_estimated = 0; - if (oovs) *oovs = 0; - if (est_oovs) *est_oovs = 0; - bool saw_eos = false; - bool has_some_history = false; - lm::ngram::State state = ngram_->NullContextState(); - const vector<WordID>& e = rule.e(); - bool context_complete = false; - for (int j = 0; j < e.size(); ++j) { - if (e[j] < 1) { // handle non-terminal substitution - const void* astate = (ant_states[-e[j]]); - int unscored_ant_len = UnscoredSize(astate); - for (int k = 0; k < unscored_ant_len; ++k) { - const lm::WordIndex cur_word = IthUnscoredWord(k, astate); - const bool is_oov = (cur_word == 0); - double p = 0; - if (cur_word == kSOS_) { - state = ngram_->BeginSentenceState(); - if (has_some_history) { // this is immediately fully scored, and bad - p = -100; - context_complete = true; - } else { // this might be a real <s> - num_scored = max(0, order_ - 2); - } - } else { - const lm::ngram::State scopy(state); - p = ngram_->Score(scopy, cur_word, state); - if (saw_eos) { p = -100; } - saw_eos = (cur_word == kEOS_); - } - has_some_history = true; - ++num_scored; - if (!context_complete) { - if (num_scored >= order_) context_complete = true; - } - if (context_complete) { - sum += p; - if (oovs && is_oov) (*oovs)++; - } else { - if (remnant) - SetIthUnscoredWord(num_estimated, cur_word, remnant); - ++num_estimated; - est_sum += p; - if (est_oovs && is_oov) (*est_oovs)++; - } - } - saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT); - if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007 - state = RemnantLMState(astate); - context_complete = true; - } - } else { // handle terminal - const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[j]); // in future, - // maybe handle emission - const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id - double p = 0; - const bool is_oov = (cur_word == 0); - if (cur_word == kSOS_) { - state = ngram_->BeginSentenceState(); - if (has_some_history) { // this is immediately fully scored, and bad - p = -100; - context_complete = true; - } else { // this might be a real <s> - num_scored = max(0, order_ - 2); - } - } else { - const lm::ngram::State scopy(state); - p = ngram_->Score(scopy, cur_word, state); - if (saw_eos) { p = -100; } - saw_eos = (cur_word == kEOS_); - } - has_some_history = true; - ++num_scored; - if (!context_complete) { - if (num_scored >= order_) context_complete = true; - } - if (context_complete) { - sum += p; - if (oovs && is_oov) (*oovs)++; - } else { - if (remnant) - SetIthUnscoredWord(num_estimated, cur_word, remnant); - ++num_estimated; - est_sum += p; - if (est_oovs && is_oov) (*est_oovs)++; - } - } - } - if (pest_sum) *pest_sum = est_sum; - if (remnant) { - state.ZeroRemaining(); - SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant); - SetRemnantLMState(state, remnant); - SetUnscoredSize(num_estimated, remnant); - SetHasFullContext(context_complete || (num_scored >= order_), remnant); - } - return sum; - } - - // this assumes no target words on final unary -> goal rule. is that ok? - // for <s> (n-1 left words) and (n-1 right words) </s> - double FinalTraversalCost(const void* state, double* oovs) { - if (add_sos_eos_) { // rules do not produce <s> </s>, so do it here - SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_); - SetHasFullContext(1, dummy_state_); - SetUnscoredSize(0, dummy_state_); - dummy_ants_[1] = state; - *oovs = 0; - return LookupWords(*dummy_rule_, dummy_ants_, NULL, oovs, NULL, NULL); - } else { // rules DO produce <s> ... </s> - double p = 0; - if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } - if (UnscoredSize(state) > 0) { // are there unscored words - if (kSOS_ != IthUnscoredWord(0, state)) { - p -= 100 * UnscoredSize(state); - } - } - return p; - } - } - - // if this is not a class-based LM, returns w untransformed, - // otherwise returns a word class mapping of w, - // returns TD::Convert("<unk>") if there is no mapping for w - WordID ClassifyWordIfNecessary(WordID w) const { - if (word2class_map_.empty()) return w; - if (w >= word2class_map_.size()) - return kCDEC_UNK; - else - return word2class_map_[w]; - } - - // converts to cdec word id's to KenLM's id space, OOVs and <unk> end up at 0 - lm::WordIndex MapWord(WordID w) const { - if (w >= cdec2klm_map_.size()) - return 0; - else - return cdec2klm_map_[w]; - } - - public: - KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) : - kCDEC_UNK(TD::Convert("<unk>")) , - add_sos_eos_(!explicit_markers) { - { - VMapper vm(&cdec2klm_map_); - lm::ngram::Config conf; - conf.enumerate_vocab = &vm; - ngram_ = new Model(filename.c_str(), conf); - } - order_ = ngram_->Order(); - cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n"; - state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex); - unscored_size_offset_ = ngram_->StateSize(); - is_complete_offset_ = unscored_size_offset_ + 1; - unscored_words_offset_ = is_complete_offset_ + 1; - - // special handling of beginning / ending sentence markers - dummy_state_ = new char[state_size_]; - memset(dummy_state_, 0, state_size_); - dummy_ants_.push_back(dummy_state_); - dummy_ants_.push_back(NULL); - dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0")); - kSOS_ = MapWord(TD::Convert("<s>")); - assert(kSOS_ > 0); - kEOS_ = MapWord(TD::Convert("</s>")); - assert(kEOS_ > 0); - assert(MapWord(kCDEC_UNK) == 0); // KenLM invariant - - // handle class-based LMs (unambiguous word->class mapping reqd.) - if (mapfile.size()) - LoadWordClasses(mapfile); - } - - void LoadWordClasses(const string& file) { - ReadFile rf(file); - istream& in = *rf.stream(); - string line; - vector<WordID> dummy; - int lc = 0; - cerr << " Loading word classes from " << file << " ...\n"; - AddWordToClassMapping_(TD::Convert("<s>"), TD::Convert("<s>")); - AddWordToClassMapping_(TD::Convert("</s>"), TD::Convert("</s>")); - while(in) { - getline(in, line); - if (!in) continue; - dummy.clear(); - TD::ConvertSentence(line, &dummy); - ++lc; - if (dummy.size() != 2) { - cerr << " Format error in " << file << ", line " << lc << ": " << line << endl; - abort(); - } - AddWordToClassMapping_(dummy[0], dummy[1]); - } - } - - void AddWordToClassMapping_(WordID word, WordID cls) { - if (word2class_map_.size() <= word) { - word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK); - assert(word2class_map_.size() > word); - } - if(word2class_map_[word] != kCDEC_UNK) { - cerr << "Multiple classes for symbol " << TD::Convert(word) << endl; - abort(); - } - word2class_map_[word] = cls; - } - - ~KLanguageModelImpl() { - delete ngram_; - delete[] dummy_state_; - } - - int ReserveStateSize() const { return state_size_; } - - private: - const WordID kCDEC_UNK; - lm::WordIndex kSOS_; // <s> - requires special handling. - lm::WordIndex kEOS_; // </s> - Model* ngram_; - const bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s> - // if this is true, FinalTransitionFeatures will "add" <s> and </s> - // if false, FinalTransitionFeatures will score anything with the - // markers in the right place (i.e., the beginning and end of - // the sentence) with 0, and anything else with -100 - - int order_; - int state_size_; - int unscored_size_offset_; - int is_complete_offset_; - int unscored_words_offset_; - char* dummy_state_; - vector<const void*> dummy_ants_; - vector<lm::WordIndex> cdec2klm_map_; - vector<WordID> word2class_map_; // if this is a class-based LM, this is the word->class mapping - TRulePtr dummy_rule_; -}; - -template <class Model> -KLanguageModel<Model>::KLanguageModel(const string& param) { - string filename, mapfile, featname; - bool explicit_markers; - if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) { - abort(); - } - try { - pimpl_ = new KLanguageModelImpl<Model>(filename, mapfile, explicit_markers); - } catch (std::exception &e) { - std::cerr << e.what() << std::endl; - abort(); - } - fid_ = FD::Convert(featname); - oov_fid_ = FD::Convert(featname+"_OOV"); - cerr << "FID: " << oov_fid_ << endl; - SetStateSize(pimpl_->ReserveStateSize()); -} - -template <class Model> -Features KLanguageModel<Model>::features() const { - return single_feature(fid_); -} - -template <class Model> -KLanguageModel<Model>::~KLanguageModel() { - delete pimpl_; -} - -template <class Model> -void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, - const Hypergraph::Edge& edge, - const vector<const void*>& ant_states, - SparseVector<double>* features, - SparseVector<double>* estimated_features, - void* state) const { - double est = 0; - double oovs = 0; - double est_oovs = 0; - features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &est, &oovs, &est_oovs, state)); - estimated_features->set_value(fid_, est); - if (oov_fid_) { - if (oovs) features->set_value(oov_fid_, oovs); - if (est_oovs) estimated_features->set_value(oov_fid_, est_oovs); - } -} - -template <class Model> -void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state, - SparseVector<double>* features) const { - double oovs = 0; - double lm = pimpl_->FinalTraversalCost(ant_state, &oovs); - features->set_value(fid_, lm); - if (oov_fid_ && oovs) - features->set_value(oov_fid_, oovs); -} - -template <class Model> boost::shared_ptr<FeatureFunction> CreateModel(const std::string ¶m) { - KLanguageModel<Model> *ret = new KLanguageModel<Model>(param); - ret->Init(); - return boost::shared_ptr<FeatureFunction>(ret); -} - -boost::shared_ptr<FeatureFunction> KLanguageModelFactory::Create(std::string param) const { - using namespace lm::ngram; - std::string filename, ignored_map; - bool ignored_markers; - std::string ignored_featname; - ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); - ModelType m; - if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; - - switch (m) { - case HASH_PROBING: - return CreateModel<ProbingModel>(param); - case TRIE_SORTED: - return CreateModel<TrieModel>(param); - case ARRAY_TRIE_SORTED: - return CreateModel<ArrayTrieModel>(param); - case QUANT_TRIE_SORTED: - return CreateModel<QuantTrieModel>(param); - case QUANT_ARRAY_TRIE_SORTED: - return CreateModel<QuantArrayTrieModel>(param); - default: - UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); - } -} - -std::string KLanguageModelFactory::usage(bool params,bool verbose) const { - return KLanguageModel<lm::ngram::Model>::usage(params, verbose); -} - -#endif |