diff options
| -rw-r--r-- | decoder/ff_klm.cc | 349 | 
1 files changed, 349 insertions, 0 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 24dcb9c3..016aad26 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,6 +12,353 @@  #include "lm/model.hh"  #include "lm/enumerate_vocab.hh" +#undef NEW_KENLM +#ifdef NEW_KENLM + +#include "lm/left.hh" + +using namespace std; + +// -x : rules include <s> and </s> +// -n NAME : feature id is NAME +bool ParseLMArgs(string const& in, string* filename, string* mapfile, bool* explicit_markers, string* featname) { +  vector<string> const& argv=SplitOnWhitespace(in); +  *explicit_markers = false; +  *featname="LanguageModel"; +  *mapfile = ""; +#define LMSPEC_NEXTARG if (i==argv.end()) {            \ +    cerr << "Missing argument for "<<*last<<". "; goto usage; \ +    } else { ++i; } + +  for (vector<string>::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { +    string const& s=*i; +    if (s[0]=='-') { +      if (s.size()>2) goto fail; +      switch (s[1]) { +      case 'x': +        *explicit_markers = true; +        break; +      case 'm': +        LMSPEC_NEXTARG; *mapfile=*i; +        break; +      case 'n': +        LMSPEC_NEXTARG; *featname=*i; +        break; +#undef LMSPEC_NEXTARG +      default: +      fail: +        cerr<<"Unknown KLanguageModel option "<<s<<" ; "; +        goto usage; +      } +    } else { +      if (filename->empty()) +        *filename=s; +      else { +        cerr<<"More than one filename provided. "; +        goto usage; +      } +    } +  } +  if (!filename->empty()) +    return true; +usage: +  cerr << "KLanguageModel is incorrect!\n"; +  return false; +} + +template <class Model> +string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) { +  return "KLanguageModel"; +} + +struct VMapper : public lm::ngram::EnumerateVocab { +  VMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } +  void Add(lm::WordIndex index, const StringPiece &str) { +    const WordID cdec_id = TD::Convert(str.as_string()); +    if (cdec_id >= out_->size()) +      out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); +    (*out_)[cdec_id] = index; +  } +  vector<lm::WordIndex>* out_; +  const lm::WordIndex kLM_UNKNOWN_TOKEN; +}; + +template <class Model> +class KLanguageModelImpl { + +  static inline const lm::ngram::ChartState& RemnantLMState(const void* state) { +    return *static_cast<const lm::ngram::ChartState*>(state); +  } + +  inline void SetRemnantLMState(const lm::ngram::ChartState& lmstate, void* state) const { +    // if we were clever, we could use the memory pointed to by state to do all +    // the work, avoiding this copy +    memcpy(state, &lmstate, ngram_->StateSize()); +  } + + public: +  double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* oovs, void* remnant) { +    double sum = 0.0; +    if (oovs) *oovs = 0; +    const vector<WordID>& e = rule.e(); +    lm::ngram::ChartState state; +    lm::ngram::RuleScore<Model> ruleScore(*ngram_, state); +    unsigned i = 0; +    if (e.size()) { +      if (e[i] == kCDEC_SOS) { +        ++i; +        ruleScore.BeginSentence(); +      } else if (e[i] <= 0) {  // special case for left-edge NT +        const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[0]]); +        ruleScore.BeginNonTerminal(prevState, 0.0f);  // TODO +        ++i; +      } +    } +    for (; i < e.size(); ++i) { +      if (e[i] <= 0) { +        const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[i]]); +        ruleScore.NonTerminal(prevState, 0.0f);  // TODO +      } else { +        const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]);  // in future, +                                                                          // maybe handle emission +        const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id +        const bool is_oov = (cur_word == 0); +        if (is_oov) (*oovs) += 1.0; +        ruleScore.Terminal(cur_word); +      } +    } +    if (remnant) SetRemnantLMState(state, remnant); +    return ruleScore.Finish(); +  } + +  // this assumes no target words on final unary -> goal rule.  is that ok? +  // for <s> (n-1 left words) and (n-1 right words) </s> +  double FinalTraversalCost(const void* state, double* oovs) { +    if (add_sos_eos_) {  // rules do not produce <s> </s>, so do it here +      lm::ngram::ChartState cstate; +      lm::ngram::RuleScore<Model> ruleScore(*ngram_, cstate); +      ruleScore.BeginSentence(); +      SetRemnantLMState(cstate, dummy_state_); +      dummy_ants_[1] = state; +      *oovs = 0; +      return LookupWords(*dummy_rule_, dummy_ants_, oovs, NULL); +    } else {  // rules DO produce <s> ... </s> +      double p = 0; +      cerr << "not implemented"; abort(); // TODO +      //if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } +      //if (UnscoredSize(state) > 0) {  // are there unscored words +      //  if (kSOS_ != IthUnscoredWord(0, state)) { +      //    p -= 100 * UnscoredSize(state); +      //  } +      //} +      return p; +    } +  } + +  // if this is not a class-based LM, returns w untransformed, +  // otherwise returns a word class mapping of w, +  // returns TD::Convert("<unk>") if there is no mapping for w +  WordID ClassifyWordIfNecessary(WordID w) const { +    if (word2class_map_.empty()) return w; +    if (w >= word2class_map_.size()) +      return kCDEC_UNK; +    else +      return word2class_map_[w]; +  } + +  // converts to cdec word id's to KenLM's id space, OOVs and <unk> end up at 0 +  lm::WordIndex MapWord(WordID w) const { +    if (w >= cdec2klm_map_.size()) +      return 0; +    else +      return cdec2klm_map_[w]; +  } + + public: +  KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) : +      kCDEC_UNK(TD::Convert("<unk>")) , +      kCDEC_SOS(TD::Convert("<s>")) , +      add_sos_eos_(!explicit_markers) { +    { +      VMapper vm(&cdec2klm_map_); +      lm::ngram::Config conf; +      conf.enumerate_vocab = &vm; +      ngram_ = new Model(filename.c_str(), conf); +    } +    order_ = ngram_->Order(); +    cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n"; +    state_size_ = sizeof(lm::ngram::ChartState); + +    // special handling of beginning / ending sentence markers +    dummy_state_ = new char[state_size_]; +    memset(dummy_state_, 0, state_size_); +    dummy_ants_.push_back(dummy_state_); +    dummy_ants_.push_back(NULL); +    dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0")); +    kSOS_ = MapWord(kCDEC_SOS); +    assert(kSOS_ > 0); +    kEOS_ = MapWord(TD::Convert("</s>")); +    assert(kEOS_ > 0); +    assert(MapWord(kCDEC_UNK) == 0); // KenLM invariant + +    // handle class-based LMs (unambiguous word->class mapping reqd.) +    if (mapfile.size()) +      LoadWordClasses(mapfile); +  } + +  void LoadWordClasses(const string& file) { +    ReadFile rf(file); +    istream& in = *rf.stream(); +    string line; +    vector<WordID> dummy; +    int lc = 0; +    cerr << "  Loading word classes from " << file << " ...\n"; +    AddWordToClassMapping_(TD::Convert("<s>"), TD::Convert("<s>")); +    AddWordToClassMapping_(TD::Convert("</s>"), TD::Convert("</s>")); +    while(in) { +      getline(in, line); +      if (!in) continue; +      dummy.clear(); +      TD::ConvertSentence(line, &dummy); +      ++lc; +      if (dummy.size() != 2) { +        cerr << "    Format error in " << file << ", line " << lc << ": " << line << endl; +        abort(); +      } +      AddWordToClassMapping_(dummy[0], dummy[1]); +    } +  } + +  void AddWordToClassMapping_(WordID word, WordID cls) { +    if (word2class_map_.size() <= word) { +      word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK); +      assert(word2class_map_.size() > word); +    } +    if(word2class_map_[word] != kCDEC_UNK) { +      cerr << "Multiple classes for symbol " << TD::Convert(word) << endl; +      abort(); +    } +    word2class_map_[word] = cls; +  } + +  ~KLanguageModelImpl() { +    delete ngram_; +    delete[] dummy_state_; +  } + +  int ReserveStateSize() const { return state_size_; } + + private: +  const WordID kCDEC_UNK; +  const WordID kCDEC_SOS; +  lm::WordIndex kSOS_;  // <s> - requires special handling. +  lm::WordIndex kEOS_;  // </s> +  Model* ngram_; +  const bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s> +                     // if this is true, FinalTransitionFeatures will "add" <s> and </s> +                     // if false, FinalTransitionFeatures will score anything with the +                     // markers in the right place (i.e., the beginning and end of +                     // the sentence) with 0, and anything else with -100 + +  int order_; +  int state_size_; +  char* dummy_state_; +  vector<const void*> dummy_ants_; +  vector<lm::WordIndex> cdec2klm_map_; +  vector<WordID> word2class_map_;        // if this is a class-based LM, this is the word->class mapping +  TRulePtr dummy_rule_; +}; + +template <class Model> +KLanguageModel<Model>::KLanguageModel(const string& param) { +  string filename, mapfile, featname; +  bool explicit_markers; +  if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) { +    abort(); +  } +  try { +    pimpl_ = new KLanguageModelImpl<Model>(filename, mapfile, explicit_markers); +  } catch (std::exception &e) { +    std::cerr << e.what() << std::endl; +    abort(); +  } +  fid_ = FD::Convert(featname); +  oov_fid_ = FD::Convert(featname+"_OOV"); +  // cerr << "FID: " << oov_fid_ << endl; +  SetStateSize(pimpl_->ReserveStateSize()); +} + +template <class Model> +Features KLanguageModel<Model>::features() const { +  return single_feature(fid_); +} + +template <class Model> +KLanguageModel<Model>::~KLanguageModel() { +  delete pimpl_; +} + +template <class Model> +void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, +                                          const Hypergraph::Edge& edge, +                                          const vector<const void*>& ant_states, +                                          SparseVector<double>* features, +                                          SparseVector<double>* estimated_features, +                                          void* state) const { +  double est = 0; +  double oovs = 0; +  features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &oovs, state)); +  if (oovs && oov_fid_) +    features->set_value(oov_fid_, oovs); +} + +template <class Model> +void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state, +                                           SparseVector<double>* features) const { +  double oovs = 0; +  double lm = pimpl_->FinalTraversalCost(ant_state, &oovs); +  features->set_value(fid_, lm); +  if (oov_fid_ && oovs) +    features->set_value(oov_fid_, oovs); +} + +template <class Model> boost::shared_ptr<FeatureFunction> CreateModel(const std::string ¶m) { +  KLanguageModel<Model> *ret = new KLanguageModel<Model>(param); +  ret->Init(); +  return boost::shared_ptr<FeatureFunction>(ret); +} + +boost::shared_ptr<FeatureFunction> KLanguageModelFactory::Create(std::string param) const { +  using namespace lm::ngram; +  std::string filename, ignored_map; +  bool ignored_markers; +  std::string ignored_featname; +  ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); +  ModelType m; +  if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; + +  switch (m) { +    case HASH_PROBING: +      return CreateModel<ProbingModel>(param); +    case TRIE_SORTED: +      return CreateModel<TrieModel>(param); +    case ARRAY_TRIE_SORTED: +      return CreateModel<ArrayTrieModel>(param); +    case QUANT_TRIE_SORTED: +      return CreateModel<QuantTrieModel>(param); +    case QUANT_ARRAY_TRIE_SORTED: +      return CreateModel<QuantArrayTrieModel>(param); +    default: +      UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); +  } +} + +std::string  KLanguageModelFactory::usage(bool params,bool verbose) const { +  return KLanguageModel<lm::ngram::Model>::usage(params, verbose); +} + +#else +  using namespace std;  static const unsigned char HAS_FULL_CONTEXT = 1; @@ -469,3 +816,5 @@ boost::shared_ptr<FeatureFunction> KLanguageModelFactory::Create(std::string par  std::string  KLanguageModelFactory::usage(bool params,bool verbose) const {    return KLanguageModel<lm::ngram::Model>::usage(params, verbose);  } + +#endif  | 
