summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-10-25 16:06:02 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-10-25 16:06:02 -0400
commit8682110b8162fad3bd59d8244fe3fd56fa5d354e (patch)
treed17ef352fdf5eab29888a22a0ffc7f273d533f4e
parentdf5445c3651fa1cc99ed4bdb682dcf57092dd4e2 (diff)
parentb015577f42314efe57c1791c6d41885ef6b3487c (diff)
Merge branch 'master' of https://github.com/redpony/cdec
-rw-r--r--decoder/decoder.cc9
-rw-r--r--decoder/incremental.cc69
-rw-r--r--decoder/incremental.h14
3 files changed, 47 insertions, 45 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index fe812011..b5f4b9b6 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -4,6 +4,7 @@
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
#include <boost/make_shared.hpp>
+#include <boost/scoped_ptr.hpp>
#include "program_options.h"
#include "stringlib.h"
@@ -325,6 +326,8 @@ struct DecoderImpl {
bool feature_expectations; // TODO Observer
bool output_training_vector; // TODO Observer
bool remove_intersected_rule_annotations;
+ boost::scoped_ptr<IncrementalBase> incremental;
+
static void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) {
for (SparseVector<prob_t>::const_iterator it = src.begin(); it != src.end(); ++it)
@@ -727,6 +730,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
sent_id = -1;
acc_obj = 0; // accumulate objective
g_count = 0; // number of gradient pieces computed
+
+ if (conf.count("incremental_search")) {
+ incremental.reset(IncrementalBase::Load(conf["incremental_search"].as<string>().c_str(), CurrentWeightVector()));
+ }
}
Decoder::Decoder(istream* cfg) { pimpl_.reset(new DecoderImpl(conf,0,0,cfg)); }
@@ -829,7 +836,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
HypergraphIO::WriteTarget(conf["show_target_graph"].as<string>(), sent_id, forest);
if (conf.count("incremental_search")) {
- PassToIncremental(conf["incremental_search"].as<string>().c_str(), CurrentWeightVector(), pop_limit, forest);
+ incremental->Search(pop_limit, forest);
o->NotifyDecodingComplete(smeta);
return true;
}
diff --git a/decoder/incremental.cc b/decoder/incremental.cc
index a9369374..46615b0b 100644
--- a/decoder/incremental.cc
+++ b/decoder/incremental.cc
@@ -43,21 +43,22 @@ struct MapVocab : public lm::EnumerateVocab {
std::vector<lm::WordIndex> out_;
};
-class IncrementalBase {
+template <class Model> class Incremental : public IncrementalBase {
public:
- IncrementalBase(const std::vector<weight_t> &weights) :
- cdec_weights_(weights),
- weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) {
+ Incremental(const char *model_file, const std::vector<weight_t> &weights) :
+ IncrementalBase(weights),
+ m_(model_file, GetConfig()),
+ weights_(
+ weights[FD::Convert("KLanguageModel")],
+ weights[FD::Convert("KLanguageModel_OOV")],
+ weights[FD::Convert("WordPenalty")]) {
std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl;
}
+ void Search(unsigned int pop_limit, const Hypergraph &hg) const;
- virtual ~IncrementalBase() {}
-
- virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0;
-
- static IncrementalBase *Load(const char *model_file, const std::vector<weight_t> &weights);
+ private:
+ void ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const;
- protected:
lm::ngram::Config GetConfig() {
lm::ngram::Config ret;
ret.enumerate_vocab = &vocab_;
@@ -66,36 +67,11 @@ class IncrementalBase {
MapVocab vocab_;
- const std::vector<weight_t> &cdec_weights_;
+ const Model m_;
const search::Weights weights_;
};
-template <class Model> class Incremental : public IncrementalBase {
- public:
- Incremental(const char *model_file, const std::vector<weight_t> &weights) : IncrementalBase(weights), m_(model_file, GetConfig()) {}
-
- void Search(unsigned int pop_limit, const Hypergraph &hg) const;
-
- private:
- void ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const;
-
- const Model m_;
-};
-
-IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector<weight_t> &weights) {
- lm::ngram::ModelType model_type;
- if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING;
- switch (model_type) {
- case lm::ngram::PROBING:
- return new Incremental<lm::ngram::ProbingModel>(model_file, weights);
- case lm::ngram::REST_PROBING:
- return new Incremental<lm::ngram::RestProbingModel>(model_file, weights);
- default:
- UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet.");
- }
-}
-
void PrintFinal(const Hypergraph &hg, const search::Final final) {
const std::vector<WordID> &words = static_cast<const Hypergraph::Edge*>(final.GetNote().vp)->rule_->e();
const search::Final *child(final.Children());
@@ -171,14 +147,21 @@ template <class Model> void Incremental<Model>::ConvertEdge(const search::Contex
gen.AddEdge(out);
}
-boost::scoped_ptr<IncrementalBase> AwfulGlobalIncremental;
-
} // namespace
-void PassToIncremental(const char *model_file, const std::vector<weight_t> &weights, unsigned int pop_limit, const Hypergraph &hg) {
- if (!AwfulGlobalIncremental.get()) {
- std::cerr << "Pop limit " << pop_limit << std::endl;
- AwfulGlobalIncremental.reset(IncrementalBase::Load(model_file, weights));
+IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector<weight_t> &weights) {
+ lm::ngram::ModelType model_type;
+ if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING;
+ switch (model_type) {
+ case lm::ngram::PROBING:
+ return new Incremental<lm::ngram::ProbingModel>(model_file, weights);
+ case lm::ngram::REST_PROBING:
+ return new Incremental<lm::ngram::RestProbingModel>(model_file, weights);
+ default:
+ UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet.");
}
- AwfulGlobalIncremental->Search(pop_limit, hg);
}
+
+IncrementalBase::~IncrementalBase() {}
+
+IncrementalBase::IncrementalBase(const std::vector<weight_t> &weights) : cdec_weights_(weights) {}
diff --git a/decoder/incremental.h b/decoder/incremental.h
index 180383ce..f791a626 100644
--- a/decoder/incremental.h
+++ b/decoder/incremental.h
@@ -6,6 +6,18 @@
class Hypergraph;
-void PassToIncremental(const char *model_file, const std::vector<weight_t> &weights, unsigned int pop_limit, const Hypergraph &hg);
+class IncrementalBase {
+ public:
+ static IncrementalBase *Load(const char *model_file, const std::vector<weight_t> &weights);
+
+ virtual ~IncrementalBase();
+
+ virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0;
+
+ protected:
+ IncrementalBase(const std::vector<weight_t> &weights);
+
+ const std::vector<weight_t> &cdec_weights_;
+};
#endif // _INCREMENTAL_H_