diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-08 14:40:38 -0500 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-08 14:40:38 -0500 |
commit | e13ce3fa2533df895b92e484f7a27e78ba0a47ff (patch) | |
tree | 2619c5f7d300dcfa5920713a5d41aef62841d58c | |
parent | e5ef74a308f9fc2d131481214325a26ca1a895dc (diff) |
support multiple LMs with different feature names
-rw-r--r-- | decoder/ff_klm.cc | 70 |
1 files changed, 57 insertions, 13 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 203bced5..854653c3 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -2,6 +2,7 @@ #include <cstring> +#include "stringlib.h" #include "hg.h" #include "tdict.h" #include "lm/enumerate_vocab.hh" @@ -12,6 +13,49 @@ static const unsigned char HAS_FULL_CONTEXT = 1; static const unsigned char HAS_EOS_ON_RIGHT = 2; static const unsigned char MASK = 7; +// -x : rules include <s> and </s> +// -n NAME : feature id is NAME +bool ParseLMArgs(string const& in, string* filename, bool* explicit_markers, string* featname) { + vector<string> const& argv=SplitOnWhitespace(in); + *explicit_markers = true; + *featname="LanguageModel"; +#define LMSPEC_NEXTARG if (i==argv.end()) { \ + cerr << "Missing argument for "<<*last<<". "; goto usage; \ + } else { ++i; } + + for (vector<string>::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { + string const& s=*i; + if (s[0]=='-') { + if (s.size()>2) goto fail; + switch (s[1]) { + case 'x': + *explicit_markers = true; + break; + case 'n': + LMSPEC_NEXTARG; *featname=*i; + break; +#undef LMSPEC_NEXTARG + default: + fail: + cerr<<"Unknown KLanguageModel option "<<s<<" ; "; + goto usage; + } + } else { + if (filename->empty()) + *filename=s; + else { + cerr<<"More than one filename provided. "; + goto usage; + } + } + } + if (!filename->empty()) + return true; +usage: + cerr << "KLanguageModel is incorrect!\n"; + return false; +} + template <class Model> string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) { return "KLanguageModel"; @@ -212,19 +256,14 @@ class KLanguageModelImpl { } public: - KLanguageModelImpl(const std::string& param) { - add_sos_eos_ = true; - string fname = param; - if (param.find("-x ") == 0) { - add_sos_eos_ = false; - fname = param.substr(3); - } + KLanguageModelImpl(const string& filename, bool explicit_markers) : + add_sos_eos_(!explicit_markers) { lm::ngram::Config conf; VMapper vm(&map_); conf.enumerate_vocab = &vm; - ngram_ = new Model(fname.c_str(), conf); + ngram_ = new Model(filename.c_str(), conf); order_ = ngram_->Order(); - cerr << "Loaded " << order_ << "-gram KLM from " << fname << " (MapSize=" << map_.size() << ")\n"; + cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << map_.size() << ")\n"; state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex); unscored_size_offset_ = ngram_->StateSize(); is_complete_offset_ = unscored_size_offset_ + 1; @@ -252,7 +291,7 @@ class KLanguageModelImpl { lm::WordIndex kSOS_; // <s> - requires special handling. lm::WordIndex kEOS_; // </s> Model* ngram_; - bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s> + const bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s> // if this is true, FinalTransitionFeatures will "add" <s> and </s> // if false, FinalTransitionFeatures will score anything with the // markers in the right place (i.e., the beginning and end of @@ -271,9 +310,14 @@ class KLanguageModelImpl { template <class Model> KLanguageModel<Model>::KLanguageModel(const string& param) { - pimpl_ = new KLanguageModelImpl<Model>(param); - fid_ = FD::Convert("LanguageModel"); // todo support LM feature name - oov_fid_ = FD::Convert("OOV"); // should also be named + string filename, featname; + bool explicit_markers; + if (!ParseLMArgs(param, &filename, &explicit_markers, &featname)) { + abort(); + } + pimpl_ = new KLanguageModelImpl<Model>(filename, explicit_markers); + fid_ = FD::Convert(featname); + oov_fid_ = FD::Convert(featname+"_OOV"); SetStateSize(pimpl_->ReserveStateSize()); } |