From 18c5b12399eb078dbb8c764a205caf9c610f9a2f Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Thu, 13 Sep 2012 04:28:30 -0700 Subject: Allow lm file name, print weights --- decoder/decoder.cc | 3 ++- decoder/lazy.cc | 10 +++++++--- decoder/lazy.h | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) (limited to 'decoder') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 3a410cf2..525c6ba6 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -416,6 +416,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation") ("show_cfg_search_space", "Show the search space as a CFG") ("show_target_graph", po::value(), "Directory to write the target hypergraphs to") + ("lazy_search", po::value(), "Run lazy search with this language model file") ("coarse_to_fine_beam_prune", po::value(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)") ("ctf_beam_widen", po::value()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found") ("ctf_num_widenings", po::value()->default_value(2), "Widen coarse beam this many times before backing off to full parse") @@ -834,7 +835,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { HypergraphIO::WriteTarget(conf["show_target_graph"].as(), sent_id, forest); if (conf.count("lazy_search")) - PassToLazy(forest, CurrentWeightVector()); + PassToLazy(conf["lazy_search"].as().c_str(), CurrentWeightVector(), forest); for (int pass = 0; pass < rescoring_passes.size(); ++pass) { const RescoringPass& rp = rescoring_passes[pass]; diff --git a/decoder/lazy.cc b/decoder/lazy.cc index 4776c1b8..58a9e08a 100644 --- a/decoder/lazy.cc +++ b/decoder/lazy.cc @@ -44,7 +44,9 @@ class LazyBase { public: LazyBase(const std::vector &weights) : cdec_weights_(weights), - config_(search::Weights(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]), 1000) {} + config_(search::Weights(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]), 1000) { + std::cerr << "Weights KLanguageModel " << config_.GetWeights().LM() << " KLanguageModel_OOV " << config_.GetWeights().OOV() << " WordPenalty " << config_.GetWeights().WordPenalty() << std::endl; + } virtual ~LazyBase() {} @@ -95,6 +97,7 @@ template void Lazy::Search(const Hypergraph &hg) const { boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); boost::scoped_array out_edges(new search::Edge[hg.edges_.size()]); + search::Context context(config_, m_); for (unsigned int i = 0; i < hg.nodes_.size(); ++i) { @@ -141,6 +144,7 @@ template void Lazy::ConvertEdge(const search::ContextGetFeatureValues().dot(cdec_weights_); + UTIL_THROW_IF(isnan(additive), util::Exception, "Bad dot product"); additive -= terminals * context.GetWeights().WordPenalty() * static_cast(terminals) / M_LN10; out.InitRule().Init(context, additive, words, final); @@ -150,9 +154,9 @@ boost::scoped_ptr AwfulGlobalLazy; } // namespace -void PassToLazy(const Hypergraph &hg, const std::vector &weights) { +void PassToLazy(const char *model_file, const std::vector &weights, const Hypergraph &hg) { if (!AwfulGlobalLazy.get()) { - AwfulGlobalLazy.reset(LazyBase::Load("lm", weights)); + AwfulGlobalLazy.reset(LazyBase::Load(model_file, weights)); } AwfulGlobalLazy->Search(hg); } diff --git a/decoder/lazy.h b/decoder/lazy.h index 3e71a3b0..d1f030d1 100644 --- a/decoder/lazy.h +++ b/decoder/lazy.h @@ -6,6 +6,6 @@ class Hypergraph; -void PassToLazy(const Hypergraph &hg, const std::vector &weights); +void PassToLazy(const char *model_file, const std::vector &weights, const Hypergraph &hg); #endif // _LAZY_H_ -- cgit v1.2.3