summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--decoder/ff_klm.cc70
1 files changed, 57 insertions, 13 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index 203bced5..854653c3 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -2,6 +2,7 @@
#include <cstring>
+#include "stringlib.h"
#include "hg.h"
#include "tdict.h"
#include "lm/enumerate_vocab.hh"
@@ -12,6 +13,49 @@ static const unsigned char HAS_FULL_CONTEXT = 1;
static const unsigned char HAS_EOS_ON_RIGHT = 2;
static const unsigned char MASK = 7;
+// -x : rules include <s> and </s>
+// -n NAME : feature id is NAME
+bool ParseLMArgs(string const& in, string* filename, bool* explicit_markers, string* featname) {
+ vector<string> const& argv=SplitOnWhitespace(in);
+ *explicit_markers = true;
+ *featname="LanguageModel";
+#define LMSPEC_NEXTARG if (i==argv.end()) { \
+ cerr << "Missing argument for "<<*last<<". "; goto usage; \
+ } else { ++i; }
+
+ for (vector<string>::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) {
+ string const& s=*i;
+ if (s[0]=='-') {
+ if (s.size()>2) goto fail;
+ switch (s[1]) {
+ case 'x':
+ *explicit_markers = true;
+ break;
+ case 'n':
+ LMSPEC_NEXTARG; *featname=*i;
+ break;
+#undef LMSPEC_NEXTARG
+ default:
+ fail:
+ cerr<<"Unknown KLanguageModel option "<<s<<" ; ";
+ goto usage;
+ }
+ } else {
+ if (filename->empty())
+ *filename=s;
+ else {
+ cerr<<"More than one filename provided. ";
+ goto usage;
+ }
+ }
+ }
+ if (!filename->empty())
+ return true;
+usage:
+ cerr << "KLanguageModel is incorrect!\n";
+ return false;
+}
+
template <class Model>
string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) {
return "KLanguageModel";
@@ -212,19 +256,14 @@ class KLanguageModelImpl {
}
public:
- KLanguageModelImpl(const std::string& param) {
- add_sos_eos_ = true;
- string fname = param;
- if (param.find("-x ") == 0) {
- add_sos_eos_ = false;
- fname = param.substr(3);
- }
+ KLanguageModelImpl(const string& filename, bool explicit_markers) :
+ add_sos_eos_(!explicit_markers) {
lm::ngram::Config conf;
VMapper vm(&map_);
conf.enumerate_vocab = &vm;
- ngram_ = new Model(fname.c_str(), conf);
+ ngram_ = new Model(filename.c_str(), conf);
order_ = ngram_->Order();
- cerr << "Loaded " << order_ << "-gram KLM from " << fname << " (MapSize=" << map_.size() << ")\n";
+ cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << map_.size() << ")\n";
state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex);
unscored_size_offset_ = ngram_->StateSize();
is_complete_offset_ = unscored_size_offset_ + 1;
@@ -252,7 +291,7 @@ class KLanguageModelImpl {
lm::WordIndex kSOS_; // <s> - requires special handling.
lm::WordIndex kEOS_; // </s>
Model* ngram_;
- bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s>
+ const bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s>
// if this is true, FinalTransitionFeatures will "add" <s> and </s>
// if false, FinalTransitionFeatures will score anything with the
// markers in the right place (i.e., the beginning and end of
@@ -271,9 +310,14 @@ class KLanguageModelImpl {
template <class Model>
KLanguageModel<Model>::KLanguageModel(const string& param) {
- pimpl_ = new KLanguageModelImpl<Model>(param);
- fid_ = FD::Convert("LanguageModel"); // todo support LM feature name
- oov_fid_ = FD::Convert("OOV"); // should also be named
+ string filename, featname;
+ bool explicit_markers;
+ if (!ParseLMArgs(param, &filename, &explicit_markers, &featname)) {
+ abort();
+ }
+ pimpl_ = new KLanguageModelImpl<Model>(filename, explicit_markers);
+ fid_ = FD::Convert(featname);
+ oov_fid_ = FD::Convert(featname+"_OOV");
SetStateSize(pimpl_->ReserveStateSize());
}