summaryrefslogtreecommitdiff
path: root/decoder/decoder.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/decoder.cc')
-rw-r--r--decoder/decoder.cc46
1 files changed, 28 insertions, 18 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index a6f7b1ce..b5f4b9b6 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -4,6 +4,7 @@
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
#include <boost/make_shared.hpp>
+#include <boost/scoped_ptr.hpp>
#include "program_options.h"
#include "stringlib.h"
@@ -24,10 +25,12 @@
#include "hg.h"
#include "sentence_metadata.h"
#include "hg_intersect.h"
+#include "hg_union.h"
#include "oracle_bleu.h"
#include "apply_models.h"
#include "ff.h"
+#include "ffset.h"
#include "ff_factory.h"
#include "viterbi.h"
#include "kbest.h"
@@ -37,6 +40,7 @@
#include "sampler.h"
#include "forest_writer.h" // TODO this section should probably be handled by an Observer
+#include "incremental.h"
#include "hg_io.h"
#include "aligner.h"
@@ -89,11 +93,6 @@ inline void ShowBanner() {
cerr << "cdec v1.0 (c) 2009-2011 by Chris Dyer\n";
}
-inline void show_models(po::variables_map const& conf,ModelSet &ms,char const* header) {
- cerr<<header<<": ";
- ms.show_features(cerr,cerr,conf.count("warn_0_weight"));
-}
-
inline string str(char const* name,po::variables_map const& conf) {
return conf[name].as<string>();
}
@@ -131,7 +130,7 @@ inline boost::shared_ptr<FeatureFunction> make_ff(string const& ffp,bool verbose
}
boost::shared_ptr<FeatureFunction> pf = ff_registry.Create(ff, param);
if (!pf) exit(1);
- int nbyte=pf->NumBytesContext();
+ int nbyte=pf->StateSize();
if (verbose_feature_functions && !SILENT)
cerr<<"State is "<<nbyte<<" bytes for "<<pre<<"feature "<<ffp<<endl;
return pf;
@@ -327,6 +326,8 @@ struct DecoderImpl {
bool feature_expectations; // TODO Observer
bool output_training_vector; // TODO Observer
bool remove_intersected_rule_annotations;
+ boost::scoped_ptr<IncrementalBase> incremental;
+
static void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) {
for (SparseVector<prob_t>::const_iterator it = src.begin(); it != src.end(); ++it)
@@ -414,6 +415,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation")
("show_cfg_search_space", "Show the search space as a CFG")
("show_target_graph", po::value<string>(), "Directory to write the target hypergraphs to")
+ ("incremental_search", po::value<string>(), "Run lazy search with this language model file")
("coarse_to_fine_beam_prune", po::value<double>(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)")
("ctf_beam_widen", po::value<double>()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found")
("ctf_num_widenings", po::value<int>()->default_value(2), "Widen coarse beam this many times before backing off to full parse")
@@ -641,8 +643,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
prev_weights = rp.weight_vector;
}
rp.models.reset(new ModelSet(*rp.weight_vector, rp.ffs));
- string ps = "Pass1 "; ps[4] += pass;
- if (!SILENT) show_models(conf,*rp.models,ps.c_str());
}
// show configuration of rescoring passes
@@ -730,6 +730,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
sent_id = -1;
acc_obj = 0; // accumulate objective
g_count = 0; // number of gradient pieces computed
+
+ if (conf.count("incremental_search")) {
+ incremental.reset(IncrementalBase::Load(conf["incremental_search"].as<string>().c_str(), CurrentWeightVector()));
+ }
}
Decoder::Decoder(istream* cfg) { pimpl_.reset(new DecoderImpl(conf,0,0,cfg)); }
@@ -831,6 +835,12 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
if (conf.count("show_target_graph"))
HypergraphIO::WriteTarget(conf["show_target_graph"].as<string>(), sent_id, forest);
+ if (conf.count("incremental_search")) {
+ incremental->Search(pop_limit, forest);
+ o->NotifyDecodingComplete(smeta);
+ return true;
+ }
+
for (int pass = 0; pass < rescoring_passes.size(); ++pass) {
const RescoringPass& rp = rescoring_passes[pass];
const vector<weight_t>& cur_weights = *rp.weight_vector;
@@ -870,13 +880,13 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
if (rp.fid_summary) {
if (summary_feature_type == kEDGE_PROB) {
const prob_t z = forest.PushWeightsToGoal(1.0);
- if (!isfinite(log(z)) || isnan(log(z))) {
+ if (!std::isfinite(log(z)) || std::isnan(log(z))) {
cerr << " " << passtr << " !!! Invalid partition detected, abandoning.\n";
} else {
for (int i = 0; i < forest.edges_.size(); ++i) {
const double log_prob_transition = log(forest.edges_[i].edge_prob_); // locally normalized by the edge
// head node by forest.PushWeightsToGoal
- if (!isfinite(log_prob_transition) || isnan(log_prob_transition)) {
+ if (!std::isfinite(log_prob_transition) || std::isnan(log_prob_transition)) {
cerr << "Edge: i=" << i << " got bad inside prob: " << *forest.edges_[i].rule_ << endl;
abort();
}
@@ -888,7 +898,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
} else if (summary_feature_type == kNODE_RISK) {
Hypergraph::EdgeProbs posts;
const prob_t z = forest.ComputeEdgePosteriors(1.0, &posts);
- if (!isfinite(log(z)) || isnan(log(z))) {
+ if (!std::isfinite(log(z)) || std::isnan(log(z))) {
cerr << " " << passtr << " !!! Invalid partition detected, abandoning.\n";
} else {
for (int i = 0; i < forest.nodes_.size(); ++i) {
@@ -897,7 +907,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
for (int j = 0; j < in_edges.size(); ++j)
node_post += (posts[in_edges[j]] / z);
const double log_np = log(node_post);
- if (!isfinite(log_np) || isnan(log_np)) {
+ if (!std::isfinite(log_np) || std::isnan(log_np)) {
cerr << "got bad posterior prob for node " << i << endl;
abort();
}
@@ -912,13 +922,13 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
} else if (summary_feature_type == kEDGE_RISK) {
Hypergraph::EdgeProbs posts;
const prob_t z = forest.ComputeEdgePosteriors(1.0, &posts);
- if (!isfinite(log(z)) || isnan(log(z))) {
+ if (!std::isfinite(log(z)) || std::isnan(log(z))) {
cerr << " " << passtr << " !!! Invalid partition detected, abandoning.\n";
} else {
assert(posts.size() == forest.edges_.size());
for (int i = 0; i < posts.size(); ++i) {
const double log_np = log(posts[i] / z);
- if (!isfinite(log_np) || isnan(log_np)) {
+ if (!std::isfinite(log_np) || std::isnan(log_np)) {
cerr << "got bad posterior prob for node " << i << endl;
abort();
}
@@ -958,7 +968,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
// Oracle Rescoring
if(get_oracle_forest) {
- assert(!"this is broken"); FeatureVector dummy; // = last_weights
+ assert(!"this is broken"); SparseVector<double> dummy; // = last_weights
Oracle oc=oracle.ComputeOracle(smeta,&forest,dummy,10,conf["forest_output"].as<std::string>());
if (!SILENT) cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
if (!SILENT) cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl;
@@ -980,7 +990,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
if (!succeeded) abort();
}
- new_hg.Union(forest);
+ HG::Union(forest, &new_hg);
bool succeeded = writer.Write(new_hg, false);
if (!succeeded) abort();
} else {
@@ -1067,7 +1077,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
if (!succeeded) abort();
}
- new_hg.Union(forest);
+ HG::Union(forest, &new_hg);
bool succeeded = writer.Write(new_hg, false);
if (!succeeded) abort();
} else {
@@ -1089,7 +1099,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
cerr << "DIFF. ERR! log_z < log_ref_z: " << log_z << " " << log_ref_z << endl;
exit(1);
}
- assert(!isnan(log_ref_z));
+ assert(!std::isnan(log_ref_z));
ref_exp -= full_exp;
acc_vec += ref_exp;
acc_obj += (log_z - log_ref_z);