summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-12-14 12:39:04 -0800
committerKenneth Heafield <github@kheafield.com>2012-12-14 12:39:04 -0800
commit5d42134cc676278189b0f77708908542fbb5ccc9 (patch)
treeab55a95ae3b38095467733323e7162578b308164 /decoder
parent212decb4382b84c2370c369b0507a5534399aa56 (diff)
Updated incremental, updated kenlm. Incremental assumes <s>
Diffstat (limited to 'decoder')
-rw-r--r--decoder/incremental.cc44
1 files changed, 22 insertions, 22 deletions
diff --git a/decoder/incremental.cc b/decoder/incremental.cc
index 46615b0b..85647a44 100644
--- a/decoder/incremental.cc
+++ b/decoder/incremental.cc
@@ -6,6 +6,7 @@
#include "lm/enumerate_vocab.hh"
#include "lm/model.hh"
+#include "search/applied.hh"
#include "search/config.hh"
#include "search/context.hh"
#include "search/edge.hh"
@@ -48,16 +49,16 @@ template <class Model> class Incremental : public IncrementalBase {
Incremental(const char *model_file, const std::vector<weight_t> &weights) :
IncrementalBase(weights),
m_(model_file, GetConfig()),
- weights_(
- weights[FD::Convert("KLanguageModel")],
- weights[FD::Convert("KLanguageModel_OOV")],
- weights[FD::Convert("WordPenalty")]) {
- std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl;
+ lm_(weights[FD::Convert("KLanguageModel")]),
+ oov_(weights[FD::Convert("KLanguageModel_OOV")]),
+ word_penalty_(weights[FD::Convert("WordPenalty")]) {
+ std::cerr << "Weights KLanguageModel " << lm_ << " KLanguageModel_OOV " << oov_ << " WordPenalty " << word_penalty_ << std::endl;
}
+
void Search(unsigned int pop_limit, const Hypergraph &hg) const;
private:
- void ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const;
+ void ConvertEdge(const search::Context<Model> &context, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const;
lm::ngram::Config GetConfig() {
lm::ngram::Config ret;
@@ -69,46 +70,47 @@ template <class Model> class Incremental : public IncrementalBase {
const Model m_;
- const search::Weights weights_;
+ const float lm_, oov_, word_penalty_;
};
-void PrintFinal(const Hypergraph &hg, const search::Final final) {
+void PrintApplied(const Hypergraph &hg, const search::Applied final) {
const std::vector<WordID> &words = static_cast<const Hypergraph::Edge*>(final.GetNote().vp)->rule_->e();
- const search::Final *child(final.Children());
+ const search::Applied *child(final.Children());
for (std::vector<WordID>::const_iterator i = words.begin(); i != words.end(); ++i) {
if (*i > 0) {
std::cout << TD::Convert(*i) << ' ';
} else {
- PrintFinal(hg, *child++);
+ PrintApplied(hg, *child++);
}
}
}
template <class Model> void Incremental<Model>::Search(unsigned int pop_limit, const Hypergraph &hg) const {
boost::scoped_array<search::Vertex> out_vertices(new search::Vertex[hg.nodes_.size()]);
- search::Config config(weights_, pop_limit);
+ search::Config config(lm_, pop_limit, search::NBestConfig(1));
search::Context<Model> context(config, m_);
+ search::SingleBest best;
for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) {
search::EdgeGenerator gen;
const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_;
for (unsigned int j = 0; j < down_edges.size(); ++j) {
unsigned int edge_index = down_edges[j];
- ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], gen);
+ ConvertEdge(context, out_vertices.get(), hg.edges_[edge_index], gen);
}
- search::VertexGenerator vertex_gen(context, out_vertices[i]);
+ search::VertexGenerator<search::SingleBest> vertex_gen(context, out_vertices[i], best);
gen.Search(context, vertex_gen);
}
- const search::Final top = out_vertices[hg.nodes_.size() - 2].BestChild();
+ const search::Applied top = out_vertices[hg.nodes_.size() - 2].BestChild();
if (!top.Valid()) {
std::cout << "NO PATH FOUND" << std::endl;
} else {
- PrintFinal(hg, top);
+ PrintApplied(hg, top);
std::cout << "||| " << top.GetScore() << std::endl;
}
}
-template <class Model> void Incremental<Model>::ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const {
+template <class Model> void Incremental<Model>::ConvertEdge(const search::Context<Model> &context, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const {
const std::vector<WordID> &e = in.rule_->e();
std::vector<lm::WordIndex> words;
words.reserve(e.size());
@@ -127,10 +129,6 @@ template <class Model> void Incremental<Model>::ConvertEdge(const search::Contex
}
}
- if (final) {
- words.push_back(m_.GetVocabulary().EndSentence());
- }
-
search::PartialEdge out(gen.AllocateEdge(nts.size()));
memcpy(out.NT(), &nts[0], sizeof(search::PartialVertex) * nts.size());
@@ -140,8 +138,10 @@ template <class Model> void Incremental<Model>::ConvertEdge(const search::Contex
out.SetNote(note);
score += in.rule_->GetFeatureValues().dot(cdec_weights_);
- score -= static_cast<float>(terminals) * context.GetWeights().WordPenalty() / M_LN10;
- score += search::ScoreRule(context, words, final, out.Between());
+ score -= static_cast<float>(terminals) * word_penalty_ / M_LN10;
+ search::ScoreRuleRet res(search::ScoreRule(context.LanguageModel(), words, out.Between()));
+ score += res.prob * lm_ + static_cast<float>(res.oov) * oov_;
+
out.SetScore(score);
gen.AddEdge(out);