diff options
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/decoder.cc | 3 | ||||
-rw-r--r-- | decoder/lazy.cc | 10 | ||||
-rw-r--r-- | decoder/lazy.h | 2 |
3 files changed, 10 insertions, 5 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 3a410cf2..525c6ba6 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -416,6 +416,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation") ("show_cfg_search_space", "Show the search space as a CFG") ("show_target_graph", po::value<string>(), "Directory to write the target hypergraphs to") + ("lazy_search", po::value<string>(), "Run lazy search with this language model file") ("coarse_to_fine_beam_prune", po::value<double>(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)") ("ctf_beam_widen", po::value<double>()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found") ("ctf_num_widenings", po::value<int>()->default_value(2), "Widen coarse beam this many times before backing off to full parse") @@ -834,7 +835,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { HypergraphIO::WriteTarget(conf["show_target_graph"].as<string>(), sent_id, forest); if (conf.count("lazy_search")) - PassToLazy(forest, CurrentWeightVector()); + PassToLazy(conf["lazy_search"].as<string>().c_str(), CurrentWeightVector(), forest); for (int pass = 0; pass < rescoring_passes.size(); ++pass) { const RescoringPass& rp = rescoring_passes[pass]; diff --git a/decoder/lazy.cc b/decoder/lazy.cc index 4776c1b8..58a9e08a 100644 --- a/decoder/lazy.cc +++ b/decoder/lazy.cc @@ -44,7 +44,9 @@ class LazyBase { public: LazyBase(const std::vector<weight_t> &weights) : cdec_weights_(weights), - config_(search::Weights(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]), 1000) {} + config_(search::Weights(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]), 1000) { + std::cerr << "Weights KLanguageModel " << config_.GetWeights().LM() << " KLanguageModel_OOV " << config_.GetWeights().OOV() << " WordPenalty " << config_.GetWeights().WordPenalty() << std::endl; + } virtual ~LazyBase() {} @@ -95,6 +97,7 @@ template <class Model> void Lazy<Model>::Search(const Hypergraph &hg) const { boost::scoped_array<search::Vertex> out_vertices(new search::Vertex[hg.nodes_.size()]); boost::scoped_array<search::Edge> out_edges(new search::Edge[hg.edges_.size()]); + search::Context<Model> context(config_, m_); for (unsigned int i = 0; i < hg.nodes_.size(); ++i) { @@ -141,6 +144,7 @@ template <class Model> void Lazy<Model>::ConvertEdge(const search::Context<Model } float additive = in.rule_->GetFeatureValues().dot(cdec_weights_); + UTIL_THROW_IF(isnan(additive), util::Exception, "Bad dot product"); additive -= terminals * context.GetWeights().WordPenalty() * static_cast<float>(terminals) / M_LN10; out.InitRule().Init(context, additive, words, final); @@ -150,9 +154,9 @@ boost::scoped_ptr<LazyBase> AwfulGlobalLazy; } // namespace -void PassToLazy(const Hypergraph &hg, const std::vector<weight_t> &weights) { +void PassToLazy(const char *model_file, const std::vector<weight_t> &weights, const Hypergraph &hg) { if (!AwfulGlobalLazy.get()) { - AwfulGlobalLazy.reset(LazyBase::Load("lm", weights)); + AwfulGlobalLazy.reset(LazyBase::Load(model_file, weights)); } AwfulGlobalLazy->Search(hg); } diff --git a/decoder/lazy.h b/decoder/lazy.h index 3e71a3b0..d1f030d1 100644 --- a/decoder/lazy.h +++ b/decoder/lazy.h @@ -6,6 +6,6 @@ class Hypergraph; -void PassToLazy(const Hypergraph &hg, const std::vector<weight_t> &weights); +void PassToLazy(const char *model_file, const std::vector<weight_t> &weights, const Hypergraph &hg); #endif // _LAZY_H_ |