diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/ff_klm.cc | 464 | 
1 files changed, 0 insertions, 464 deletions
| diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 3c941fbf..ed6f731e 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,9 +12,6 @@  #include "lm/model.hh"  #include "lm/enumerate_vocab.hh" -#define NEW_KENLM -#undef NEW_KENLM -  #include "lm/left.hh"  using namespace std; @@ -395,464 +392,3 @@ std::string  KLanguageModelFactory::usage(bool params,bool verbose) const {    return KLanguageModel<lm::ngram::Model>::usage(params, verbose);  } -#else - -using namespace std; - -static const unsigned char HAS_FULL_CONTEXT = 1; -static const unsigned char HAS_EOS_ON_RIGHT = 2; -static const unsigned char MASK             = 7; - -// -x : rules include <s> and </s> -// -n NAME : feature id is NAME -bool ParseLMArgs(string const& in, string* filename, string* mapfile, bool* explicit_markers, string* featname) { -  vector<string> const& argv=SplitOnWhitespace(in); -  *explicit_markers = false; -  *featname="LanguageModel"; -  *mapfile = ""; -#define LMSPEC_NEXTARG if (i==argv.end()) {            \ -    cerr << "Missing argument for "<<*last<<". "; goto usage; \ -    } else { ++i; } - -  for (vector<string>::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { -    string const& s=*i; -    if (s[0]=='-') { -      if (s.size()>2) goto fail; -      switch (s[1]) { -      case 'x': -        *explicit_markers = true; -        break; -      case 'm': -        LMSPEC_NEXTARG; *mapfile=*i; -        break; -      case 'n': -        LMSPEC_NEXTARG; *featname=*i; -        break; -#undef LMSPEC_NEXTARG -      default: -      fail: -        cerr<<"Unknown KLanguageModel option "<<s<<" ; "; -        goto usage; -      } -    } else { -      if (filename->empty()) -        *filename=s; -      else { -        cerr<<"More than one filename provided. "; -        goto usage; -      } -    } -  } -  if (!filename->empty()) -    return true; -usage: -  cerr << "KLanguageModel is incorrect!\n"; -  return false; -} - -template <class Model> -string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) { -  return "KLanguageModel"; -} - -struct VMapper : public lm::ngram::EnumerateVocab { -  VMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } -  void Add(lm::WordIndex index, const StringPiece &str) { -    const WordID cdec_id = TD::Convert(str.as_string()); -    if (cdec_id >= out_->size()) -      out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); -    (*out_)[cdec_id] = index; -  } -  vector<lm::WordIndex>* out_; -  const lm::WordIndex kLM_UNKNOWN_TOKEN; -}; - -template <class Model> -class KLanguageModelImpl { - -  // returns the number of unscored words at the left edge of a span -  inline int UnscoredSize(const void* state) const { -    return *(static_cast<const char*>(state) + unscored_size_offset_); -  } - -  inline void SetUnscoredSize(int size, void* state) const { -    *(static_cast<char*>(state) + unscored_size_offset_) = size; -  } - -  static inline const lm::ngram::State& RemnantLMState(const void* state) { -    return *static_cast<const lm::ngram::State*>(state); -  } - -  inline void SetRemnantLMState(const lm::ngram::State& lmstate, void* state) const { -    // if we were clever, we could use the memory pointed to by state to do all -    // the work, avoiding this copy -    memcpy(state, &lmstate, ngram_->StateSize()); -  } - -  lm::WordIndex IthUnscoredWord(int i, const void* state) const { -    const lm::WordIndex* const mem = reinterpret_cast<const lm::WordIndex*>(static_cast<const char*>(state) + unscored_words_offset_); -    return mem[i]; -  } - -  void SetIthUnscoredWord(int i, lm::WordIndex index, void *state) const { -    lm::WordIndex* mem = reinterpret_cast<lm::WordIndex*>(static_cast<char*>(state) + unscored_words_offset_); -    mem[i] = index; -  } - -  inline bool GetFlag(const void *state, unsigned char flag) const { -    return (*(static_cast<const char*>(state) + is_complete_offset_) & flag); -  } - -  inline void SetFlag(bool on, unsigned char flag, void *state) const { -    if (on) { -      *(static_cast<char*>(state) + is_complete_offset_) |= flag; -    } else { -      *(static_cast<char*>(state) + is_complete_offset_) &= (MASK ^ flag); -    } -  } - -  inline bool HasFullContext(const void *state) const { -    return GetFlag(state, HAS_FULL_CONTEXT); -  } - -  inline void SetHasFullContext(bool flag, void *state) const { -    SetFlag(flag, HAS_FULL_CONTEXT, state); -  } - - public: -  double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, double* oovs, double* est_oovs, void* remnant) { -    double sum = 0.0; -    double est_sum = 0.0; -    int num_scored = 0; -    int num_estimated = 0; -    if (oovs) *oovs = 0; -    if (est_oovs) *est_oovs = 0; -    bool saw_eos = false; -    bool has_some_history = false; -    lm::ngram::State state = ngram_->NullContextState(); -    const vector<WordID>& e = rule.e(); -    bool context_complete = false; -    for (int j = 0; j < e.size(); ++j) { -      if (e[j] < 1) {   // handle non-terminal substitution -        const void* astate = (ant_states[-e[j]]); -        int unscored_ant_len = UnscoredSize(astate); -        for (int k = 0; k < unscored_ant_len; ++k) { -          const lm::WordIndex cur_word = IthUnscoredWord(k, astate); -          const bool is_oov = (cur_word == 0); -          double p = 0; -          if (cur_word == kSOS_) { -            state = ngram_->BeginSentenceState(); -            if (has_some_history) {  // this is immediately fully scored, and bad -              p = -100; -              context_complete = true; -            } else {  // this might be a real <s> -              num_scored = max(0, order_ - 2); -            } -          } else { -            const lm::ngram::State scopy(state); -            p = ngram_->Score(scopy, cur_word, state); -            if (saw_eos) { p = -100; } -            saw_eos = (cur_word == kEOS_); -          } -          has_some_history = true; -          ++num_scored; -          if (!context_complete) { -            if (num_scored >= order_) context_complete = true; -          } -          if (context_complete) { -            sum += p; -            if (oovs && is_oov) (*oovs)++; -          } else { -            if (remnant) -              SetIthUnscoredWord(num_estimated, cur_word, remnant); -            ++num_estimated; -            est_sum += p; -            if (est_oovs && is_oov) (*est_oovs)++; -          } -        } -        saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT); -        if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007 -          state = RemnantLMState(astate); -          context_complete = true; -        } -      } else {   // handle terminal -        const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[j]);  // in future, -                                                                          // maybe handle emission -        const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id -        double p = 0; -        const bool is_oov = (cur_word == 0); -        if (cur_word == kSOS_) { -          state = ngram_->BeginSentenceState(); -          if (has_some_history) {  // this is immediately fully scored, and bad -            p = -100; -            context_complete = true; -          } else {  // this might be a real <s> -            num_scored = max(0, order_ - 2); -          } -        } else { -          const lm::ngram::State scopy(state); -          p = ngram_->Score(scopy, cur_word, state); -          if (saw_eos) { p = -100; } -          saw_eos = (cur_word == kEOS_); -        } -        has_some_history = true; -        ++num_scored; -        if (!context_complete) { -          if (num_scored >= order_) context_complete = true; -        } -        if (context_complete) { -          sum += p; -          if (oovs && is_oov) (*oovs)++; -        } else { -          if (remnant) -            SetIthUnscoredWord(num_estimated, cur_word, remnant); -          ++num_estimated; -          est_sum += p; -          if (est_oovs && is_oov) (*est_oovs)++; -        } -      } -    } -    if (pest_sum) *pest_sum = est_sum; -    if (remnant) { -      state.ZeroRemaining(); -      SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant); -      SetRemnantLMState(state, remnant); -      SetUnscoredSize(num_estimated, remnant); -      SetHasFullContext(context_complete || (num_scored >= order_), remnant); -    } -    return sum; -  } - -  // this assumes no target words on final unary -> goal rule.  is that ok? -  // for <s> (n-1 left words) and (n-1 right words) </s> -  double FinalTraversalCost(const void* state, double* oovs) { -    if (add_sos_eos_) {  // rules do not produce <s> </s>, so do it here -      SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_); -      SetHasFullContext(1, dummy_state_); -      SetUnscoredSize(0, dummy_state_); -      dummy_ants_[1] = state; -      *oovs = 0; -      return LookupWords(*dummy_rule_, dummy_ants_, NULL, oovs, NULL, NULL); -    } else {  // rules DO produce <s> ... </s> -      double p = 0; -      if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } -      if (UnscoredSize(state) > 0) {  // are there unscored words -        if (kSOS_ != IthUnscoredWord(0, state)) { -          p -= 100 * UnscoredSize(state); -        } -      } -      return p; -    } -  } - -  // if this is not a class-based LM, returns w untransformed, -  // otherwise returns a word class mapping of w, -  // returns TD::Convert("<unk>") if there is no mapping for w -  WordID ClassifyWordIfNecessary(WordID w) const { -    if (word2class_map_.empty()) return w; -    if (w >= word2class_map_.size()) -      return kCDEC_UNK; -    else -      return word2class_map_[w]; -  } - -  // converts to cdec word id's to KenLM's id space, OOVs and <unk> end up at 0 -  lm::WordIndex MapWord(WordID w) const { -    if (w >= cdec2klm_map_.size()) -      return 0; -    else -      return cdec2klm_map_[w]; -  } - - public: -  KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) : -      kCDEC_UNK(TD::Convert("<unk>")) , -      add_sos_eos_(!explicit_markers) { -    { -      VMapper vm(&cdec2klm_map_); -      lm::ngram::Config conf; -      conf.enumerate_vocab = &vm; -      ngram_ = new Model(filename.c_str(), conf); -    } -    order_ = ngram_->Order(); -    cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n"; -    state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex); -    unscored_size_offset_ = ngram_->StateSize(); -    is_complete_offset_ = unscored_size_offset_ + 1; -    unscored_words_offset_ = is_complete_offset_ + 1; - -    // special handling of beginning / ending sentence markers -    dummy_state_ = new char[state_size_]; -    memset(dummy_state_, 0, state_size_); -    dummy_ants_.push_back(dummy_state_); -    dummy_ants_.push_back(NULL); -    dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0")); -    kSOS_ = MapWord(TD::Convert("<s>")); -    assert(kSOS_ > 0); -    kEOS_ = MapWord(TD::Convert("</s>")); -    assert(kEOS_ > 0); -    assert(MapWord(kCDEC_UNK) == 0); // KenLM invariant - -    // handle class-based LMs (unambiguous word->class mapping reqd.) -    if (mapfile.size()) -      LoadWordClasses(mapfile); -  } - -  void LoadWordClasses(const string& file) { -    ReadFile rf(file); -    istream& in = *rf.stream(); -    string line; -    vector<WordID> dummy; -    int lc = 0; -    cerr << "  Loading word classes from " << file << " ...\n"; -    AddWordToClassMapping_(TD::Convert("<s>"), TD::Convert("<s>")); -    AddWordToClassMapping_(TD::Convert("</s>"), TD::Convert("</s>")); -    while(in) { -      getline(in, line); -      if (!in) continue; -      dummy.clear(); -      TD::ConvertSentence(line, &dummy); -      ++lc; -      if (dummy.size() != 2) { -        cerr << "    Format error in " << file << ", line " << lc << ": " << line << endl; -        abort(); -      } -      AddWordToClassMapping_(dummy[0], dummy[1]); -    } -  } - -  void AddWordToClassMapping_(WordID word, WordID cls) { -    if (word2class_map_.size() <= word) { -      word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK); -      assert(word2class_map_.size() > word); -    } -    if(word2class_map_[word] != kCDEC_UNK) { -      cerr << "Multiple classes for symbol " << TD::Convert(word) << endl; -      abort(); -    } -    word2class_map_[word] = cls; -  } - -  ~KLanguageModelImpl() { -    delete ngram_; -    delete[] dummy_state_; -  } - -  int ReserveStateSize() const { return state_size_; } - - private: -  const WordID kCDEC_UNK; -  lm::WordIndex kSOS_;  // <s> - requires special handling. -  lm::WordIndex kEOS_;  // </s> -  Model* ngram_; -  const bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s> -                     // if this is true, FinalTransitionFeatures will "add" <s> and </s> -                     // if false, FinalTransitionFeatures will score anything with the -                     // markers in the right place (i.e., the beginning and end of -                     // the sentence) with 0, and anything else with -100 - -  int order_; -  int state_size_; -  int unscored_size_offset_; -  int is_complete_offset_; -  int unscored_words_offset_; -  char* dummy_state_; -  vector<const void*> dummy_ants_; -  vector<lm::WordIndex> cdec2klm_map_; -  vector<WordID> word2class_map_;        // if this is a class-based LM, this is the word->class mapping -  TRulePtr dummy_rule_; -}; - -template <class Model> -KLanguageModel<Model>::KLanguageModel(const string& param) { -  string filename, mapfile, featname; -  bool explicit_markers; -  if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) { -    abort(); -  } -  try { -    pimpl_ = new KLanguageModelImpl<Model>(filename, mapfile, explicit_markers); -  } catch (std::exception &e) { -    std::cerr << e.what() << std::endl; -    abort(); -  } -  fid_ = FD::Convert(featname); -  oov_fid_ = FD::Convert(featname+"_OOV"); -  cerr << "FID: " << oov_fid_ << endl; -  SetStateSize(pimpl_->ReserveStateSize()); -} - -template <class Model> -Features KLanguageModel<Model>::features() const { -  return single_feature(fid_); -} - -template <class Model> -KLanguageModel<Model>::~KLanguageModel() { -  delete pimpl_; -} - -template <class Model> -void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, -                                          const Hypergraph::Edge& edge, -                                          const vector<const void*>& ant_states, -                                          SparseVector<double>* features, -                                          SparseVector<double>* estimated_features, -                                          void* state) const { -  double est = 0; -  double oovs = 0; -  double est_oovs = 0; -  features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &est, &oovs, &est_oovs, state)); -  estimated_features->set_value(fid_, est); -  if (oov_fid_) { -    if (oovs) features->set_value(oov_fid_, oovs); -    if (est_oovs) estimated_features->set_value(oov_fid_, est_oovs); -  } -} - -template <class Model> -void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state, -                                           SparseVector<double>* features) const { -  double oovs = 0; -  double lm = pimpl_->FinalTraversalCost(ant_state, &oovs); -  features->set_value(fid_, lm); -  if (oov_fid_ && oovs) -    features->set_value(oov_fid_, oovs); -} - -template <class Model> boost::shared_ptr<FeatureFunction> CreateModel(const std::string ¶m) { -  KLanguageModel<Model> *ret = new KLanguageModel<Model>(param); -  ret->Init(); -  return boost::shared_ptr<FeatureFunction>(ret); -} - -boost::shared_ptr<FeatureFunction> KLanguageModelFactory::Create(std::string param) const { -  using namespace lm::ngram; -  std::string filename, ignored_map; -  bool ignored_markers; -  std::string ignored_featname; -  ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); -  ModelType m; -  if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; - -  switch (m) { -    case HASH_PROBING: -      return CreateModel<ProbingModel>(param); -    case TRIE_SORTED: -      return CreateModel<TrieModel>(param); -    case ARRAY_TRIE_SORTED: -      return CreateModel<ArrayTrieModel>(param); -    case QUANT_TRIE_SORTED: -      return CreateModel<QuantTrieModel>(param); -    case QUANT_ARRAY_TRIE_SORTED: -      return CreateModel<QuantArrayTrieModel>(param); -    default: -      UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); -  } -} - -std::string  KLanguageModelFactory::usage(bool params,bool verbose) const { -  return KLanguageModel<lm::ngram::Model>::usage(params, verbose); -} - -#endif | 
