diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-12-14 12:39:04 -0800 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2012-12-14 12:39:04 -0800 |
commit | 5d42134cc676278189b0f77708908542fbb5ccc9 (patch) | |
tree | ab55a95ae3b38095467733323e7162578b308164 /decoder | |
parent | 212decb4382b84c2370c369b0507a5534399aa56 (diff) |
Updated incremental, updated kenlm. Incremental assumes <s>
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/incremental.cc | 44 |
1 files changed, 22 insertions, 22 deletions
diff --git a/decoder/incremental.cc b/decoder/incremental.cc index 46615b0b..85647a44 100644 --- a/decoder/incremental.cc +++ b/decoder/incremental.cc @@ -6,6 +6,7 @@ #include "lm/enumerate_vocab.hh" #include "lm/model.hh" +#include "search/applied.hh" #include "search/config.hh" #include "search/context.hh" #include "search/edge.hh" @@ -48,16 +49,16 @@ template <class Model> class Incremental : public IncrementalBase { Incremental(const char *model_file, const std::vector<weight_t> &weights) : IncrementalBase(weights), m_(model_file, GetConfig()), - weights_( - weights[FD::Convert("KLanguageModel")], - weights[FD::Convert("KLanguageModel_OOV")], - weights[FD::Convert("WordPenalty")]) { - std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; + lm_(weights[FD::Convert("KLanguageModel")]), + oov_(weights[FD::Convert("KLanguageModel_OOV")]), + word_penalty_(weights[FD::Convert("WordPenalty")]) { + std::cerr << "Weights KLanguageModel " << lm_ << " KLanguageModel_OOV " << oov_ << " WordPenalty " << word_penalty_ << std::endl; } + void Search(unsigned int pop_limit, const Hypergraph &hg) const; private: - void ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; + void ConvertEdge(const search::Context<Model> &context, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; lm::ngram::Config GetConfig() { lm::ngram::Config ret; @@ -69,46 +70,47 @@ template <class Model> class Incremental : public IncrementalBase { const Model m_; - const search::Weights weights_; + const float lm_, oov_, word_penalty_; }; -void PrintFinal(const Hypergraph &hg, const search::Final final) { +void PrintApplied(const Hypergraph &hg, const search::Applied final) { const std::vector<WordID> &words = static_cast<const Hypergraph::Edge*>(final.GetNote().vp)->rule_->e(); - const search::Final *child(final.Children()); + const search::Applied *child(final.Children()); for (std::vector<WordID>::const_iterator i = words.begin(); i != words.end(); ++i) { if (*i > 0) { std::cout << TD::Convert(*i) << ' '; } else { - PrintFinal(hg, *child++); + PrintApplied(hg, *child++); } } } template <class Model> void Incremental<Model>::Search(unsigned int pop_limit, const Hypergraph &hg) const { boost::scoped_array<search::Vertex> out_vertices(new search::Vertex[hg.nodes_.size()]); - search::Config config(weights_, pop_limit); + search::Config config(lm_, pop_limit, search::NBestConfig(1)); search::Context<Model> context(config, m_); + search::SingleBest best; for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) { search::EdgeGenerator gen; const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; for (unsigned int j = 0; j < down_edges.size(); ++j) { unsigned int edge_index = down_edges[j]; - ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], gen); + ConvertEdge(context, out_vertices.get(), hg.edges_[edge_index], gen); } - search::VertexGenerator vertex_gen(context, out_vertices[i]); + search::VertexGenerator<search::SingleBest> vertex_gen(context, out_vertices[i], best); gen.Search(context, vertex_gen); } - const search::Final top = out_vertices[hg.nodes_.size() - 2].BestChild(); + const search::Applied top = out_vertices[hg.nodes_.size() - 2].BestChild(); if (!top.Valid()) { std::cout << "NO PATH FOUND" << std::endl; } else { - PrintFinal(hg, top); + PrintApplied(hg, top); std::cout << "||| " << top.GetScore() << std::endl; } } -template <class Model> void Incremental<Model>::ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const { +template <class Model> void Incremental<Model>::ConvertEdge(const search::Context<Model> &context, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const { const std::vector<WordID> &e = in.rule_->e(); std::vector<lm::WordIndex> words; words.reserve(e.size()); @@ -127,10 +129,6 @@ template <class Model> void Incremental<Model>::ConvertEdge(const search::Contex } } - if (final) { - words.push_back(m_.GetVocabulary().EndSentence()); - } - search::PartialEdge out(gen.AllocateEdge(nts.size())); memcpy(out.NT(), &nts[0], sizeof(search::PartialVertex) * nts.size()); @@ -140,8 +138,10 @@ template <class Model> void Incremental<Model>::ConvertEdge(const search::Contex out.SetNote(note); score += in.rule_->GetFeatureValues().dot(cdec_weights_); - score -= static_cast<float>(terminals) * context.GetWeights().WordPenalty() / M_LN10; - score += search::ScoreRule(context, words, final, out.Between()); + score -= static_cast<float>(terminals) * word_penalty_ / M_LN10; + search::ScoreRuleRet res(search::ScoreRule(context.LanguageModel(), words, out.Between())); + score += res.prob * lm_ + static_cast<float>(res.oov) * oov_; + out.SetScore(score); gen.AddEdge(out); |