summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-03-08 15:17:28 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2011-03-08 15:17:28 -0500
commit6872cde637b0ee6d308358e5618292da7efba512 (patch)
tree69617e6074fddedd3119800d0ea2fb135f533647 /decoder
parente13ce3fa2533df895b92e484f7a27e78ba0a47ff (diff)
support for class based LMs (without emission probs)
Diffstat (limited to 'decoder')
-rw-r--r--decoder/ff_klm.cc96
1 files changed, 82 insertions, 14 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index 854653c3..446485f6 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -1,7 +1,11 @@
#include "ff_klm.h"
#include <cstring>
+#include <iostream>
+#include <boost/scoped_ptr.hpp>
+
+#include "filelib.h"
#include "stringlib.h"
#include "hg.h"
#include "tdict.h"
@@ -15,10 +19,11 @@ static const unsigned char MASK = 7;
// -x : rules include <s> and </s>
// -n NAME : feature id is NAME
-bool ParseLMArgs(string const& in, string* filename, bool* explicit_markers, string* featname) {
+bool ParseLMArgs(string const& in, string* filename, string* mapfile, bool* explicit_markers, string* featname) {
vector<string> const& argv=SplitOnWhitespace(in);
*explicit_markers = true;
*featname="LanguageModel";
+ *mapfile = "";
#define LMSPEC_NEXTARG if (i==argv.end()) { \
cerr << "Missing argument for "<<*last<<". "; goto usage; \
} else { ++i; }
@@ -31,6 +36,9 @@ bool ParseLMArgs(string const& in, string* filename, bool* explicit_markers, str
case 'x':
*explicit_markers = true;
break;
+ case 'm':
+ LMSPEC_NEXTARG; *mapfile=*i;
+ break;
case 'n':
LMSPEC_NEXTARG; *featname=*i;
break;
@@ -182,7 +190,9 @@ class KLanguageModelImpl {
context_complete = true;
}
} else { // handle terminal
- const lm::WordIndex cur_word = MapWord(e[j]);
+ const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[j]); // in future,
+ // maybe handle emission
+ const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id
double p = 0;
const bool is_oov = (cur_word == 0);
if (cur_word == kSOS_) {
@@ -248,22 +258,38 @@ class KLanguageModelImpl {
}
}
+ // if this is not a class-based LM, returns w untransformed,
+ // otherwise returns a word class mapping of w,
+ // returns TD::Convert("<unk>") if there is no mapping for w
+ WordID ClassifyWordIfNecessary(WordID w) const {
+ if (word2class_map_.empty()) return w;
+ if (w >= word2class_map_.size())
+ return kCDEC_UNK;
+ else
+ return word2class_map_[w];
+ }
+
+ // converts to cdec word id's to KenLM's id space, OOVs and <unk> end up at 0
lm::WordIndex MapWord(WordID w) const {
- if (w >= map_.size())
+ if (w >= cdec2klm_map_.size())
return 0;
else
- return map_[w];
+ return cdec2klm_map_[w];
}
public:
- KLanguageModelImpl(const string& filename, bool explicit_markers) :
+ KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) :
+ kCDEC_UNK(TD::Convert("<unk>")) ,
add_sos_eos_(!explicit_markers) {
- lm::ngram::Config conf;
- VMapper vm(&map_);
- conf.enumerate_vocab = &vm;
- ngram_ = new Model(filename.c_str(), conf);
+ if (true) {
+ boost::scoped_ptr<lm::ngram::EnumerateVocab> vm;
+ vm.reset(new VMapper(&cdec2klm_map_));
+ lm::ngram::Config conf;
+ conf.enumerate_vocab = vm.get();
+ ngram_ = new Model(filename.c_str(), conf);
+ }
order_ = ngram_->Order();
- cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << map_.size() << ")\n";
+ cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n";
state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex);
unscored_size_offset_ = ngram_->StateSize();
is_complete_offset_ = unscored_size_offset_ + 1;
@@ -278,6 +304,46 @@ class KLanguageModelImpl {
assert(kSOS_ > 0);
kEOS_ = MapWord(TD::Convert("</s>"));
assert(kEOS_ > 0);
+ assert(MapWord(kCDEC_UNK) == 0); // KenLM invariant
+
+ // handle class-based LMs (unambiguous word->class mapping reqd.)
+ if (mapfile.size())
+ LoadWordClasses(mapfile);
+ }
+
+ void LoadWordClasses(const string& file) {
+ ReadFile rf(file);
+ istream& in = *rf.stream();
+ string line;
+ vector<WordID> dummy;
+ int lc = 0;
+ cerr << " Loading word classes from " << file << " ...\n";
+ AddWordToClassMapping_(TD::Convert("<s>"), TD::Convert("</s>"));
+ AddWordToClassMapping_(TD::Convert("</s>"), TD::Convert("</s>"));
+ while(in) {
+ getline(in, line);
+ if (!in) continue;
+ dummy.clear();
+ TD::ConvertSentence(line, &dummy);
+ ++lc;
+ if (dummy.size() != 2) {
+ cerr << " Format error in " << file << ", line " << lc << ": " << line << endl;
+ abort();
+ }
+ AddWordToClassMapping_(dummy[0], dummy[1]);
+ }
+ }
+
+ void AddWordToClassMapping_(WordID word, WordID cls) {
+ if (word2class_map_.size() <= word) {
+ word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK);
+ assert(word2class_map_.size() > word);
+ }
+ if(word2class_map_[word] == kCDEC_UNK) {
+ cerr << "Multiple classes for symbol " << TD::Convert(word) << endl;
+ abort();
+ }
+ word2class_map_[word] = cls;
}
~KLanguageModelImpl() {
@@ -288,6 +354,7 @@ class KLanguageModelImpl {
int ReserveStateSize() const { return state_size_; }
private:
+ const WordID kCDEC_UNK;
lm::WordIndex kSOS_; // <s> - requires special handling.
lm::WordIndex kEOS_; // </s>
Model* ngram_;
@@ -304,18 +371,19 @@ class KLanguageModelImpl {
int unscored_words_offset_;
char* dummy_state_;
vector<const void*> dummy_ants_;
- vector<lm::WordIndex> map_;
+ vector<lm::WordIndex> cdec2klm_map_;
+ vector<WordID> word2class_map_; // if this is a class-based LM, this is the word->class mapping
TRulePtr dummy_rule_;
};
template <class Model>
KLanguageModel<Model>::KLanguageModel(const string& param) {
- string filename, featname;
+ string filename, mapfile, featname;
bool explicit_markers;
- if (!ParseLMArgs(param, &filename, &explicit_markers, &featname)) {
+ if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) {
abort();
}
- pimpl_ = new KLanguageModelImpl<Model>(filename, explicit_markers);
+ pimpl_ = new KLanguageModelImpl<Model>(filename, mapfile, explicit_markers);
fid_ = FD::Convert(featname);
oov_fid_ = FD::Convert(featname+"_OOV");
SetStateSize(pimpl_->ReserveStateSize());