diff options
-rw-r--r-- | training/augment_grammar.cc | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/training/augment_grammar.cc b/training/augment_grammar.cc index 9b7fc7be..48ef23fc 100644 --- a/training/augment_grammar.cc +++ b/training/augment_grammar.cc @@ -57,6 +57,8 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { return true; } +lm::WordIndex kSOS; + template <class Model> float Score(const vector<WordID>& str, const Model &model) { typename Model::State state, out; lm::FullScoreReturn ret; @@ -65,9 +67,13 @@ template <class Model> float Score(const vector<WordID>& str, const Model &model for (int i = 0; i < str.size(); ++i) { lm::WordIndex vocab = ((str[i] < word_map.size() && str[i] > 0) ? word_map[str[i]] : 0); - ret = model.FullScore(state, vocab, out); - total += ret.prob; - state = out; + if (vocab == kSOS) { + state = model.BeginSentenceState(); + } else { + ret = model.FullScore(state, vocab, out); + total += ret.prob; + state = out; + } } return total; } @@ -76,6 +82,7 @@ static void RuleHelper(const TRulePtr& new_rule, const unsigned int ctf_level, c cout << *new_rule << " SrcLM=" << Score(new_rule->f_, *ngram) << endl; } + int main(int argc, char** argv) { po::variables_map conf; if (!InitCommandLine(argc, argv, &conf)) return 1; @@ -84,7 +91,9 @@ int main(int argc, char** argv) { VMapper vm(&word_map); kconf.enumerate_vocab = &vm; ngram = new lm::ngram::ProbingModel(conf["source_lm"].as<string>().c_str(), kconf); + kSOS = word_map[TD::Convert("<s>")]; cerr << "Loaded " << (int)ngram->Order() << "-gram KenLM (MapSize=" << word_map.size() << ")\n"; + cerr << " <s> = " << kSOS << endl; } else { ngram = NULL; } assert(ngram); RuleLexer::ReadRules(&cin, &RuleHelper, NULL); |