From 7278218f4581ed8da3dacbff9c7ff3834c292dab Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 22 Oct 2012 11:35:31 -0400 Subject: Remove global variable, have decoder hold a pointer --- decoder/incremental.cc | 69 +++++++++++++++++++------------------------------- 1 file changed, 26 insertions(+), 43 deletions(-) (limited to 'decoder/incremental.cc') diff --git a/decoder/incremental.cc b/decoder/incremental.cc index a9369374..46615b0b 100644 --- a/decoder/incremental.cc +++ b/decoder/incremental.cc @@ -43,21 +43,22 @@ struct MapVocab : public lm::EnumerateVocab { std::vector out_; }; -class IncrementalBase { +template class Incremental : public IncrementalBase { public: - IncrementalBase(const std::vector &weights) : - cdec_weights_(weights), - weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) { + Incremental(const char *model_file, const std::vector &weights) : + IncrementalBase(weights), + m_(model_file, GetConfig()), + weights_( + weights[FD::Convert("KLanguageModel")], + weights[FD::Convert("KLanguageModel_OOV")], + weights[FD::Convert("WordPenalty")]) { std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; } + void Search(unsigned int pop_limit, const Hypergraph &hg) const; - virtual ~IncrementalBase() {} - - virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; - - static IncrementalBase *Load(const char *model_file, const std::vector &weights); + private: + void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; - protected: lm::ngram::Config GetConfig() { lm::ngram::Config ret; ret.enumerate_vocab = &vocab_; @@ -66,36 +67,11 @@ class IncrementalBase { MapVocab vocab_; - const std::vector &cdec_weights_; + const Model m_; const search::Weights weights_; }; -template class Incremental : public IncrementalBase { - public: - Incremental(const char *model_file, const std::vector &weights) : IncrementalBase(weights), m_(model_file, GetConfig()) {} - - void Search(unsigned int pop_limit, const Hypergraph &hg) const; - - private: - void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; - - const Model m_; -}; - -IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector &weights) { - lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; - switch (model_type) { - case lm::ngram::PROBING: - return new Incremental(model_file, weights); - case lm::ngram::REST_PROBING: - return new Incremental(model_file, weights); - default: - UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); - } -} - void PrintFinal(const Hypergraph &hg, const search::Final final) { const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); const search::Final *child(final.Children()); @@ -171,14 +147,21 @@ template void Incremental::ConvertEdge(const search::Contex gen.AddEdge(out); } -boost::scoped_ptr AwfulGlobalIncremental; - } // namespace -void PassToIncremental(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg) { - if (!AwfulGlobalIncremental.get()) { - std::cerr << "Pop limit " << pop_limit << std::endl; - AwfulGlobalIncremental.reset(IncrementalBase::Load(model_file, weights)); +IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector &weights) { + lm::ngram::ModelType model_type; + if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; + switch (model_type) { + case lm::ngram::PROBING: + return new Incremental(model_file, weights); + case lm::ngram::REST_PROBING: + return new Incremental(model_file, weights); + default: + UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); } - AwfulGlobalIncremental->Search(pop_limit, hg); } + +IncrementalBase::~IncrementalBase() {} + +IncrementalBase::IncrementalBase(const std::vector &weights) : cdec_weights_(weights) {} -- cgit v1.2.3