summaryrefslogtreecommitdiff
path: root/training/augment_grammar.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/augment_grammar.cc')
-rw-r--r--training/augment_grammar.cc15
1 files changed, 12 insertions, 3 deletions
diff --git a/training/augment_grammar.cc b/training/augment_grammar.cc
index 9b7fc7be..48ef23fc 100644
--- a/training/augment_grammar.cc
+++ b/training/augment_grammar.cc
@@ -57,6 +57,8 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
return true;
}
+lm::WordIndex kSOS;
+
template <class Model> float Score(const vector<WordID>& str, const Model &model) {
typename Model::State state, out;
lm::FullScoreReturn ret;
@@ -65,9 +67,13 @@ template <class Model> float Score(const vector<WordID>& str, const Model &model
for (int i = 0; i < str.size(); ++i) {
lm::WordIndex vocab = ((str[i] < word_map.size() && str[i] > 0) ? word_map[str[i]] : 0);
- ret = model.FullScore(state, vocab, out);
- total += ret.prob;
- state = out;
+ if (vocab == kSOS) {
+ state = model.BeginSentenceState();
+ } else {
+ ret = model.FullScore(state, vocab, out);
+ total += ret.prob;
+ state = out;
+ }
}
return total;
}
@@ -76,6 +82,7 @@ static void RuleHelper(const TRulePtr& new_rule, const unsigned int ctf_level, c
cout << *new_rule << " SrcLM=" << Score(new_rule->f_, *ngram) << endl;
}
+
int main(int argc, char** argv) {
po::variables_map conf;
if (!InitCommandLine(argc, argv, &conf)) return 1;
@@ -84,7 +91,9 @@ int main(int argc, char** argv) {
VMapper vm(&word_map);
kconf.enumerate_vocab = &vm;
ngram = new lm::ngram::ProbingModel(conf["source_lm"].as<string>().c_str(), kconf);
+ kSOS = word_map[TD::Convert("<s>")];
cerr << "Loaded " << (int)ngram->Order() << "-gram KenLM (MapSize=" << word_map.size() << ")\n";
+ cerr << " <s> = " << kSOS << endl;
} else { ngram = NULL; }
assert(ngram);
RuleLexer::ReadRules(&cin, &RuleHelper, NULL);