diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/decoder.cc | 9 | ||||
| -rw-r--r-- | decoder/incremental.cc | 69 | ||||
| -rw-r--r-- | decoder/incremental.h | 14 | 
3 files changed, 47 insertions, 45 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc index fe812011..b5f4b9b6 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -4,6 +4,7 @@  #include <boost/program_options.hpp>  #include <boost/program_options/variables_map.hpp>  #include <boost/make_shared.hpp> +#include <boost/scoped_ptr.hpp>  #include "program_options.h"  #include "stringlib.h" @@ -325,6 +326,8 @@ struct DecoderImpl {    bool feature_expectations; // TODO Observer    bool output_training_vector; // TODO Observer    bool remove_intersected_rule_annotations; +  boost::scoped_ptr<IncrementalBase> incremental; +    static void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) {      for (SparseVector<prob_t>::const_iterator it = src.begin(); it != src.end(); ++it) @@ -727,6 +730,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream    sent_id = -1;    acc_obj = 0; // accumulate objective    g_count = 0;    // number of gradient pieces computed + +  if (conf.count("incremental_search")) { +    incremental.reset(IncrementalBase::Load(conf["incremental_search"].as<string>().c_str(), CurrentWeightVector())); +  }  }  Decoder::Decoder(istream* cfg) { pimpl_.reset(new DecoderImpl(conf,0,0,cfg)); } @@ -829,7 +836,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {      HypergraphIO::WriteTarget(conf["show_target_graph"].as<string>(), sent_id, forest);    if (conf.count("incremental_search")) { -    PassToIncremental(conf["incremental_search"].as<string>().c_str(), CurrentWeightVector(), pop_limit, forest); +    incremental->Search(pop_limit, forest);      o->NotifyDecodingComplete(smeta);      return true;    } diff --git a/decoder/incremental.cc b/decoder/incremental.cc index a9369374..46615b0b 100644 --- a/decoder/incremental.cc +++ b/decoder/incremental.cc @@ -43,21 +43,22 @@ struct MapVocab : public lm::EnumerateVocab {      std::vector<lm::WordIndex> out_;  }; -class IncrementalBase { +template <class Model> class Incremental : public IncrementalBase {    public: -    IncrementalBase(const std::vector<weight_t> &weights) :  -      cdec_weights_(weights), -      weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) { +    Incremental(const char *model_file, const std::vector<weight_t> &weights) : +      IncrementalBase(weights),  +      m_(model_file, GetConfig()), +      weights_( +          weights[FD::Convert("KLanguageModel")], +          weights[FD::Convert("KLanguageModel_OOV")], +          weights[FD::Convert("WordPenalty")]) {        std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl;      } +    void Search(unsigned int pop_limit, const Hypergraph &hg) const; -    virtual ~IncrementalBase() {} - -    virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; - -    static IncrementalBase *Load(const char *model_file, const std::vector<weight_t> &weights); +  private: +    void ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; -  protected:      lm::ngram::Config GetConfig() {        lm::ngram::Config ret;        ret.enumerate_vocab = &vocab_; @@ -66,36 +67,11 @@ class IncrementalBase {      MapVocab vocab_; -    const std::vector<weight_t> &cdec_weights_; +    const Model m_;      const search::Weights weights_;  }; -template <class Model> class Incremental : public IncrementalBase { -  public: -    Incremental(const char *model_file, const std::vector<weight_t> &weights) : IncrementalBase(weights), m_(model_file, GetConfig()) {} - -    void Search(unsigned int pop_limit, const Hypergraph &hg) const; - -  private: -    void ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; - -    const Model m_; -}; - -IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector<weight_t> &weights) { -  lm::ngram::ModelType model_type; -  if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; -  switch (model_type) { -    case lm::ngram::PROBING: -      return new Incremental<lm::ngram::ProbingModel>(model_file, weights); -    case lm::ngram::REST_PROBING: -      return new Incremental<lm::ngram::RestProbingModel>(model_file, weights); -    default: -      UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); -  } -} -  void PrintFinal(const Hypergraph &hg, const search::Final final) {    const std::vector<WordID> &words = static_cast<const Hypergraph::Edge*>(final.GetNote().vp)->rule_->e();    const search::Final *child(final.Children()); @@ -171,14 +147,21 @@ template <class Model> void Incremental<Model>::ConvertEdge(const search::Contex    gen.AddEdge(out);  } -boost::scoped_ptr<IncrementalBase> AwfulGlobalIncremental; -  } // namespace -void PassToIncremental(const char *model_file, const std::vector<weight_t> &weights, unsigned int pop_limit, const Hypergraph &hg) { -  if (!AwfulGlobalIncremental.get()) { -    std::cerr << "Pop limit " << pop_limit << std::endl; -    AwfulGlobalIncremental.reset(IncrementalBase::Load(model_file, weights)); +IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector<weight_t> &weights) { +  lm::ngram::ModelType model_type; +  if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; +  switch (model_type) { +    case lm::ngram::PROBING: +      return new Incremental<lm::ngram::ProbingModel>(model_file, weights); +    case lm::ngram::REST_PROBING: +      return new Incremental<lm::ngram::RestProbingModel>(model_file, weights); +    default: +      UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet.");    } -  AwfulGlobalIncremental->Search(pop_limit, hg);  } + +IncrementalBase::~IncrementalBase() {} + +IncrementalBase::IncrementalBase(const std::vector<weight_t> &weights) : cdec_weights_(weights) {} diff --git a/decoder/incremental.h b/decoder/incremental.h index 180383ce..f791a626 100644 --- a/decoder/incremental.h +++ b/decoder/incremental.h @@ -6,6 +6,18 @@  class Hypergraph; -void PassToIncremental(const char *model_file, const std::vector<weight_t> &weights, unsigned int pop_limit, const Hypergraph &hg); +class IncrementalBase { +  public: +    static IncrementalBase *Load(const char *model_file, const std::vector<weight_t> &weights); + +    virtual ~IncrementalBase(); + +    virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; + +  protected: +    IncrementalBase(const std::vector<weight_t> &weights); + +    const std::vector<weight_t> &cdec_weights_; +};  #endif // _INCREMENTAL_H_  | 
