diff options
| author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-01 03:05:25 -0500 | 
|---|---|---|
| committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-01 03:05:25 -0500 | 
| commit | 0c9fe1dd72fb2321a5652e0ee66d1c897a3d9f80 (patch) | |
| tree | 7492acc2fcfae48ee1a0cc03a0fc201d8466124e /decoder | |
| parent | 839cf217e24de58f07d683ab357d27d94791e1e2 (diff) | |
support explicit sentence boundary markers with cdec
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/ff_klm.cc | 60 | 
1 files changed, 47 insertions, 13 deletions
| diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 5049f156..5e590053 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -75,13 +75,19 @@ class KLanguageModelImpl {      const vector<WordID>& e = rule.e();      bool context_complete = false;      for (int j = 0; j < e.size(); ++j) { -      if (e[j] < 1) { +      if (e[j] < 1) {   // handle non-terminal substitution          const void* astate = (ant_states[-e[j]]);          int unscored_ant_len = UnscoredSize(astate);          for (int k = 0; k < unscored_ant_len; ++k) {            const lm::WordIndex cur_word = IthUnscoredWord(k, astate); -          const lm::ngram::State scopy(state); -          const double p = ngram_->Score(scopy, cur_word, state); +          double p = 0; +          if (cur_word == kSOS_) { +            if (state.ValidLength() > 0) { p = -100; cerr << rule << endl; } +            state = ngram_->BeginSentenceState(); +          } else { +            const lm::ngram::State scopy(state); +            p = ngram_->Score(scopy, cur_word, state); +          }            ++num_scored;            if (!context_complete) {              if (num_scored >= order_) context_complete = true; @@ -99,10 +105,16 @@ class KLanguageModelImpl {            state = RemnantLMState(astate);            context_complete = true;          } -      } else { +      } else {   // handle terminal          const lm::WordIndex cur_word = MapWord(e[j]); -        const lm::ngram::State scopy(state); -        const double p = ngram_->Score(scopy, cur_word, state); +        double p = 0; +        if (cur_word == kSOS_) { +          if (state.ValidLength() > 0) p = -100; +          state = ngram_->BeginSentenceState(); +        } else { +          const lm::ngram::State scopy(state); +          p = ngram_->Score(scopy, cur_word, state); +        }          ++num_scored;          if (!context_complete) {            if (num_scored >= order_) context_complete = true; @@ -130,11 +142,16 @@ class KLanguageModelImpl {    //FIXME: this assumes no target words on final unary -> goal rule.  is that ok?    // for <s> (n-1 left words) and (n-1 right words) </s>    double FinalTraversalCost(const void* state) { -    SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_); -    SetHasFullContext(1, dummy_state_); -    SetUnscoredSize(0, dummy_state_); -    dummy_ants_[1] = state; -    return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL); +    if (add_sos_eos_) { +      SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_); +      SetHasFullContext(1, dummy_state_); +      SetUnscoredSize(0, dummy_state_); +      dummy_ants_[1] = state; +      return LookupWords(*dummy_rule_, dummy_ants_, NULL, NULL); +    } else { +      // TODO, figure out whether spans are correct +      return 0; +    }    }    lm::WordIndex MapWord(WordID w) const { @@ -146,12 +163,18 @@ class KLanguageModelImpl {   public:    KLanguageModelImpl(const std::string& param) { +    add_sos_eos_ = true; +    string fname = param; +    if (param.find("-x ") == 0) { +       add_sos_eos_ = false; +       fname = param.substr(3); +    }      lm::ngram::Config conf;      VMapper vm(&map_);      conf.enumerate_vocab = &vm;  -    ngram_ = new Model(param.c_str(), conf); +    ngram_ = new Model(fname.c_str(), conf);      order_ = ngram_->Order(); -    cerr << "Loaded " << order_ << "-gram KLM from " << param << " (MapSize=" << map_.size() << ")\n"; +    cerr << "Loaded " << order_ << "-gram KLM from " << fname << " (MapSize=" << map_.size() << ")\n";      state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex);      unscored_size_offset_ = ngram_->StateSize();      is_complete_offset_ = unscored_size_offset_ + 1; @@ -162,6 +185,9 @@ class KLanguageModelImpl {      dummy_ants_.push_back(dummy_state_);      dummy_ants_.push_back(NULL);      dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0")); +    kSOS_ = MapWord(TD::Convert("<s>")); +    assert(kSOS_ > 0); +    kEOS_ = MapWord(TD::Convert("</s>"));    }    ~KLanguageModelImpl() { @@ -172,7 +198,15 @@ class KLanguageModelImpl {    int ReserveStateSize() const { return state_size_; }   private: +  lm::WordIndex kSOS_;  // <s> - requires special handling. +  lm::WordIndex kEOS_;  // </s>    Model* ngram_; +  bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s> +                     // if this is true, FinalTransitionFeatures will "add" <s> and </s> +                     // if false, FinalTransitionFeatures will score anything with the +                     // markers in the right place (i.e., the beginning and end of +                     // the sentence) with 0, and anything else with -100 +    int order_;    int state_size_;    int unscored_size_offset_; | 
