From 20ac14856accd1532fe7cbeb3ad4cf26dbeb80b1 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 22 Oct 2012 11:35:31 -0400 Subject: Remove global variable, have decoder hold a pointer --- decoder/decoder.cc | 9 ++++++- decoder/incremental.cc | 69 +++++++++++++++++++------------------------------- decoder/incremental.h | 14 +++++++++- 3 files changed, 47 insertions(+), 45 deletions(-) diff --git a/decoder/decoder.cc b/decoder/decoder.cc index fe812011..b5f4b9b6 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include "program_options.h" #include "stringlib.h" @@ -325,6 +326,8 @@ struct DecoderImpl { bool feature_expectations; // TODO Observer bool output_training_vector; // TODO Observer bool remove_intersected_rule_annotations; + boost::scoped_ptr incremental; + static void ConvertSV(const SparseVector& src, SparseVector* trg) { for (SparseVector::const_iterator it = src.begin(); it != src.end(); ++it) @@ -727,6 +730,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream sent_id = -1; acc_obj = 0; // accumulate objective g_count = 0; // number of gradient pieces computed + + if (conf.count("incremental_search")) { + incremental.reset(IncrementalBase::Load(conf["incremental_search"].as().c_str(), CurrentWeightVector())); + } } Decoder::Decoder(istream* cfg) { pimpl_.reset(new DecoderImpl(conf,0,0,cfg)); } @@ -829,7 +836,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { HypergraphIO::WriteTarget(conf["show_target_graph"].as(), sent_id, forest); if (conf.count("incremental_search")) { - PassToIncremental(conf["incremental_search"].as().c_str(), CurrentWeightVector(), pop_limit, forest); + incremental->Search(pop_limit, forest); o->NotifyDecodingComplete(smeta); return true; } diff --git a/decoder/incremental.cc b/decoder/incremental.cc index a9369374..46615b0b 100644 --- a/decoder/incremental.cc +++ b/decoder/incremental.cc @@ -43,21 +43,22 @@ struct MapVocab : public lm::EnumerateVocab { std::vector out_; }; -class IncrementalBase { +template class Incremental : public IncrementalBase { public: - IncrementalBase(const std::vector &weights) : - cdec_weights_(weights), - weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) { + Incremental(const char *model_file, const std::vector &weights) : + IncrementalBase(weights), + m_(model_file, GetConfig()), + weights_( + weights[FD::Convert("KLanguageModel")], + weights[FD::Convert("KLanguageModel_OOV")], + weights[FD::Convert("WordPenalty")]) { std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; } + void Search(unsigned int pop_limit, const Hypergraph &hg) const; - virtual ~IncrementalBase() {} - - virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; - - static IncrementalBase *Load(const char *model_file, const std::vector &weights); + private: + void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; - protected: lm::ngram::Config GetConfig() { lm::ngram::Config ret; ret.enumerate_vocab = &vocab_; @@ -66,36 +67,11 @@ class IncrementalBase { MapVocab vocab_; - const std::vector &cdec_weights_; + const Model m_; const search::Weights weights_; }; -template class Incremental : public IncrementalBase { - public: - Incremental(const char *model_file, const std::vector &weights) : IncrementalBase(weights), m_(model_file, GetConfig()) {} - - void Search(unsigned int pop_limit, const Hypergraph &hg) const; - - private: - void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; - - const Model m_; -}; - -IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector &weights) { - lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; - switch (model_type) { - case lm::ngram::PROBING: - return new Incremental(model_file, weights); - case lm::ngram::REST_PROBING: - return new Incremental(model_file, weights); - default: - UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); - } -} - void PrintFinal(const Hypergraph &hg, const search::Final final) { const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); const search::Final *child(final.Children()); @@ -171,14 +147,21 @@ template void Incremental::ConvertEdge(const search::Contex gen.AddEdge(out); } -boost::scoped_ptr AwfulGlobalIncremental; - } // namespace -void PassToIncremental(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg) { - if (!AwfulGlobalIncremental.get()) { - std::cerr << "Pop limit " << pop_limit << std::endl; - AwfulGlobalIncremental.reset(IncrementalBase::Load(model_file, weights)); +IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector &weights) { + lm::ngram::ModelType model_type; + if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; + switch (model_type) { + case lm::ngram::PROBING: + return new Incremental(model_file, weights); + case lm::ngram::REST_PROBING: + return new Incremental(model_file, weights); + default: + UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); } - AwfulGlobalIncremental->Search(pop_limit, hg); } + +IncrementalBase::~IncrementalBase() {} + +IncrementalBase::IncrementalBase(const std::vector &weights) : cdec_weights_(weights) {} diff --git a/decoder/incremental.h b/decoder/incremental.h index 180383ce..f791a626 100644 --- a/decoder/incremental.h +++ b/decoder/incremental.h @@ -6,6 +6,18 @@ class Hypergraph; -void PassToIncremental(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg); +class IncrementalBase { + public: + static IncrementalBase *Load(const char *model_file, const std::vector &weights); + + virtual ~IncrementalBase(); + + virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; + + protected: + IncrementalBase(const std::vector &weights); + + const std::vector &cdec_weights_; +}; #endif // _INCREMENTAL_H_ -- cgit v1.2.3