summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-10-18 14:19:09 +0100
committerChris Dyer <cdyer@cs.cmu.edu>2011-10-18 14:19:09 +0100
commit09297047e446f49804d3f48bf320cdbd38d6396a (patch)
treec2b2eea5565e444db59863affa06ae9d93666c02
parent97520627383196a8ba8c4e0656573546f4a03fbc (diff)
incorporate kenneth's fixes
-rw-r--r--decoder/ff_klm.cc464
1 files changed, 0 insertions, 464 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index 3c941fbf..ed6f731e 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -12,9 +12,6 @@
#include "lm/model.hh"
#include "lm/enumerate_vocab.hh"
-#define NEW_KENLM
-#undef NEW_KENLM
-
#include "lm/left.hh"
using namespace std;
@@ -395,464 +392,3 @@ std::string KLanguageModelFactory::usage(bool params,bool verbose) const {
return KLanguageModel<lm::ngram::Model>::usage(params, verbose);
}
-#else
-
-using namespace std;
-
-static const unsigned char HAS_FULL_CONTEXT = 1;
-static const unsigned char HAS_EOS_ON_RIGHT = 2;
-static const unsigned char MASK = 7;
-
-// -x : rules include <s> and </s>
-// -n NAME : feature id is NAME
-bool ParseLMArgs(string const& in, string* filename, string* mapfile, bool* explicit_markers, string* featname) {
- vector<string> const& argv=SplitOnWhitespace(in);
- *explicit_markers = false;
- *featname="LanguageModel";
- *mapfile = "";
-#define LMSPEC_NEXTARG if (i==argv.end()) { \
- cerr << "Missing argument for "<<*last<<". "; goto usage; \
- } else { ++i; }
-
- for (vector<string>::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) {
- string const& s=*i;
- if (s[0]=='-') {
- if (s.size()>2) goto fail;
- switch (s[1]) {
- case 'x':
- *explicit_markers = true;
- break;
- case 'm':
- LMSPEC_NEXTARG; *mapfile=*i;
- break;
- case 'n':
- LMSPEC_NEXTARG; *featname=*i;
- break;
-#undef LMSPEC_NEXTARG
- default:
- fail:
- cerr<<"Unknown KLanguageModel option "<<s<<" ; ";
- goto usage;
- }
- } else {
- if (filename->empty())
- *filename=s;
- else {
- cerr<<"More than one filename provided. ";
- goto usage;
- }
- }
- }
- if (!filename->empty())
- return true;
-usage:
- cerr << "KLanguageModel is incorrect!\n";
- return false;
-}
-
-template <class Model>
-string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) {
- return "KLanguageModel";
-}
-
-struct VMapper : public lm::ngram::EnumerateVocab {
- VMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); }
- void Add(lm::WordIndex index, const StringPiece &str) {
- const WordID cdec_id = TD::Convert(str.as_string());
- if (cdec_id >= out_->size())
- out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN);
- (*out_)[cdec_id] = index;
- }
- vector<lm::WordIndex>* out_;
- const lm::WordIndex kLM_UNKNOWN_TOKEN;
-};
-
-template <class Model>
-class KLanguageModelImpl {
-
- // returns the number of unscored words at the left edge of a span
- inline int UnscoredSize(const void* state) const {
- return *(static_cast<const char*>(state) + unscored_size_offset_);
- }
-
- inline void SetUnscoredSize(int size, void* state) const {
- *(static_cast<char*>(state) + unscored_size_offset_) = size;
- }
-
- static inline const lm::ngram::State& RemnantLMState(const void* state) {
- return *static_cast<const lm::ngram::State*>(state);
- }
-
- inline void SetRemnantLMState(const lm::ngram::State& lmstate, void* state) const {
- // if we were clever, we could use the memory pointed to by state to do all
- // the work, avoiding this copy
- memcpy(state, &lmstate, ngram_->StateSize());
- }
-
- lm::WordIndex IthUnscoredWord(int i, const void* state) const {
- const lm::WordIndex* const mem = reinterpret_cast<const lm::WordIndex*>(static_cast<const char*>(state) + unscored_words_offset_);
- return mem[i];
- }
-
- void SetIthUnscoredWord(int i, lm::WordIndex index, void *state) const {
- lm::WordIndex* mem = reinterpret_cast<lm::WordIndex*>(static_cast<char*>(state) + unscored_words_offset_);
- mem[i] = index;
- }
-
- inline bool GetFlag(const void *state, unsigned char flag) const {
- return (*(static_cast<const char*>(state) + is_complete_offset_) & flag);
- }
-
- inline void SetFlag(bool on, unsigned char flag, void *state) const {
- if (on) {
- *(static_cast<char*>(state) + is_complete_offset_) |= flag;
- } else {
- *(static_cast<char*>(state) + is_complete_offset_) &= (MASK ^ flag);
- }
- }
-
- inline bool HasFullContext(const void *state) const {
- return GetFlag(state, HAS_FULL_CONTEXT);
- }
-
- inline void SetHasFullContext(bool flag, void *state) const {
- SetFlag(flag, HAS_FULL_CONTEXT, state);
- }
-
- public:
- double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* pest_sum, double* oovs, double* est_oovs, void* remnant) {
- double sum = 0.0;
- double est_sum = 0.0;
- int num_scored = 0;
- int num_estimated = 0;
- if (oovs) *oovs = 0;
- if (est_oovs) *est_oovs = 0;
- bool saw_eos = false;
- bool has_some_history = false;
- lm::ngram::State state = ngram_->NullContextState();
- const vector<WordID>& e = rule.e();
- bool context_complete = false;
- for (int j = 0; j < e.size(); ++j) {
- if (e[j] < 1) { // handle non-terminal substitution
- const void* astate = (ant_states[-e[j]]);
- int unscored_ant_len = UnscoredSize(astate);
- for (int k = 0; k < unscored_ant_len; ++k) {
- const lm::WordIndex cur_word = IthUnscoredWord(k, astate);
- const bool is_oov = (cur_word == 0);
- double p = 0;
- if (cur_word == kSOS_) {
- state = ngram_->BeginSentenceState();
- if (has_some_history) { // this is immediately fully scored, and bad
- p = -100;
- context_complete = true;
- } else { // this might be a real <s>
- num_scored = max(0, order_ - 2);
- }
- } else {
- const lm::ngram::State scopy(state);
- p = ngram_->Score(scopy, cur_word, state);
- if (saw_eos) { p = -100; }
- saw_eos = (cur_word == kEOS_);
- }
- has_some_history = true;
- ++num_scored;
- if (!context_complete) {
- if (num_scored >= order_) context_complete = true;
- }
- if (context_complete) {
- sum += p;
- if (oovs && is_oov) (*oovs)++;
- } else {
- if (remnant)
- SetIthUnscoredWord(num_estimated, cur_word, remnant);
- ++num_estimated;
- est_sum += p;
- if (est_oovs && is_oov) (*est_oovs)++;
- }
- }
- saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT);
- if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007
- state = RemnantLMState(astate);
- context_complete = true;
- }
- } else { // handle terminal
- const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[j]); // in future,
- // maybe handle emission
- const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id
- double p = 0;
- const bool is_oov = (cur_word == 0);
- if (cur_word == kSOS_) {
- state = ngram_->BeginSentenceState();
- if (has_some_history) { // this is immediately fully scored, and bad
- p = -100;
- context_complete = true;
- } else { // this might be a real <s>
- num_scored = max(0, order_ - 2);
- }
- } else {
- const lm::ngram::State scopy(state);
- p = ngram_->Score(scopy, cur_word, state);
- if (saw_eos) { p = -100; }
- saw_eos = (cur_word == kEOS_);
- }
- has_some_history = true;
- ++num_scored;
- if (!context_complete) {
- if (num_scored >= order_) context_complete = true;
- }
- if (context_complete) {
- sum += p;
- if (oovs && is_oov) (*oovs)++;
- } else {
- if (remnant)
- SetIthUnscoredWord(num_estimated, cur_word, remnant);
- ++num_estimated;
- est_sum += p;
- if (est_oovs && is_oov) (*est_oovs)++;
- }
- }
- }
- if (pest_sum) *pest_sum = est_sum;
- if (remnant) {
- state.ZeroRemaining();
- SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant);
- SetRemnantLMState(state, remnant);
- SetUnscoredSize(num_estimated, remnant);
- SetHasFullContext(context_complete || (num_scored >= order_), remnant);
- }
- return sum;
- }
-
- // this assumes no target words on final unary -> goal rule. is that ok?
- // for <s> (n-1 left words) and (n-1 right words) </s>
- double FinalTraversalCost(const void* state, double* oovs) {
- if (add_sos_eos_) { // rules do not produce <s> </s>, so do it here
- SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_);
- SetHasFullContext(1, dummy_state_);
- SetUnscoredSize(0, dummy_state_);
- dummy_ants_[1] = state;
- *oovs = 0;
- return LookupWords(*dummy_rule_, dummy_ants_, NULL, oovs, NULL, NULL);
- } else { // rules DO produce <s> ... </s>
- double p = 0;
- if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; }
- if (UnscoredSize(state) > 0) { // are there unscored words
- if (kSOS_ != IthUnscoredWord(0, state)) {
- p -= 100 * UnscoredSize(state);
- }
- }
- return p;
- }
- }
-
- // if this is not a class-based LM, returns w untransformed,
- // otherwise returns a word class mapping of w,
- // returns TD::Convert("<unk>") if there is no mapping for w
- WordID ClassifyWordIfNecessary(WordID w) const {
- if (word2class_map_.empty()) return w;
- if (w >= word2class_map_.size())
- return kCDEC_UNK;
- else
- return word2class_map_[w];
- }
-
- // converts to cdec word id's to KenLM's id space, OOVs and <unk> end up at 0
- lm::WordIndex MapWord(WordID w) const {
- if (w >= cdec2klm_map_.size())
- return 0;
- else
- return cdec2klm_map_[w];
- }
-
- public:
- KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) :
- kCDEC_UNK(TD::Convert("<unk>")) ,
- add_sos_eos_(!explicit_markers) {
- {
- VMapper vm(&cdec2klm_map_);
- lm::ngram::Config conf;
- conf.enumerate_vocab = &vm;
- ngram_ = new Model(filename.c_str(), conf);
- }
- order_ = ngram_->Order();
- cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n";
- state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex);
- unscored_size_offset_ = ngram_->StateSize();
- is_complete_offset_ = unscored_size_offset_ + 1;
- unscored_words_offset_ = is_complete_offset_ + 1;
-
- // special handling of beginning / ending sentence markers
- dummy_state_ = new char[state_size_];
- memset(dummy_state_, 0, state_size_);
- dummy_ants_.push_back(dummy_state_);
- dummy_ants_.push_back(NULL);
- dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0"));
- kSOS_ = MapWord(TD::Convert("<s>"));
- assert(kSOS_ > 0);
- kEOS_ = MapWord(TD::Convert("</s>"));
- assert(kEOS_ > 0);
- assert(MapWord(kCDEC_UNK) == 0); // KenLM invariant
-
- // handle class-based LMs (unambiguous word->class mapping reqd.)
- if (mapfile.size())
- LoadWordClasses(mapfile);
- }
-
- void LoadWordClasses(const string& file) {
- ReadFile rf(file);
- istream& in = *rf.stream();
- string line;
- vector<WordID> dummy;
- int lc = 0;
- cerr << " Loading word classes from " << file << " ...\n";
- AddWordToClassMapping_(TD::Convert("<s>"), TD::Convert("<s>"));
- AddWordToClassMapping_(TD::Convert("</s>"), TD::Convert("</s>"));
- while(in) {
- getline(in, line);
- if (!in) continue;
- dummy.clear();
- TD::ConvertSentence(line, &dummy);
- ++lc;
- if (dummy.size() != 2) {
- cerr << " Format error in " << file << ", line " << lc << ": " << line << endl;
- abort();
- }
- AddWordToClassMapping_(dummy[0], dummy[1]);
- }
- }
-
- void AddWordToClassMapping_(WordID word, WordID cls) {
- if (word2class_map_.size() <= word) {
- word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK);
- assert(word2class_map_.size() > word);
- }
- if(word2class_map_[word] != kCDEC_UNK) {
- cerr << "Multiple classes for symbol " << TD::Convert(word) << endl;
- abort();
- }
- word2class_map_[word] = cls;
- }
-
- ~KLanguageModelImpl() {
- delete ngram_;
- delete[] dummy_state_;
- }
-
- int ReserveStateSize() const { return state_size_; }
-
- private:
- const WordID kCDEC_UNK;
- lm::WordIndex kSOS_; // <s> - requires special handling.
- lm::WordIndex kEOS_; // </s>
- Model* ngram_;
- const bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s>
- // if this is true, FinalTransitionFeatures will "add" <s> and </s>
- // if false, FinalTransitionFeatures will score anything with the
- // markers in the right place (i.e., the beginning and end of
- // the sentence) with 0, and anything else with -100
-
- int order_;
- int state_size_;
- int unscored_size_offset_;
- int is_complete_offset_;
- int unscored_words_offset_;
- char* dummy_state_;
- vector<const void*> dummy_ants_;
- vector<lm::WordIndex> cdec2klm_map_;
- vector<WordID> word2class_map_; // if this is a class-based LM, this is the word->class mapping
- TRulePtr dummy_rule_;
-};
-
-template <class Model>
-KLanguageModel<Model>::KLanguageModel(const string& param) {
- string filename, mapfile, featname;
- bool explicit_markers;
- if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) {
- abort();
- }
- try {
- pimpl_ = new KLanguageModelImpl<Model>(filename, mapfile, explicit_markers);
- } catch (std::exception &e) {
- std::cerr << e.what() << std::endl;
- abort();
- }
- fid_ = FD::Convert(featname);
- oov_fid_ = FD::Convert(featname+"_OOV");
- cerr << "FID: " << oov_fid_ << endl;
- SetStateSize(pimpl_->ReserveStateSize());
-}
-
-template <class Model>
-Features KLanguageModel<Model>::features() const {
- return single_feature(fid_);
-}
-
-template <class Model>
-KLanguageModel<Model>::~KLanguageModel() {
- delete pimpl_;
-}
-
-template <class Model>
-void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,
- const Hypergraph::Edge& edge,
- const vector<const void*>& ant_states,
- SparseVector<double>* features,
- SparseVector<double>* estimated_features,
- void* state) const {
- double est = 0;
- double oovs = 0;
- double est_oovs = 0;
- features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &est, &oovs, &est_oovs, state));
- estimated_features->set_value(fid_, est);
- if (oov_fid_) {
- if (oovs) features->set_value(oov_fid_, oovs);
- if (est_oovs) estimated_features->set_value(oov_fid_, est_oovs);
- }
-}
-
-template <class Model>
-void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state,
- SparseVector<double>* features) const {
- double oovs = 0;
- double lm = pimpl_->FinalTraversalCost(ant_state, &oovs);
- features->set_value(fid_, lm);
- if (oov_fid_ && oovs)
- features->set_value(oov_fid_, oovs);
-}
-
-template <class Model> boost::shared_ptr<FeatureFunction> CreateModel(const std::string &param) {
- KLanguageModel<Model> *ret = new KLanguageModel<Model>(param);
- ret->Init();
- return boost::shared_ptr<FeatureFunction>(ret);
-}
-
-boost::shared_ptr<FeatureFunction> KLanguageModelFactory::Create(std::string param) const {
- using namespace lm::ngram;
- std::string filename, ignored_map;
- bool ignored_markers;
- std::string ignored_featname;
- ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname);
- ModelType m;
- if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING;
-
- switch (m) {
- case HASH_PROBING:
- return CreateModel<ProbingModel>(param);
- case TRIE_SORTED:
- return CreateModel<TrieModel>(param);
- case ARRAY_TRIE_SORTED:
- return CreateModel<ArrayTrieModel>(param);
- case QUANT_TRIE_SORTED:
- return CreateModel<QuantTrieModel>(param);
- case QUANT_ARRAY_TRIE_SORTED:
- return CreateModel<QuantArrayTrieModel>(param);
- default:
- UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m);
- }
-}
-
-std::string KLanguageModelFactory::usage(bool params,bool verbose) const {
- return KLanguageModel<lm::ngram::Model>::usage(params, verbose);
-}
-
-#endif