diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-10-22 11:35:31 -0400 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2012-10-22 11:36:10 -0400 |
commit | 7278218f4581ed8da3dacbff9c7ff3834c292dab (patch) | |
tree | f343175050c24ba50cc8316d210f3fff52ff8323 | |
parent | 310e06dea1f4fd1eb1d3a8a80ee3ad57358188e1 (diff) |
Remove global variable, have decoder hold a pointer
-rw-r--r-- | decoder/decoder.cc | 9 | ||||
-rw-r--r-- | decoder/incremental.cc | 69 | ||||
-rw-r--r-- | decoder/incremental.h | 14 |
3 files changed, 47 insertions, 45 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc index fe812011..b5f4b9b6 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -4,6 +4,7 @@ #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> #include <boost/make_shared.hpp> +#include <boost/scoped_ptr.hpp> #include "program_options.h" #include "stringlib.h" @@ -325,6 +326,8 @@ struct DecoderImpl { bool feature_expectations; // TODO Observer bool output_training_vector; // TODO Observer bool remove_intersected_rule_annotations; + boost::scoped_ptr<IncrementalBase> incremental; + static void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) { for (SparseVector<prob_t>::const_iterator it = src.begin(); it != src.end(); ++it) @@ -727,6 +730,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream sent_id = -1; acc_obj = 0; // accumulate objective g_count = 0; // number of gradient pieces computed + + if (conf.count("incremental_search")) { + incremental.reset(IncrementalBase::Load(conf["incremental_search"].as<string>().c_str(), CurrentWeightVector())); + } } Decoder::Decoder(istream* cfg) { pimpl_.reset(new DecoderImpl(conf,0,0,cfg)); } @@ -829,7 +836,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { HypergraphIO::WriteTarget(conf["show_target_graph"].as<string>(), sent_id, forest); if (conf.count("incremental_search")) { - PassToIncremental(conf["incremental_search"].as<string>().c_str(), CurrentWeightVector(), pop_limit, forest); + incremental->Search(pop_limit, forest); o->NotifyDecodingComplete(smeta); return true; } diff --git a/decoder/incremental.cc b/decoder/incremental.cc index a9369374..46615b0b 100644 --- a/decoder/incremental.cc +++ b/decoder/incremental.cc @@ -43,21 +43,22 @@ struct MapVocab : public lm::EnumerateVocab { std::vector<lm::WordIndex> out_; }; -class IncrementalBase { +template <class Model> class Incremental : public IncrementalBase { public: - IncrementalBase(const std::vector<weight_t> &weights) : - cdec_weights_(weights), - weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) { + Incremental(const char *model_file, const std::vector<weight_t> &weights) : + IncrementalBase(weights), + m_(model_file, GetConfig()), + weights_( + weights[FD::Convert("KLanguageModel")], + weights[FD::Convert("KLanguageModel_OOV")], + weights[FD::Convert("WordPenalty")]) { std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; } + void Search(unsigned int pop_limit, const Hypergraph &hg) const; - virtual ~IncrementalBase() {} - - virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; - - static IncrementalBase *Load(const char *model_file, const std::vector<weight_t> &weights); + private: + void ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; - protected: lm::ngram::Config GetConfig() { lm::ngram::Config ret; ret.enumerate_vocab = &vocab_; @@ -66,36 +67,11 @@ class IncrementalBase { MapVocab vocab_; - const std::vector<weight_t> &cdec_weights_; + const Model m_; const search::Weights weights_; }; -template <class Model> class Incremental : public IncrementalBase { - public: - Incremental(const char *model_file, const std::vector<weight_t> &weights) : IncrementalBase(weights), m_(model_file, GetConfig()) {} - - void Search(unsigned int pop_limit, const Hypergraph &hg) const; - - private: - void ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; - - const Model m_; -}; - -IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector<weight_t> &weights) { - lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; - switch (model_type) { - case lm::ngram::PROBING: - return new Incremental<lm::ngram::ProbingModel>(model_file, weights); - case lm::ngram::REST_PROBING: - return new Incremental<lm::ngram::RestProbingModel>(model_file, weights); - default: - UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); - } -} - void PrintFinal(const Hypergraph &hg, const search::Final final) { const std::vector<WordID> &words = static_cast<const Hypergraph::Edge*>(final.GetNote().vp)->rule_->e(); const search::Final *child(final.Children()); @@ -171,14 +147,21 @@ template <class Model> void Incremental<Model>::ConvertEdge(const search::Contex gen.AddEdge(out); } -boost::scoped_ptr<IncrementalBase> AwfulGlobalIncremental; - } // namespace -void PassToIncremental(const char *model_file, const std::vector<weight_t> &weights, unsigned int pop_limit, const Hypergraph &hg) { - if (!AwfulGlobalIncremental.get()) { - std::cerr << "Pop limit " << pop_limit << std::endl; - AwfulGlobalIncremental.reset(IncrementalBase::Load(model_file, weights)); +IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector<weight_t> &weights) { + lm::ngram::ModelType model_type; + if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; + switch (model_type) { + case lm::ngram::PROBING: + return new Incremental<lm::ngram::ProbingModel>(model_file, weights); + case lm::ngram::REST_PROBING: + return new Incremental<lm::ngram::RestProbingModel>(model_file, weights); + default: + UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); } - AwfulGlobalIncremental->Search(pop_limit, hg); } + +IncrementalBase::~IncrementalBase() {} + +IncrementalBase::IncrementalBase(const std::vector<weight_t> &weights) : cdec_weights_(weights) {} diff --git a/decoder/incremental.h b/decoder/incremental.h index 180383ce..f791a626 100644 --- a/decoder/incremental.h +++ b/decoder/incremental.h @@ -6,6 +6,18 @@ class Hypergraph; -void PassToIncremental(const char *model_file, const std::vector<weight_t> &weights, unsigned int pop_limit, const Hypergraph &hg); +class IncrementalBase { + public: + static IncrementalBase *Load(const char *model_file, const std::vector<weight_t> &weights); + + virtual ~IncrementalBase(); + + virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; + + protected: + IncrementalBase(const std::vector<weight_t> &weights); + + const std::vector<weight_t> &cdec_weights_; +}; #endif // _INCREMENTAL_H_ |