summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--decoder/decoder.cc3
-rw-r--r--decoder/lazy.cc10
-rw-r--r--decoder/lazy.h2
3 files changed, 10 insertions, 5 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 3a410cf2..525c6ba6 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -416,6 +416,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation")
("show_cfg_search_space", "Show the search space as a CFG")
("show_target_graph", po::value<string>(), "Directory to write the target hypergraphs to")
+ ("lazy_search", po::value<string>(), "Run lazy search with this language model file")
("coarse_to_fine_beam_prune", po::value<double>(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)")
("ctf_beam_widen", po::value<double>()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found")
("ctf_num_widenings", po::value<int>()->default_value(2), "Widen coarse beam this many times before backing off to full parse")
@@ -834,7 +835,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
HypergraphIO::WriteTarget(conf["show_target_graph"].as<string>(), sent_id, forest);
if (conf.count("lazy_search"))
- PassToLazy(forest, CurrentWeightVector());
+ PassToLazy(conf["lazy_search"].as<string>().c_str(), CurrentWeightVector(), forest);
for (int pass = 0; pass < rescoring_passes.size(); ++pass) {
const RescoringPass& rp = rescoring_passes[pass];
diff --git a/decoder/lazy.cc b/decoder/lazy.cc
index 4776c1b8..58a9e08a 100644
--- a/decoder/lazy.cc
+++ b/decoder/lazy.cc
@@ -44,7 +44,9 @@ class LazyBase {
public:
LazyBase(const std::vector<weight_t> &weights) :
cdec_weights_(weights),
- config_(search::Weights(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]), 1000) {}
+ config_(search::Weights(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]), 1000) {
+ std::cerr << "Weights KLanguageModel " << config_.GetWeights().LM() << " KLanguageModel_OOV " << config_.GetWeights().OOV() << " WordPenalty " << config_.GetWeights().WordPenalty() << std::endl;
+ }
virtual ~LazyBase() {}
@@ -95,6 +97,7 @@ template <class Model> void Lazy<Model>::Search(const Hypergraph &hg) const {
boost::scoped_array<search::Vertex> out_vertices(new search::Vertex[hg.nodes_.size()]);
boost::scoped_array<search::Edge> out_edges(new search::Edge[hg.edges_.size()]);
+
search::Context<Model> context(config_, m_);
for (unsigned int i = 0; i < hg.nodes_.size(); ++i) {
@@ -141,6 +144,7 @@ template <class Model> void Lazy<Model>::ConvertEdge(const search::Context<Model
}
float additive = in.rule_->GetFeatureValues().dot(cdec_weights_);
+ UTIL_THROW_IF(isnan(additive), util::Exception, "Bad dot product");
additive -= terminals * context.GetWeights().WordPenalty() * static_cast<float>(terminals) / M_LN10;
out.InitRule().Init(context, additive, words, final);
@@ -150,9 +154,9 @@ boost::scoped_ptr<LazyBase> AwfulGlobalLazy;
} // namespace
-void PassToLazy(const Hypergraph &hg, const std::vector<weight_t> &weights) {
+void PassToLazy(const char *model_file, const std::vector<weight_t> &weights, const Hypergraph &hg) {
if (!AwfulGlobalLazy.get()) {
- AwfulGlobalLazy.reset(LazyBase::Load("lm", weights));
+ AwfulGlobalLazy.reset(LazyBase::Load(model_file, weights));
}
AwfulGlobalLazy->Search(hg);
}
diff --git a/decoder/lazy.h b/decoder/lazy.h
index 3e71a3b0..d1f030d1 100644
--- a/decoder/lazy.h
+++ b/decoder/lazy.h
@@ -6,6 +6,6 @@
class Hypergraph;
-void PassToLazy(const Hypergraph &hg, const std::vector<weight_t> &weights);
+void PassToLazy(const char *model_file, const std::vector<weight_t> &weights, const Hypergraph &hg);
#endif // _LAZY_H_