summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2011-10-17 16:58:26 +0100
committerKenneth Heafield <github@kheafield.com>2011-10-17 16:58:26 +0100
commit957d90991b4ec80b9877126c736bd60768b094aa (patch)
tree958d2b33bd37e42713505924a95a9efa9d94038b
parentf036d4ec5c79db95df3470adb7cd317ff258ab7d (diff)
Chris, I'd like you to review this for use with your rules that contain <s> and </s>.
-rw-r--r--decoder/ff_klm.cc72
1 files changed, 49 insertions, 23 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index 6d9aca54..658aef80 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -71,6 +71,8 @@ string KLanguageModel<Model>::usage(bool /*param*/,bool /*verbose*/) {
return "KLanguageModel";
}
+namespace {
+
struct VMapper : public lm::ngram::EnumerateVocab {
VMapper(vector<lm::WordIndex>* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); }
void Add(lm::WordIndex index, const StringPiece &str) {
@@ -83,66 +85,90 @@ struct VMapper : public lm::ngram::EnumerateVocab {
const lm::WordIndex kLM_UNKNOWN_TOKEN;
};
-template <class Model>
-class KLanguageModelImpl {
+#pragma pack(push)
+#pragma pack(1)
- static inline const lm::ngram::ChartState& RemnantLMState(const void* state) {
- return *static_cast<const lm::ngram::ChartState*>(state);
+struct BoundaryAnnotatedState {
+ lm::ngram::ChartState state;
+ bool seen_bos, seen_eos;
+};
+
+#pragma pack(pop)
+
+void BoundaryCheck(bool &annotated, bool sub, double &ret) {
+ if (!sub) return;
+ if (annotated) {
+ ret -= 100.0;
+ } else {
+ annotated = true;
}
+}
+} // namespace
+
+template <class Model>
+class KLanguageModelImpl {
public:
double LookupWords(const TRule& rule, const vector<const void*>& ant_states, double* oovs, void* remnant) {
*oovs = 0;
const vector<WordID>& e = rule.e();
- lm::ngram::RuleScore<Model> ruleScore(*ngram_, *static_cast<lm::ngram::ChartState*>(remnant));
+ BoundaryAnnotatedState &annotated = *static_cast<BoundaryAnnotatedState*>(remnant);
+ lm::ngram::RuleScore<Model> ruleScore(*ngram_, annotated.state);
+ annotated.seen_bos = false;
+ annotated.seen_eos = false;
unsigned i = 0;
+ double ret = 0.0;
if (e.size()) {
if (e[i] == kCDEC_SOS) {
++i;
ruleScore.BeginSentence();
+ annotated.seen_bos = true;
} else if (e[i] <= 0) { // special case for left-edge NT
- const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[0]]);
- ruleScore.BeginNonTerminal(prevState, 0.0f); // TODO
+ const BoundaryAnnotatedState &sub = *static_cast<const BoundaryAnnotatedState*>(ant_states[-e[0]]);
+ ruleScore.BeginNonTerminal(sub.state, 0.0f);
+ annotated.seen_bos = sub.seen_bos;
+ annotated.seen_eos = sub.seen_eos;
++i;
}
}
for (; i < e.size(); ++i) {
if (e[i] <= 0) {
- const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[i]]);
- ruleScore.NonTerminal(prevState, 0.0f); // TODO
+ const BoundaryAnnotatedState &sub = *static_cast<const BoundaryAnnotatedState*>(ant_states[-e[i]]);
+ ruleScore.NonTerminal(sub.state, 0.0f);
+ BoundaryCheck(annotated.seen_bos, sub.seen_bos, ret);
+ BoundaryCheck(annotated.seen_eos, sub.seen_eos, ret);
} else {
const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future,
// maybe handle emission
const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id
if (cur_word == 0) (*oovs) += 1.0;
+ BoundaryCheck(annotated.seen_eos, cur_word == kEOS_, ret);
ruleScore.Terminal(cur_word);
}
}
- double ret = ruleScore.Finish();
- static_cast<lm::ngram::ChartState*>(remnant)->ZeroRemaining();
+ ret += ruleScore.Finish();
+ annotated.state.ZeroRemaining();
return ret;
}
// this assumes no target words on final unary -> goal rule. is that ok?
// for <s> (n-1 left words) and (n-1 right words) </s>
- double FinalTraversalCost(const void* state, double* oovs) {
+ double FinalTraversalCost(const void* state_void, double* oovs) {
+ const BoundaryAnnotatedState &annotated = *static_cast<const BoundaryAnnotatedState*>(state_void);
if (add_sos_eos_) { // rules do not produce <s> </s>, so do it here
+ assert(!annotated.seen_bos);
+ assert(!annotated.seen_eos);
lm::ngram::ChartState cstate;
lm::ngram::RuleScore<Model> ruleScore(*ngram_, cstate);
ruleScore.BeginSentence();
- ruleScore.NonTerminal(RemnantLMState(state), 0.0f);
+ ruleScore.NonTerminal(annotated.state, 0.0f);
ruleScore.Terminal(kEOS_);
return ruleScore.Finish();
} else { // rules DO produce <s> ... </s>
- double p = 0;
- cerr << "not implemented"; abort(); // TODO
- //if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; }
- //if (UnscoredSize(state) > 0) { // are there unscored words
- // if (kSOS_ != IthUnscoredWord(0, state)) {
- // p -= 100 * UnscoredSize(state);
- // }
- //}
- return p;
+ double ret = 0.0;
+ if (!annotated.seen_bos) ret -= 100.0;
+ if (!annotated.seen_eos) ret -= 100.0;
+ return ret;
}
}
@@ -230,7 +256,7 @@ class KLanguageModelImpl {
delete ngram_;
}
- int ReserveStateSize() const { return sizeof(lm::ngram::ChartState); }
+ int ReserveStateSize() const { return sizeof(BoundaryAnnotatedState); }
private:
const WordID kCDEC_UNK;