summaryrefslogtreecommitdiff
path: root/decoder/decoder.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/decoder.cc')
-rw-r--r--decoder/decoder.cc124
1 files changed, 85 insertions, 39 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 55d9f1d7..b93925d1 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -46,6 +46,13 @@
#include "cfg_options.h"
#endif
+#ifdef CP_TIME
+ clock_t CpTime::time_;
+ void CpTime::Add(clock_t x){time_+=x;}
+ void CpTime::Sub(clock_t x){time_-=x;}
+ double CpTime::Get(){return (double)(time_)/CLOCKS_PER_SEC;}
+#endif
+
static const double kMINUS_EPSILON = -1e-6; // don't be too strict
using namespace std;
@@ -152,8 +159,7 @@ struct RescoringPass {
shared_ptr<ModelSet> models;
shared_ptr<IntersectionConfiguration> inter_conf;
vector<const FeatureFunction*> ffs;
- shared_ptr<Weights> w; // null == use previous weights
- vector<double> weight_vector;
+ shared_ptr<vector<weight_t> > weight_vector;
int fid_summary; // 0 == no summary feature
double density_prune; // 0 == don't density prune
double beam_prune; // 0 == don't beam prune
@@ -162,7 +168,7 @@ struct RescoringPass {
ostream& operator<<(ostream& os, const RescoringPass& rp) {
os << "[num_fn=" << rp.ffs.size();
if (rp.inter_conf) { os << " int_alg=" << *rp.inter_conf; }
- if (rp.w) os << " new_weights";
+ //if (rp.weight_vector.size() > 0) os << " new_weights";
if (rp.fid_summary) os << " summary_feature=" << FD::Convert(rp.fid_summary);
if (rp.density_prune) os << " density_prune=" << rp.density_prune;
if (rp.beam_prune) os << " beam_prune=" << rp.beam_prune;
@@ -174,13 +180,8 @@ struct DecoderImpl {
DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg);
~DecoderImpl();
bool Decode(const string& input, DecoderObserver*);
- void SetWeights(const vector<double>& weights) {
- init_weights = weights;
- for (int i = 0; i < rescoring_passes.size(); ++i) {
- if (rescoring_passes[i].models)
- rescoring_passes[i].models->SetWeights(weights);
- rescoring_passes[i].weight_vector = weights;
- }
+ vector<weight_t>& CurrentWeightVector() {
+ return (rescoring_passes.empty() ? *init_weights : *rescoring_passes.back().weight_vector);
}
void SetId(int next_sent_id) { sent_id = next_sent_id - 1; }
@@ -293,8 +294,7 @@ struct DecoderImpl {
OracleBleu oracle;
string formalism;
shared_ptr<Translator> translator;
- Weights w_init_weights; // used with initial parse
- vector<double> init_weights; // weights used with initial parse
+ shared_ptr<vector<weight_t> > init_weights; // weights used with initial parse
vector<shared_ptr<FeatureFunction> > pffs;
#ifdef FSA_RESCORING
CFGOptions cfg_options;
@@ -321,10 +321,11 @@ struct DecoderImpl {
bool write_gradient; // TODO Observer
bool feature_expectations; // TODO Observer
bool output_training_vector; // TODO Observer
+ bool remove_intersected_rule_annotations;
static void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) {
for (SparseVector<prob_t>::const_iterator it = src.begin(); it != src.end(); ++it)
- trg->set_value(it->first, it->second);
+ trg->set_value(it->first, it->second.as_float());
}
};
@@ -354,10 +355,13 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)")
("per_sentence_grammar_file", po::value<string>(), "Optional (and possibly not implemented) per sentence grammar file enables all per sentence grammars to be stored in a single large file and accessed by offset")
("list_feature_functions,L","List available feature functions")
+#ifdef HAVE_CMPH
+ ("cmph_perfect_feature_hash,h", po::value<string>(), "Load perfect hash function for features")
+#endif
("weights,w",po::value<string>(),"Feature weights file (initial forest / pass 1)")
("feature_function,F",po::value<vector<string> >()->composing(), "Pass 1 additional feature function(s) (-L for list)")
- ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full")
+ ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full, Fast_cube_pruning, Fast_cube_pruning_2")
("summary_feature", po::value<string>(), "Compute a 'summary feature' at the end of the pass (before any pruning) with name=arg and value=inside-outside/Z")
("summary_feature_type", po::value<string>()->default_value("node_risk"), "Summary feature types: node_risk, edge_risk, edge_prob")
("density_prune", po::value<double>(), "Pass 1 pruning: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
@@ -416,6 +420,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("csplit_output_plf", "(Compound splitter) Output lattice in PLF format")
("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice")
("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file")
+ ("show_derivations", po::value<string>(), "Directory to print the derivation structures to")
("graphviz","Show (constrained) translation forest in GraphViz format")
("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart")
("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart")
@@ -425,7 +430,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)")
("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")
- ("forest_output,O",po::value<string>(),"Directory to write forests to");
+ ("forest_output,O",po::value<string>(),"Directory to write forests to")
+ ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)");
+
// ob.AddOptions(&opts);
#ifdef FSA_RESCORING
po::options_description cfgo(cfg_options.description());
@@ -434,7 +441,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
po::options_description clo("Command line options");
clo.add_options()
("config,c", po::value<vector<string> >(&cfg_files), "Configuration file(s) - latest has priority")
- ("help,h", "Print this help message and exit")
+ ("help,?", "Print this help message and exit")
("usage,u", po::value<string>(), "Describe a feature function type")
("compgen", "Print just option names suitable for bash command line completion builtin 'compgen'")
;
@@ -543,13 +550,18 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
exit(1);
}
- // load initial feature weights (and possibly freeze feature set)
- if (conf.count("weights")) {
- w_init_weights.InitFromFile(str("weights",conf));
- w_init_weights.InitVector(&init_weights);
- init_weights.resize(FD::NumFeats());
+ // load perfect hash function for features
+ if (conf.count("cmph_perfect_feature_hash")) {
+ cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as<string>() << " ...\n";
+ FD::EnableHash(conf["cmph_perfect_feature_hash"].as<string>());
+ cerr << " " << FD::NumFeats() << " features in map\n";
}
+ // load initial feature weights (and possibly freeze feature set)
+ init_weights.reset(new vector<weight_t>);
+ if (conf.count("weights"))
+ Weights::InitFromFile(str("weights",conf), init_weights.get());
+
// cube pruning pop-limit: we may want to configure this on a per-pass basis
pop_limit = conf["cubepruning_pop_limit"].as<int>();
@@ -568,9 +580,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
RescoringPass& rp = rescoring_passes.back();
// only configure new weights if pass > 0, otherwise we reuse the initial chart weights
if (nth_pass_condition && conf.count(ws)) {
- rp.w.reset(new Weights);
- rp.w->InitFromFile(str(ws.c_str(), conf));
- rp.w->InitVector(&rp.weight_vector);
+ rp.weight_vector.reset(new vector<weight_t>());
+ Weights::InitFromFile(str(ws.c_str(), conf), rp.weight_vector.get());
}
bool has_stateful = false;
if (conf.count(ff)) {
@@ -595,6 +606,14 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
if (LowercaseString(str(isn.c_str(),conf)) == "full") {
palg = 0;
}
+ if (LowercaseString(conf["intersection_strategy"].as<string>()) == "fast_cube_pruning") {
+ palg = 2;
+ cerr << "Using Fast Cube Pruning intersection (see Algorithm 2 described in: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010).\n";
+ }
+ if (LowercaseString(conf["intersection_strategy"].as<string>()) == "fast_cube_pruning_2") {
+ palg = 3;
+ cerr << "Using Fast Cube Pruning 2 intersection (see Algorithm 3 described in: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010).\n";
+ }
rp.inter_conf.reset(new IntersectionConfiguration(palg, pop_limit));
} else {
break; // TODO alert user if there are any future configurations
@@ -602,11 +621,15 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
}
// set up weight vectors since later phases may reuse weights from earlier phases
- const vector<double>* prev = &init_weights;
+ shared_ptr<vector<weight_t> > prev_weights = init_weights;
for (int pass = 0; pass < rescoring_passes.size(); ++pass) {
RescoringPass& rp = rescoring_passes[pass];
- if (!rp.w) { rp.weight_vector = *prev; } else { prev = &rp.weight_vector; }
- rp.models.reset(new ModelSet(rp.weight_vector, rp.ffs));
+ if (!rp.weight_vector) {
+ rp.weight_vector = prev_weights;
+ } else {
+ prev_weights = rp.weight_vector;
+ }
+ rp.models.reset(new ModelSet(*rp.weight_vector, rp.ffs));
string ps = "Pass1 "; ps[4] += pass;
if (!SILENT) show_models(conf,*rp.models,ps.c_str());
}
@@ -657,7 +680,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
}
if (!fsa_ffs.empty()) {
cerr<<"FSA: ";
- show_all_features(fsa_ffs,init_weights,cerr,cerr,true,true);
+ show_all_features(fsa_ffs,*init_weights,cerr,cerr,true,true);
}
#endif
@@ -677,6 +700,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
kbest = conf.count("k_best");
unique_kbest = conf.count("unique_k_best");
get_oracle_forest = conf.count("get_oracle_forest");
+ oracle.show_derivation=conf.count("show_derivations");
+ remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations");
#ifdef FSA_RESCORING
cfg_options.Validate();
@@ -703,7 +728,8 @@ bool Decoder::Decode(const string& input, DecoderObserver* o) {
if (del) delete o;
return res;
}
-void Decoder::SetWeights(const vector<double>& weights) { pimpl_->SetWeights(weights); }
+vector<weight_t>& Decoder::CurrentWeightVector() { return pimpl_->CurrentWeightVector(); }
+const vector<weight_t>& Decoder::CurrentWeightVector() const { return pimpl_->CurrentWeightVector(); }
void Decoder::SetSupplementalGrammar(const std::string& grammar_string) {
assert(pimpl_->translator->GetDecoderType() == "SCFG");
static_cast<SCFGTranslator&>(*pimpl_->translator).SetSupplementalGrammar(grammar_string);
@@ -748,7 +774,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
translator->ProcessMarkupHints(smeta.sgml_);
Timer t("Translation");
const bool translation_successful =
- translator->Translate(to_translate, &smeta, init_weights, &forest);
+ translator->Translate(to_translate, &smeta, *init_weights, &forest);
translator->SentenceComplete();
if (!translation_successful) {
@@ -766,10 +792,15 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
const bool show_tree_structure=conf.count("show_tree_structure");
if (!SILENT) forest_stats(forest," Init. forest",show_tree_structure,oracle.show_derivation);
if (conf.count("show_expected_length")) {
- const PRPair<double, double> res =
- Inside<PRPair<double, double>,
- PRWeightFunction<double, EdgeProb, double, ELengthWeightFunction> >(forest);
- cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl;
+ const PRPair<prob_t, prob_t> res =
+ Inside<PRPair<prob_t, prob_t>,
+ PRWeightFunction<prob_t, EdgeProb, prob_t, ELengthWeightFunction> >(forest);
+ cerr << " Expected length (words): " << (res.r / res.p).as_float() << "\t" << res << endl;
+ }
+
+ if (conf.count("show_partition")) {
+ const prob_t z = Inside<prob_t, EdgeProb>(forest);
+ cerr << " Partition log(Z): " << log(z) << endl;
}
SummaryFeature summary_feature_type = kNODE_RISK;
@@ -786,7 +817,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
for (int pass = 0; pass < rescoring_passes.size(); ++pass) {
const RescoringPass& rp = rescoring_passes[pass];
- const vector<double>& cur_weights = rp.weight_vector;
+ const vector<weight_t>& cur_weights = *rp.weight_vector;
if (!SILENT) cerr << endl << " RESCORING PASS #" << (pass+1) << " " << rp << endl;
#ifdef FSA_RESCORING
cfg_options.maybe_output_source(forest);
@@ -799,11 +830,17 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
Timer t("Forest rescoring:");
rp.models->PrepareForInput(smeta);
Hypergraph rescored_forest;
+#ifdef CP_TIME
+ CpTime::Sub(clock());
+#endif
ApplyModelSet(forest,
smeta,
*rp.models,
*rp.inter_conf,
&rescored_forest);
+#ifdef CP_TIME
+ CpTime::Add(clock());
+#endif
forest.swap(rescored_forest);
forest.Reweight(cur_weights);
if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,oracle.show_derivation);
@@ -901,7 +938,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
#endif
}
- const vector<double>& last_weights = (rescoring_passes.empty() ? init_weights : rescoring_passes.back().weight_vector);
+ const vector<double>& last_weights = (rescoring_passes.empty() ? *init_weights : *rescoring_passes.back().weight_vector);
// Oracle Rescoring
if(get_oracle_forest) {
@@ -942,7 +979,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
} else {
if (kbest && !has_ref) {
//TODO: does this work properly?
- oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-");
+ const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-";
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname);
} else if (csplit_output_plf) {
cout << HypergraphIO::AsPLF(forest, false) << endl;
} else {
@@ -989,6 +1027,12 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
// if (!SILENT) cerr << " USING UNIFORM WEIGHTS\n";
// for (int i = 0; i < forest.edges_.size(); ++i)
// forest.edges_[i].edge_prob_=prob_t::One(); }
+ if (remove_intersected_rule_annotations) {
+ for (unsigned i = 0; i < forest.edges_.size(); ++i)
+ if (forest.edges_[i].rule_ &&
+ forest.edges_[i].rule_->parent_rule_)
+ forest.edges_[i].rule_ = forest.edges_[i].rule_->parent_rule_;
+ }
forest.Reweight(last_weights);
if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,oracle.show_derivation);
if (!SILENT) cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl;
@@ -1059,8 +1103,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
}
}
if (conf.count("graphviz")) forest.PrintGraphviz();
- if (kbest)
- oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-");
+ if (kbest) {
+ const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-";
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname);
+ }
if (conf.count("show_conditional_prob")) {
const prob_t ref_z = Inside<prob_t, EdgeProb>(forest);
cout << (log(ref_z) - log(first_z)) << endl << flush;