diff options
| author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-01 02:40:25 -0500 | 
|---|---|---|
| committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-01 02:40:25 -0500 | 
| commit | 839cf217e24de58f07d683ab357d27d94791e1e2 (patch) | |
| tree | 5a36fc2396590b6dd270bbf0e3f4d84a9b324ef6 /training | |
| parent | bffa08c6a335ac03bc43ab178724ec176921aa5a (diff) | |
deal with SOS
Diffstat (limited to 'training')
| -rw-r--r-- | training/augment_grammar.cc | 15 | 
1 files changed, 12 insertions, 3 deletions
| diff --git a/training/augment_grammar.cc b/training/augment_grammar.cc index 9b7fc7be..48ef23fc 100644 --- a/training/augment_grammar.cc +++ b/training/augment_grammar.cc @@ -57,6 +57,8 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {    return true;  } +lm::WordIndex kSOS; +  template <class Model> float Score(const vector<WordID>& str, const Model &model) {    typename Model::State state, out;    lm::FullScoreReturn ret; @@ -65,9 +67,13 @@ template <class Model> float Score(const vector<WordID>& str, const Model &model    for (int i = 0; i < str.size(); ++i) {      lm::WordIndex vocab = ((str[i] < word_map.size() && str[i] > 0) ? word_map[str[i]] : 0); -    ret = model.FullScore(state, vocab, out); -    total += ret.prob; -    state = out; +    if (vocab == kSOS) { +      state = model.BeginSentenceState(); +    } else { +      ret = model.FullScore(state, vocab, out); +      total += ret.prob; +      state = out; +    }    }    return total;  } @@ -76,6 +82,7 @@ static void RuleHelper(const TRulePtr& new_rule, const unsigned int ctf_level, c    cout << *new_rule << " SrcLM=" << Score(new_rule->f_, *ngram) << endl;  } +  int main(int argc, char** argv) {    po::variables_map conf;    if (!InitCommandLine(argc, argv, &conf)) return 1; @@ -84,7 +91,9 @@ int main(int argc, char** argv) {      VMapper vm(&word_map);      kconf.enumerate_vocab = &vm;       ngram = new lm::ngram::ProbingModel(conf["source_lm"].as<string>().c_str(), kconf); +    kSOS = word_map[TD::Convert("<s>")];      cerr << "Loaded " << (int)ngram->Order() << "-gram KenLM (MapSize=" << word_map.size() << ")\n"; +    cerr << "  <s> = " << kSOS << endl;    } else { ngram = NULL; }    assert(ngram);    RuleLexer::ReadRules(&cin, &RuleHelper, NULL); | 
