summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--decoder/ff_klm.cc36
1 files changed, 9 insertions, 27 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index 016aad26..3b2113ad 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -90,19 +90,12 @@ class KLanguageModelImpl {
return *static_cast<const lm::ngram::ChartState*>(state);
}
- inline void SetRemnantLMState(const lm::ngram::ChartState& lmstate, void* state) const {
- // if we were clever, we could use the memory pointed to by state to do all
- // the work, avoiding this copy
- memcpy(state, &lmstate, ngram_->StateSize());
- }
-
public:
double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* oovs, void* remnant) {
- double sum = 0.0;
if (oovs) *oovs = 0;
const vector<WordID>& e = rule.e();
lm::ngram::ChartState state;
- lm::ngram::RuleScore<Model> ruleScore(*ngram_, state);
+ lm::ngram::RuleScore<Model> ruleScore(*ngram_, remnant ? *static_cast<lm::ngram::ChartState*>(remnant) : state);
unsigned i = 0;
if (e.size()) {
if (e[i] == kCDEC_SOS) {
@@ -123,12 +116,13 @@ class KLanguageModelImpl {
// maybe handle emission
const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id
const bool is_oov = (cur_word == 0);
- if (is_oov) (*oovs) += 1.0;
+ if (is_oov && oovs) (*oovs) += 1.0;
ruleScore.Terminal(cur_word);
}
}
- if (remnant) SetRemnantLMState(state, remnant);
- return ruleScore.Finish();
+ double ret = ruleScore.Finish();
+ state.ZeroRemaining();
+ return ret;
}
// this assumes no target words on final unary -> goal rule. is that ok?
@@ -138,10 +132,9 @@ class KLanguageModelImpl {
lm::ngram::ChartState cstate;
lm::ngram::RuleScore<Model> ruleScore(*ngram_, cstate);
ruleScore.BeginSentence();
- SetRemnantLMState(cstate, dummy_state_);
- dummy_ants_[1] = state;
- *oovs = 0;
- return LookupWords(*dummy_rule_, dummy_ants_, oovs, NULL);
+ ruleScore.NonTerminal(RemnantLMState(state), 0.0f);
+ ruleScore.Terminal(kEOS_);
+ return ruleScore.Finish();
} else { // rules DO produce <s> ... </s>
double p = 0;
cerr << "not implemented"; abort(); // TODO
@@ -187,14 +180,8 @@ class KLanguageModelImpl {
}
order_ = ngram_->Order();
cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n";
- state_size_ = sizeof(lm::ngram::ChartState);
// special handling of beginning / ending sentence markers
- dummy_state_ = new char[state_size_];
- memset(dummy_state_, 0, state_size_);
- dummy_ants_.push_back(dummy_state_);
- dummy_ants_.push_back(NULL);
- dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0"));
kSOS_ = MapWord(kCDEC_SOS);
assert(kSOS_ > 0);
kEOS_ = MapWord(TD::Convert("</s>"));
@@ -243,10 +230,9 @@ class KLanguageModelImpl {
~KLanguageModelImpl() {
delete ngram_;
- delete[] dummy_state_;
}
- int ReserveStateSize() const { return state_size_; }
+ int ReserveStateSize() const { return sizeof(lm::ngram::ChartState); }
private:
const WordID kCDEC_UNK;
@@ -261,12 +247,8 @@ class KLanguageModelImpl {
// the sentence) with 0, and anything else with -100
int order_;
- int state_size_;
- char* dummy_state_;
- vector<const void*> dummy_ants_;
vector<lm::WordIndex> cdec2klm_map_;
vector<WordID> word2class_map_; // if this is a class-based LM, this is the word->class mapping
- TRulePtr dummy_rule_;
};
template <class Model>