summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-03-01 03:05:25 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2011-03-01 03:05:25 -0500
commitc96e484a59b4cf3f39b801f162b76882aca36ea7 (patch)
tree34b1b4966d647dcf29c54ced14a3edfb28db44a9
parent3f608d4c3fa7b98db4fbc64ff2b66e64072f38d6 (diff)
support explicit sentence boundary markers with cdec
-rw-r--r--decoder/ff_klm.cc60
1 files changed, 47 insertions, 13 deletions
diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc
index 5049f156..5e590053 100644
--- a/decoder/ff_klm.cc
+++ b/decoder/ff_klm.cc
@@ -75,13 +75,19 @@ class KLanguageModelImpl {
const vector<WordID>& e = rule.e();
bool context_complete = false;
for (int j = 0; j < e.size(); ++j) {
- if (e[j] < 1) {
+ if (e[j] < 1) { // handle non-terminal substitution
const void* astate = (ant_states[-e[j]]);
int unscored_ant_len = UnscoredSize(astate);
for (int k = 0; k < unscored_ant_len; ++k) {
const lm::WordIndex cur_word = IthUnscoredWord(k, astate);
- const lm::ngram::State scopy(state);
- const double p = ngram_->Score(scopy, cur_word, state);
+ double p = 0;
+ if (cur_word == kSOS_) {
+ if (state.ValidLength() > 0) { p = -100; cerr << rule << endl; }
+ state = ngram_->BeginSentenceState();
+ } else {
+ const lm::ngram::State scopy(state);
+ p = ngram_->Score(scopy, cur_word, state);
+ }
++num_scored;
if (!context_complete) {
if (num_scored >= order_) context_complete = true;
@@ -99,10 +105,16 @@ class KLanguageModelImpl {
state = RemnantLMState(astate);
context_complete = true;
}
- } else {
+ } else { // handle terminal
const lm::WordIndex cur_word = MapWord(e[j]);
- const lm::ngram::State scopy(state);
- const double p = ngram_->Score(scopy, cur_word, state);
+ double p = 0;
+ if (cur_word == kSOS_) {
+ if (state.ValidLength() > 0) p = -100;
+ state = ngram_->BeginSentenceState();
+ } else {
+ const lm::ngram::State scopy(state);
+ p = ngram_->Score(scopy, cur_word, state);
+ }
++num_scored;
if (!context_complete) {
if (num_scored >= order_) context_complete = true;
@@ -130,11 +142,16 @@ class KLanguageModelImpl {
//FIXME: this assumes no target words on final unary -> goal rule. is that ok?
// for <s> (n-1 left words) and (n-1 right words) </s>
double FinalTraversalCost(const void* state) {
- SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_);
- SetHasFullContext(1, dummy_state_);
- SetUnscoredSize(0, dummy_state_);
- dummy_ants_[1] = state;
- return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL);
+ if (add_sos_eos_) {
+ SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_);
+ SetHasFullContext(1, dummy_state_);
+ SetUnscoredSize(0, dummy_state_);
+ dummy_ants_[1] = state;
+ return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL);
+ } else {
+ // TODO, figure out whether spans are correct
+ return 0;
+ }
}
lm::WordIndex MapWord(WordID w) const {
@@ -146,12 +163,18 @@ class KLanguageModelImpl {
public:
KLanguageModelImpl(const std::string& param) {
+ add_sos_eos_ = true;
+ string fname = param;
+ if (param.find("-x ") == 0) {
+ add_sos_eos_ = false;
+ fname = param.substr(3);
+ }
lm::ngram::Config conf;
VMapper vm(&map_);
conf.enumerate_vocab = &vm;
- ngram_ = new Model(param.c_str(), conf);
+ ngram_ = new Model(fname.c_str(), conf);
order_ = ngram_->Order();
- cerr << "Loaded " << order_ << "-gram KLM from " << param << " (MapSize=" << map_.size() << ")\n";
+ cerr << "Loaded " << order_ << "-gram KLM from " << fname << " (MapSize=" << map_.size() << ")\n";
state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex);
unscored_size_offset_ = ngram_->StateSize();
is_complete_offset_ = unscored_size_offset_ + 1;
@@ -162,6 +185,9 @@ class KLanguageModelImpl {
dummy_ants_.push_back(dummy_state_);
dummy_ants_.push_back(NULL);
dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0"));
+ kSOS_ = MapWord(TD::Convert("<s>"));
+ assert(kSOS_ > 0);
+ kEOS_ = MapWord(TD::Convert("</s>"));
}
~KLanguageModelImpl() {
@@ -172,7 +198,15 @@ class KLanguageModelImpl {
int ReserveStateSize() const { return state_size_; }
private:
+ lm::WordIndex kSOS_; // <s> - requires special handling.
+ lm::WordIndex kEOS_; // </s>
Model* ngram_;
+ bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s>
+ // if this is true, FinalTransitionFeatures will "add" <s> and </s>
+ // if false, FinalTransitionFeatures will score anything with the
+ // markers in the right place (i.e., the beginning and end of
+ // the sentence) with 0, and anything else with -100
+
int order_;
int state_size_;
int unscored_size_offset_;