summaryrefslogtreecommitdiff
path: root/decoder/ff_klm.cc
diff options
context:
space:
mode:
authorGuest_account Guest_account prguest11 <prguest11@taipan.cs>2011-09-23 20:49:43 +0100
committerGuest_account Guest_account prguest11 <prguest11@taipan.cs>2011-09-23 20:49:43 +0100
commit8ecf63852d730f99e7c1bbacfbffdf518d5a0c3f (patch)
tree9cc80e9a47ca8c6d543667c97af5162b9e251516 /decoder/ff_klm.cc
parente1b61419329c83709018ca397a29d069e4294bd1 (diff)
stub work to talk to new kenlm
Diffstat (limited to 'decoder/ff_klm.cc')
-rw-r--r--decoder/ff_klm.cc349
1 files changed, 349 insertions, 0 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index 24dcb9c3..016aad26 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -12,6 +12,353 @@
#include "lm/model.hh"
#include "lm/enumerate_vocab.hh"
+#undef NEW_KENLM
+#ifdef NEW_KENLM
+
+#include "lm/left.hh"
+
+using namespace std;
+
+// -x : rules include <s> and </s>
+// -n NAME : feature id is NAME
+bool ParseLMArgs(string const& in, string* filename, string* mapfile, bool* explicit_markers, string* featname) {
+ vector<string> const& argv=SplitOnWhitespace(in);
+ *explicit_markers = false;
+ *featname="LanguageModel";
+ *mapfile = "";
+#define LMSPEC_NEXTARG if (i==argv.end()) { \
+ cerr << "Missing argument for "<<*last<<". "; goto usage; \
+ } else { ++i; }
+
+ for (vector<string>::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) {
+ string const& s=*i;
+ if (s[0]=='-') {
+ if (s.size()>2) goto fail;
+ switch (s[1]) {
+ case 'x':
+ *explicit_markers = true;
+ break;
+ case 'm':
+ LMSPEC_NEXTARG; *mapfile=*i;
+ break;
+ case 'n':
+ LMSPEC_NEXTARG; *featname=*i;
+ break;
+#undef LMSPEC_NEXTARG
+ default:
+ fail:
+ cerr<<"Unknown KLanguageModel option "<<s<<" ; ";
+ goto usage;
+ }
+ } else {
+ if (filename->empty())
+ *filename=s;
+ else {
+ cerr<<"More than one filename provided. ";
+ goto usage;
+ }
+ }
+ }
+ if (!filename->empty())
+ return true;
+usage:
+ cerr << "KLanguageModel is incorrect!\n";
+ return false;
+}
+
+template <class Model>
+string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) {
+ return "KLanguageModel";
+}
+
+struct VMapper : public lm::ngram::EnumerateVocab {
+ VMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); }
+ void Add(lm::WordIndex index, const StringPiece &str) {
+ const WordID cdec_id = TD::Convert(str.as_string());
+ if (cdec_id >= out_->size())
+ out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN);
+ (*out_)[cdec_id] = index;
+ }
+ vector<lm::WordIndex>* out_;
+ const lm::WordIndex kLM_UNKNOWN_TOKEN;
+};
+
+template <class Model>
+class KLanguageModelImpl {
+
+ static inline const lm::ngram::ChartState& RemnantLMState(const void* state) {
+ return *static_cast<const lm::ngram::ChartState*>(state);
+ }
+
+ inline void SetRemnantLMState(const lm::ngram::ChartState& lmstate, void* state) const {
+ // if we were clever, we could use the memory pointed to by state to do all
+ // the work, avoiding this copy
+ memcpy(state, &lmstate, ngram_->StateSize());
+ }
+
+ public:
+ double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* oovs, void* remnant) {
+ double sum = 0.0;
+ if (oovs) *oovs = 0;
+ const vector<WordID>& e = rule.e();
+ lm::ngram::ChartState state;
+ lm::ngram::RuleScore<Model> ruleScore(*ngram_, state);
+ unsigned i = 0;
+ if (e.size()) {
+ if (e[i] == kCDEC_SOS) {
+ ++i;
+ ruleScore.BeginSentence();
+ } else if (e[i] <= 0) { // special case for left-edge NT
+ const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[0]]);
+ ruleScore.BeginNonTerminal(prevState, 0.0f); // TODO
+ ++i;
+ }
+ }
+ for (; i < e.size(); ++i) {
+ if (e[i] <= 0) {
+ const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[i]]);
+ ruleScore.NonTerminal(prevState, 0.0f); // TODO
+ } else {
+ const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future,
+ // maybe handle emission
+ const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id
+ const bool is_oov = (cur_word == 0);
+ if (is_oov) (*oovs) += 1.0;
+ ruleScore.Terminal(cur_word);
+ }
+ }
+ if (remnant) SetRemnantLMState(state, remnant);
+ return ruleScore.Finish();
+ }
+
+ // this assumes no target words on final unary -> goal rule. is that ok?
+ // for <s> (n-1 left words) and (n-1 right words) </s>
+ double FinalTraversalCost(const void* state, double* oovs) {
+ if (add_sos_eos_) { // rules do not produce <s> </s>, so do it here
+ lm::ngram::ChartState cstate;
+ lm::ngram::RuleScore<Model> ruleScore(*ngram_, cstate);
+ ruleScore.BeginSentence();
+ SetRemnantLMState(cstate, dummy_state_);
+ dummy_ants_[1] = state;
+ *oovs = 0;
+ return LookupWords(*dummy_rule_, dummy_ants_, oovs, NULL);
+ } else { // rules DO produce <s> ... </s>
+ double p = 0;
+ cerr << "not implemented"; abort(); // TODO
+ //if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; }
+ //if (UnscoredSize(state) > 0) { // are there unscored words
+ // if (kSOS_ != IthUnscoredWord(0, state)) {
+ // p -= 100 * UnscoredSize(state);
+ // }
+ //}
+ return p;
+ }
+ }
+
+ // if this is not a class-based LM, returns w untransformed,
+ // otherwise returns a word class mapping of w,
+ // returns TD::Convert("<unk>") if there is no mapping for w
+ WordID ClassifyWordIfNecessary(WordID w) const {
+ if (word2class_map_.empty()) return w;
+ if (w >= word2class_map_.size())
+ return kCDEC_UNK;
+ else
+ return word2class_map_[w];
+ }
+
+ // converts to cdec word id's to KenLM's id space, OOVs and <unk> end up at 0
+ lm::WordIndex MapWord(WordID w) const {
+ if (w >= cdec2klm_map_.size())
+ return 0;
+ else
+ return cdec2klm_map_[w];
+ }
+
+ public:
+ KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) :
+ kCDEC_UNK(TD::Convert("<unk>")) ,
+ kCDEC_SOS(TD::Convert("<s>")) ,
+ add_sos_eos_(!explicit_markers) {
+ {
+ VMapper vm(&cdec2klm_map_);
+ lm::ngram::Config conf;
+ conf.enumerate_vocab = &vm;
+ ngram_ = new Model(filename.c_str(), conf);
+ }
+ order_ = ngram_->Order();
+ cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n";
+ state_size_ = sizeof(lm::ngram::ChartState);
+
+ // special handling of beginning / ending sentence markers
+ dummy_state_ = new char[state_size_];
+ memset(dummy_state_, 0, state_size_);
+ dummy_ants_.push_back(dummy_state_);
+ dummy_ants_.push_back(NULL);
+ dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0"));
+ kSOS_ = MapWord(kCDEC_SOS);
+ assert(kSOS_ > 0);
+ kEOS_ = MapWord(TD::Convert("</s>"));
+ assert(kEOS_ > 0);
+ assert(MapWord(kCDEC_UNK) == 0); // KenLM invariant
+
+ // handle class-based LMs (unambiguous word->class mapping reqd.)
+ if (mapfile.size())
+ LoadWordClasses(mapfile);
+ }
+
+ void LoadWordClasses(const string& file) {
+ ReadFile rf(file);
+ istream& in = *rf.stream();
+ string line;
+ vector<WordID> dummy;
+ int lc = 0;
+ cerr << " Loading word classes from " << file << " ...\n";
+ AddWordToClassMapping_(TD::Convert("<s>"), TD::Convert("<s>"));
+ AddWordToClassMapping_(TD::Convert("</s>"), TD::Convert("</s>"));
+ while(in) {
+ getline(in, line);
+ if (!in) continue;
+ dummy.clear();
+ TD::ConvertSentence(line, &dummy);
+ ++lc;
+ if (dummy.size() != 2) {
+ cerr << " Format error in " << file << ", line " << lc << ": " << line << endl;
+ abort();
+ }
+ AddWordToClassMapping_(dummy[0], dummy[1]);
+ }
+ }
+
+ void AddWordToClassMapping_(WordID word, WordID cls) {
+ if (word2class_map_.size() <= word) {
+ word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK);
+ assert(word2class_map_.size() > word);
+ }
+ if(word2class_map_[word] != kCDEC_UNK) {
+ cerr << "Multiple classes for symbol " << TD::Convert(word) << endl;
+ abort();
+ }
+ word2class_map_[word] = cls;
+ }
+
+ ~KLanguageModelImpl() {
+ delete ngram_;
+ delete[] dummy_state_;
+ }
+
+ int ReserveStateSize() const { return state_size_; }
+
+ private:
+ const WordID kCDEC_UNK;
+ const WordID kCDEC_SOS;
+ lm::WordIndex kSOS_; // <s> - requires special handling.
+ lm::WordIndex kEOS_; // </s>
+ Model* ngram_;
+ const bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s>
+ // if this is true, FinalTransitionFeatures will "add" <s> and </s>
+ // if false, FinalTransitionFeatures will score anything with the
+ // markers in the right place (i.e., the beginning and end of
+ // the sentence) with 0, and anything else with -100
+
+ int order_;
+ int state_size_;
+ char* dummy_state_;
+ vector<const void*> dummy_ants_;
+ vector<lm::WordIndex> cdec2klm_map_;
+ vector<WordID> word2class_map_; // if this is a class-based LM, this is the word->class mapping
+ TRulePtr dummy_rule_;
+};
+
+template <class Model>
+KLanguageModel<Model>::KLanguageModel(const string& param) {
+ string filename, mapfile, featname;
+ bool explicit_markers;
+ if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) {
+ abort();
+ }
+ try {
+ pimpl_ = new KLanguageModelImpl<Model>(filename, mapfile, explicit_markers);
+ } catch (std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ abort();
+ }
+ fid_ = FD::Convert(featname);
+ oov_fid_ = FD::Convert(featname+"_OOV");
+ // cerr << "FID: " << oov_fid_ << endl;
+ SetStateSize(pimpl_->ReserveStateSize());
+}
+
+template <class Model>
+Features KLanguageModel<Model>::features() const {
+ return single_feature(fid_);
+}
+
+template <class Model>
+KLanguageModel<Model>::~KLanguageModel() {
+ delete pimpl_;
+}
+
+template <class Model>
+void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ double est = 0;
+ double oovs = 0;
+ features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &oovs, state));
+ if (oovs && oov_fid_)
+ features->set_value(oov_fid_, oovs);
+}
+
+template <class Model>
+void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state,
+ SparseVector<double>* features) const {
+ double oovs = 0;
+ double lm = pimpl_->FinalTraversalCost(ant_state, &oovs);
+ features->set_value(fid_, lm);
+ if (oov_fid_ && oovs)
+ features->set_value(oov_fid_, oovs);
+}
+
+template <class Model> boost::shared_ptr<FeatureFunction> CreateModel(const std::string &param) {
+ KLanguageModel<Model> *ret = new KLanguageModel<Model>(param);
+ ret->Init();
+ return boost::shared_ptr<FeatureFunction>(ret);
+}
+
+boost::shared_ptr<FeatureFunction> KLanguageModelFactory::Create(std::string param) const {
+ using namespace lm::ngram;
+ std::string filename, ignored_map;
+ bool ignored_markers;
+ std::string ignored_featname;
+ ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname);
+ ModelType m;
+ if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING;
+
+ switch (m) {
+ case HASH_PROBING:
+ return CreateModel<ProbingModel>(param);
+ case TRIE_SORTED:
+ return CreateModel<TrieModel>(param);
+ case ARRAY_TRIE_SORTED:
+ return CreateModel<ArrayTrieModel>(param);
+ case QUANT_TRIE_SORTED:
+ return CreateModel<QuantTrieModel>(param);
+ case QUANT_ARRAY_TRIE_SORTED:
+ return CreateModel<QuantArrayTrieModel>(param);
+ default:
+ UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m);
+ }
+}
+
+std::string KLanguageModelFactory::usage(bool params,bool verbose) const {
+ return KLanguageModel<lm::ngram::Model>::usage(params, verbose);
+}
+
+#else
+
using namespace std;
static const unsigned char HAS_FULL_CONTEXT = 1;
@@ -469,3 +816,5 @@ boost::shared_ptr<FeatureFunction> KLanguageModelFactory::Create(std::string par
std::string KLanguageModelFactory::usage(bool params,bool verbose) const {
return KLanguageModel<lm::ngram::Model>::usage(params, verbose);
}
+
+#endif