summaryrefslogtreecommitdiff
path: root/decoder/ff_klm.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/ff_klm.cc')
-rw-r--r--decoder/ff_klm.cc49
1 files changed, 30 insertions, 19 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index fefa90bd..c8ca917a 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -1,6 +1,7 @@
#include "ff_klm.h"
#include <cstring>
+#include <cstdlib>
#include <iostream>
#include <boost/scoped_ptr.hpp>
@@ -151,8 +152,9 @@ template <class Model> class BoundaryRuleScore {
template <class Model>
class KLanguageModelImpl {
public:
- double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* oovs, void* remnant) {
+ double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* oovs, double* emit, void* remnant) {
*oovs = 0;
+ *emit = 0;
const vector<WordID>& e = rule.e();
BoundaryRuleScore<Model> ruleScore(*ngram_, *static_cast<BoundaryAnnotatedState*>(remnant));
unsigned i = 0;
@@ -169,8 +171,9 @@ class KLanguageModelImpl {
if (e[i] <= 0) {
ruleScore.NonTerminal(*static_cast<const BoundaryAnnotatedState*>(ant_states[-e[i]]));
} else {
- const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future,
- // maybe handle emission
+ float ep = 0.f;
+ const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i], &ep);
+ if (ep) { *emit += ep; }
const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id
if (cur_word == 0) (*oovs) += 1.0;
ruleScore.Terminal(cur_word);
@@ -205,12 +208,14 @@ class KLanguageModelImpl {
// if this is not a class-based LM, returns w untransformed,
// otherwise returns a word class mapping of w,
// returns TD::Convert("<unk>") if there is no mapping for w
- WordID ClassifyWordIfNecessary(WordID w) const {
+ WordID ClassifyWordIfNecessary(WordID w, float* emitp) const {
if (word2class_map_.empty()) return w;
if (w >= word2class_map_.size())
return kCDEC_UNK;
- else
- return word2class_map_[w];
+ else {
+ *emitp = word2class_map_[w].second;
+ return word2class_map_[w].first;
+ }
}
// converts to cdec word id's to KenLM's id space, OOVs and <unk> end up at 0
@@ -256,32 +261,32 @@ class KLanguageModelImpl {
int lc = 0;
if (!SILENT)
cerr << " Loading word classes from " << file << " ...\n";
- AddWordToClassMapping_(TD::Convert("<s>"), TD::Convert("<s>"));
- AddWordToClassMapping_(TD::Convert("</s>"), TD::Convert("</s>"));
- while(in) {
- getline(in, line);
- if (!in) continue;
+ AddWordToClassMapping_(TD::Convert("<s>"), TD::Convert("<s>"), 0.0);
+ AddWordToClassMapping_(TD::Convert("</s>"), TD::Convert("</s>"), 0.0);
+ while(getline(in, line)) {
dummy.clear();
TD::ConvertSentence(line, &dummy);
++lc;
- if (dummy.size() != 2) {
+ if (dummy.size() != 3) {
+ cerr << " Class map file expects: CLASS WORD logp(WORD|CLASS)\n";
cerr << " Format error in " << file << ", line " << lc << ": " << line << endl;
abort();
}
- AddWordToClassMapping_(dummy[0], dummy[1]);
+ AddWordToClassMapping_(dummy[1], dummy[0], strtof(TD::Convert(dummy[2]).c_str(), NULL));
}
}
- void AddWordToClassMapping_(WordID word, WordID cls) {
+ void AddWordToClassMapping_(WordID word, WordID cls, float emit) {
if (word2class_map_.size() <= word) {
- word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK);
+ word2class_map_.resize((word + 10) * 1.1, pair<WordID,float>(kCDEC_UNK,0.f));
assert(word2class_map_.size() > word);
}
- if(word2class_map_[word] != kCDEC_UNK) {
+ if(word2class_map_[word].first != kCDEC_UNK) {
cerr << "Multiple classes for symbol " << TD::Convert(word) << endl;
abort();
}
- word2class_map_[word] = cls;
+ word2class_map_[word].first = cls;
+ word2class_map_[word].second = emit;
}
~KLanguageModelImpl() {
@@ -304,7 +309,9 @@ class KLanguageModelImpl {
int order_;
vector<lm::WordIndex> cdec2klm_map_;
- vector<WordID> word2class_map_; // if this is a class-based LM, this is the word->class mapping
+ vector<pair<WordID,float> > word2class_map_; // if this is a class-based LM,
+ // .first is the word->class mapping
+ // .second is the emission log probability
};
template <class Model>
@@ -322,6 +329,7 @@ KLanguageModel<Model>::KLanguageModel(const string& param) {
}
fid_ = FD::Convert(featname);
oov_fid_ = FD::Convert(featname+"_OOV");
+ emit_fid_ = FD::Convert(featname+"_Emit");
// cerr << "FID: " << oov_fid_ << endl;
SetStateSize(pimpl_->ReserveStateSize());
}
@@ -340,9 +348,12 @@ void KLanguageModel<Model>::TraversalFeaturesImpl(const SentenceMetadata& /* sme
void* state) const {
double est = 0;
double oovs = 0;
- features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &oovs, state));
+ double emit = 0;
+ features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &oovs, &emit, state));
if (oovs && oov_fid_)
features->set_value(oov_fid_, oovs);
+ if (emit && emit_fid_)
+ features->set_value(emit_fid_, emit);
}
template <class Model>