diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2012-10-25 16:06:02 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2012-10-25 16:06:02 -0400 |
commit | 8682110b8162fad3bd59d8244fe3fd56fa5d354e (patch) | |
tree | d17ef352fdf5eab29888a22a0ffc7f273d533f4e /decoder/incremental.cc | |
parent | df5445c3651fa1cc99ed4bdb682dcf57092dd4e2 (diff) | |
parent | b015577f42314efe57c1791c6d41885ef6b3487c (diff) |
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'decoder/incremental.cc')
-rw-r--r-- | decoder/incremental.cc | 69 |
1 files changed, 26 insertions, 43 deletions
diff --git a/decoder/incremental.cc b/decoder/incremental.cc index a9369374..46615b0b 100644 --- a/decoder/incremental.cc +++ b/decoder/incremental.cc @@ -43,21 +43,22 @@ struct MapVocab : public lm::EnumerateVocab { std::vector<lm::WordIndex> out_; }; -class IncrementalBase { +template <class Model> class Incremental : public IncrementalBase { public: - IncrementalBase(const std::vector<weight_t> &weights) : - cdec_weights_(weights), - weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) { + Incremental(const char *model_file, const std::vector<weight_t> &weights) : + IncrementalBase(weights), + m_(model_file, GetConfig()), + weights_( + weights[FD::Convert("KLanguageModel")], + weights[FD::Convert("KLanguageModel_OOV")], + weights[FD::Convert("WordPenalty")]) { std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; } + void Search(unsigned int pop_limit, const Hypergraph &hg) const; - virtual ~IncrementalBase() {} - - virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; - - static IncrementalBase *Load(const char *model_file, const std::vector<weight_t> &weights); + private: + void ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; - protected: lm::ngram::Config GetConfig() { lm::ngram::Config ret; ret.enumerate_vocab = &vocab_; @@ -66,36 +67,11 @@ class IncrementalBase { MapVocab vocab_; - const std::vector<weight_t> &cdec_weights_; + const Model m_; const search::Weights weights_; }; -template <class Model> class Incremental : public IncrementalBase { - public: - Incremental(const char *model_file, const std::vector<weight_t> &weights) : IncrementalBase(weights), m_(model_file, GetConfig()) {} - - void Search(unsigned int pop_limit, const Hypergraph &hg) const; - - private: - void ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; - - const Model m_; -}; - -IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector<weight_t> &weights) { - lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; - switch (model_type) { - case lm::ngram::PROBING: - return new Incremental<lm::ngram::ProbingModel>(model_file, weights); - case lm::ngram::REST_PROBING: - return new Incremental<lm::ngram::RestProbingModel>(model_file, weights); - default: - UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); - } -} - void PrintFinal(const Hypergraph &hg, const search::Final final) { const std::vector<WordID> &words = static_cast<const Hypergraph::Edge*>(final.GetNote().vp)->rule_->e(); const search::Final *child(final.Children()); @@ -171,14 +147,21 @@ template <class Model> void Incremental<Model>::ConvertEdge(const search::Contex gen.AddEdge(out); } -boost::scoped_ptr<IncrementalBase> AwfulGlobalIncremental; - } // namespace -void PassToIncremental(const char *model_file, const std::vector<weight_t> &weights, unsigned int pop_limit, const Hypergraph &hg) { - if (!AwfulGlobalIncremental.get()) { - std::cerr << "Pop limit " << pop_limit << std::endl; - AwfulGlobalIncremental.reset(IncrementalBase::Load(model_file, weights)); +IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector<weight_t> &weights) { + lm::ngram::ModelType model_type; + if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; + switch (model_type) { + case lm::ngram::PROBING: + return new Incremental<lm::ngram::ProbingModel>(model_file, weights); + case lm::ngram::REST_PROBING: + return new Incremental<lm::ngram::RestProbingModel>(model_file, weights); + default: + UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); } - AwfulGlobalIncremental->Search(pop_limit, hg); } + +IncrementalBase::~IncrementalBase() {} + +IncrementalBase::IncrementalBase(const std::vector<weight_t> &weights) : cdec_weights_(weights) {} |