summaryrefslogtreecommitdiff
path: root/decoder/incremental.cc
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-10-22 11:35:31 -0400
committerKenneth Heafield <github@kheafield.com>2012-10-22 11:36:10 -0400
commit7278218f4581ed8da3dacbff9c7ff3834c292dab (patch)
treef343175050c24ba50cc8316d210f3fff52ff8323 /decoder/incremental.cc
parent310e06dea1f4fd1eb1d3a8a80ee3ad57358188e1 (diff)
Remove global variable, have decoder hold a pointer
Diffstat (limited to 'decoder/incremental.cc')
-rw-r--r--decoder/incremental.cc69
1 files changed, 26 insertions, 43 deletions
diff --git a/decoder/incremental.cc b/decoder/incremental.cc
index a9369374..46615b0b 100644
--- a/decoder/incremental.cc
+++ b/decoder/incremental.cc
@@ -43,21 +43,22 @@ struct MapVocab : public lm::EnumerateVocab {
std::vector<lm::WordIndex> out_;
};
-class IncrementalBase {
+template <class Model> class Incremental : public IncrementalBase {
public:
- IncrementalBase(const std::vector<weight_t> &weights) :
- cdec_weights_(weights),
- weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) {
+ Incremental(const char *model_file, const std::vector<weight_t> &weights) :
+ IncrementalBase(weights),
+ m_(model_file, GetConfig()),
+ weights_(
+ weights[FD::Convert("KLanguageModel")],
+ weights[FD::Convert("KLanguageModel_OOV")],
+ weights[FD::Convert("WordPenalty")]) {
std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl;
}
+ void Search(unsigned int pop_limit, const Hypergraph &hg) const;
- virtual ~IncrementalBase() {}
-
- virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0;
-
- static IncrementalBase *Load(const char *model_file, const std::vector<weight_t> &weights);
+ private:
+ void ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const;
- protected:
lm::ngram::Config GetConfig() {
lm::ngram::Config ret;
ret.enumerate_vocab = &vocab_;
@@ -66,36 +67,11 @@ class IncrementalBase {
MapVocab vocab_;
- const std::vector<weight_t> &cdec_weights_;
+ const Model m_;
const search::Weights weights_;
};
-template <class Model> class Incremental : public IncrementalBase {
- public:
- Incremental(const char *model_file, const std::vector<weight_t> &weights) : IncrementalBase(weights), m_(model_file, GetConfig()) {}
-
- void Search(unsigned int pop_limit, const Hypergraph &hg) const;
-
- private:
- void ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const;
-
- const Model m_;
-};
-
-IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector<weight_t> &weights) {
- lm::ngram::ModelType model_type;
- if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING;
- switch (model_type) {
- case lm::ngram::PROBING:
- return new Incremental<lm::ngram::ProbingModel>(model_file, weights);
- case lm::ngram::REST_PROBING:
- return new Incremental<lm::ngram::RestProbingModel>(model_file, weights);
- default:
- UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet.");
- }
-}
-
void PrintFinal(const Hypergraph &hg, const search::Final final) {
const std::vector<WordID> &words = static_cast<const Hypergraph::Edge*>(final.GetNote().vp)->rule_->e();
const search::Final *child(final.Children());
@@ -171,14 +147,21 @@ template <class Model> void Incremental<Model>::ConvertEdge(const search::Contex
gen.AddEdge(out);
}
-boost::scoped_ptr<IncrementalBase> AwfulGlobalIncremental;
-
} // namespace
-void PassToIncremental(const char *model_file, const std::vector<weight_t> &weights, unsigned int pop_limit, const Hypergraph &hg) {
- if (!AwfulGlobalIncremental.get()) {
- std::cerr << "Pop limit " << pop_limit << std::endl;
- AwfulGlobalIncremental.reset(IncrementalBase::Load(model_file, weights));
+IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector<weight_t> &weights) {
+ lm::ngram::ModelType model_type;
+ if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING;
+ switch (model_type) {
+ case lm::ngram::PROBING:
+ return new Incremental<lm::ngram::ProbingModel>(model_file, weights);
+ case lm::ngram::REST_PROBING:
+ return new Incremental<lm::ngram::RestProbingModel>(model_file, weights);
+ default:
+ UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet.");
}
- AwfulGlobalIncremental->Search(pop_limit, hg);
}
+
+IncrementalBase::~IncrementalBase() {}
+
+IncrementalBase::IncrementalBase(const std::vector<weight_t> &weights) : cdec_weights_(weights) {}