diff options
Diffstat (limited to 'decoder/decoder.cc')
-rw-r--r-- | decoder/decoder.cc | 124 |
1 files changed, 85 insertions, 39 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 55d9f1d7..b93925d1 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -46,6 +46,13 @@ #include "cfg_options.h" #endif +#ifdef CP_TIME + clock_t CpTime::time_; + void CpTime::Add(clock_t x){time_+=x;} + void CpTime::Sub(clock_t x){time_-=x;} + double CpTime::Get(){return (double)(time_)/CLOCKS_PER_SEC;} +#endif + static const double kMINUS_EPSILON = -1e-6; // don't be too strict using namespace std; @@ -152,8 +159,7 @@ struct RescoringPass { shared_ptr<ModelSet> models; shared_ptr<IntersectionConfiguration> inter_conf; vector<const FeatureFunction*> ffs; - shared_ptr<Weights> w; // null == use previous weights - vector<double> weight_vector; + shared_ptr<vector<weight_t> > weight_vector; int fid_summary; // 0 == no summary feature double density_prune; // 0 == don't density prune double beam_prune; // 0 == don't beam prune @@ -162,7 +168,7 @@ struct RescoringPass { ostream& operator<<(ostream& os, const RescoringPass& rp) { os << "[num_fn=" << rp.ffs.size(); if (rp.inter_conf) { os << " int_alg=" << *rp.inter_conf; } - if (rp.w) os << " new_weights"; + //if (rp.weight_vector.size() > 0) os << " new_weights"; if (rp.fid_summary) os << " summary_feature=" << FD::Convert(rp.fid_summary); if (rp.density_prune) os << " density_prune=" << rp.density_prune; if (rp.beam_prune) os << " beam_prune=" << rp.beam_prune; @@ -174,13 +180,8 @@ struct DecoderImpl { DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg); ~DecoderImpl(); bool Decode(const string& input, DecoderObserver*); - void SetWeights(const vector<double>& weights) { - init_weights = weights; - for (int i = 0; i < rescoring_passes.size(); ++i) { - if (rescoring_passes[i].models) - rescoring_passes[i].models->SetWeights(weights); - rescoring_passes[i].weight_vector = weights; - } + vector<weight_t>& CurrentWeightVector() { + return (rescoring_passes.empty() ? *init_weights : *rescoring_passes.back().weight_vector); } void SetId(int next_sent_id) { sent_id = next_sent_id - 1; } @@ -293,8 +294,7 @@ struct DecoderImpl { OracleBleu oracle; string formalism; shared_ptr<Translator> translator; - Weights w_init_weights; // used with initial parse - vector<double> init_weights; // weights used with initial parse + shared_ptr<vector<weight_t> > init_weights; // weights used with initial parse vector<shared_ptr<FeatureFunction> > pffs; #ifdef FSA_RESCORING CFGOptions cfg_options; @@ -321,10 +321,11 @@ struct DecoderImpl { bool write_gradient; // TODO Observer bool feature_expectations; // TODO Observer bool output_training_vector; // TODO Observer + bool remove_intersected_rule_annotations; static void ConvertSV(const SparseVector<prob_t>& src, SparseVector<double>* trg) { for (SparseVector<prob_t>::const_iterator it = src.begin(); it != src.end(); ++it) - trg->set_value(it->first, it->second); + trg->set_value(it->first, it->second.as_float()); } }; @@ -354,10 +355,13 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)") ("per_sentence_grammar_file", po::value<string>(), "Optional (and possibly not implemented) per sentence grammar file enables all per sentence grammars to be stored in a single large file and accessed by offset") ("list_feature_functions,L","List available feature functions") +#ifdef HAVE_CMPH + ("cmph_perfect_feature_hash,h", po::value<string>(), "Load perfect hash function for features") +#endif ("weights,w",po::value<string>(),"Feature weights file (initial forest / pass 1)") ("feature_function,F",po::value<vector<string> >()->composing(), "Pass 1 additional feature function(s) (-L for list)") - ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full") + ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full, Fast_cube_pruning, Fast_cube_pruning_2") ("summary_feature", po::value<string>(), "Compute a 'summary feature' at the end of the pass (before any pruning) with name=arg and value=inside-outside/Z") ("summary_feature_type", po::value<string>()->default_value("node_risk"), "Summary feature types: node_risk, edge_risk, edge_prob") ("density_prune", po::value<double>(), "Pass 1 pruning: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") @@ -416,6 +420,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format") ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice") ("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file") + ("show_derivations", po::value<string>(), "Directory to print the derivation structures to") ("graphviz","Show (constrained) translation forest in GraphViz format") ("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart") ("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart") @@ -425,7 +430,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)") ("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") ("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") - ("forest_output,O",po::value<string>(),"Directory to write forests to"); + ("forest_output,O",po::value<string>(),"Directory to write forests to") + ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)"); + // ob.AddOptions(&opts); #ifdef FSA_RESCORING po::options_description cfgo(cfg_options.description()); @@ -434,7 +441,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream po::options_description clo("Command line options"); clo.add_options() ("config,c", po::value<vector<string> >(&cfg_files), "Configuration file(s) - latest has priority") - ("help,h", "Print this help message and exit") + ("help,?", "Print this help message and exit") ("usage,u", po::value<string>(), "Describe a feature function type") ("compgen", "Print just option names suitable for bash command line completion builtin 'compgen'") ; @@ -543,13 +550,18 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream exit(1); } - // load initial feature weights (and possibly freeze feature set) - if (conf.count("weights")) { - w_init_weights.InitFromFile(str("weights",conf)); - w_init_weights.InitVector(&init_weights); - init_weights.resize(FD::NumFeats()); + // load perfect hash function for features + if (conf.count("cmph_perfect_feature_hash")) { + cerr << "Loading perfect hash function from " << conf["cmph_perfect_feature_hash"].as<string>() << " ...\n"; + FD::EnableHash(conf["cmph_perfect_feature_hash"].as<string>()); + cerr << " " << FD::NumFeats() << " features in map\n"; } + // load initial feature weights (and possibly freeze feature set) + init_weights.reset(new vector<weight_t>); + if (conf.count("weights")) + Weights::InitFromFile(str("weights",conf), init_weights.get()); + // cube pruning pop-limit: we may want to configure this on a per-pass basis pop_limit = conf["cubepruning_pop_limit"].as<int>(); @@ -568,9 +580,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream RescoringPass& rp = rescoring_passes.back(); // only configure new weights if pass > 0, otherwise we reuse the initial chart weights if (nth_pass_condition && conf.count(ws)) { - rp.w.reset(new Weights); - rp.w->InitFromFile(str(ws.c_str(), conf)); - rp.w->InitVector(&rp.weight_vector); + rp.weight_vector.reset(new vector<weight_t>()); + Weights::InitFromFile(str(ws.c_str(), conf), rp.weight_vector.get()); } bool has_stateful = false; if (conf.count(ff)) { @@ -595,6 +606,14 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream if (LowercaseString(str(isn.c_str(),conf)) == "full") { palg = 0; } + if (LowercaseString(conf["intersection_strategy"].as<string>()) == "fast_cube_pruning") { + palg = 2; + cerr << "Using Fast Cube Pruning intersection (see Algorithm 2 described in: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010).\n"; + } + if (LowercaseString(conf["intersection_strategy"].as<string>()) == "fast_cube_pruning_2") { + palg = 3; + cerr << "Using Fast Cube Pruning 2 intersection (see Algorithm 3 described in: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010).\n"; + } rp.inter_conf.reset(new IntersectionConfiguration(palg, pop_limit)); } else { break; // TODO alert user if there are any future configurations @@ -602,11 +621,15 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream } // set up weight vectors since later phases may reuse weights from earlier phases - const vector<double>* prev = &init_weights; + shared_ptr<vector<weight_t> > prev_weights = init_weights; for (int pass = 0; pass < rescoring_passes.size(); ++pass) { RescoringPass& rp = rescoring_passes[pass]; - if (!rp.w) { rp.weight_vector = *prev; } else { prev = &rp.weight_vector; } - rp.models.reset(new ModelSet(rp.weight_vector, rp.ffs)); + if (!rp.weight_vector) { + rp.weight_vector = prev_weights; + } else { + prev_weights = rp.weight_vector; + } + rp.models.reset(new ModelSet(*rp.weight_vector, rp.ffs)); string ps = "Pass1 "; ps[4] += pass; if (!SILENT) show_models(conf,*rp.models,ps.c_str()); } @@ -657,7 +680,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream } if (!fsa_ffs.empty()) { cerr<<"FSA: "; - show_all_features(fsa_ffs,init_weights,cerr,cerr,true,true); + show_all_features(fsa_ffs,*init_weights,cerr,cerr,true,true); } #endif @@ -677,6 +700,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream kbest = conf.count("k_best"); unique_kbest = conf.count("unique_k_best"); get_oracle_forest = conf.count("get_oracle_forest"); + oracle.show_derivation=conf.count("show_derivations"); + remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations"); #ifdef FSA_RESCORING cfg_options.Validate(); @@ -703,7 +728,8 @@ bool Decoder::Decode(const string& input, DecoderObserver* o) { if (del) delete o; return res; } -void Decoder::SetWeights(const vector<double>& weights) { pimpl_->SetWeights(weights); } +vector<weight_t>& Decoder::CurrentWeightVector() { return pimpl_->CurrentWeightVector(); } +const vector<weight_t>& Decoder::CurrentWeightVector() const { return pimpl_->CurrentWeightVector(); } void Decoder::SetSupplementalGrammar(const std::string& grammar_string) { assert(pimpl_->translator->GetDecoderType() == "SCFG"); static_cast<SCFGTranslator&>(*pimpl_->translator).SetSupplementalGrammar(grammar_string); @@ -748,7 +774,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { translator->ProcessMarkupHints(smeta.sgml_); Timer t("Translation"); const bool translation_successful = - translator->Translate(to_translate, &smeta, init_weights, &forest); + translator->Translate(to_translate, &smeta, *init_weights, &forest); translator->SentenceComplete(); if (!translation_successful) { @@ -766,10 +792,15 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { const bool show_tree_structure=conf.count("show_tree_structure"); if (!SILENT) forest_stats(forest," Init. forest",show_tree_structure,oracle.show_derivation); if (conf.count("show_expected_length")) { - const PRPair<double, double> res = - Inside<PRPair<double, double>, - PRWeightFunction<double, EdgeProb, double, ELengthWeightFunction> >(forest); - cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl; + const PRPair<prob_t, prob_t> res = + Inside<PRPair<prob_t, prob_t>, + PRWeightFunction<prob_t, EdgeProb, prob_t, ELengthWeightFunction> >(forest); + cerr << " Expected length (words): " << (res.r / res.p).as_float() << "\t" << res << endl; + } + + if (conf.count("show_partition")) { + const prob_t z = Inside<prob_t, EdgeProb>(forest); + cerr << " Partition log(Z): " << log(z) << endl; } SummaryFeature summary_feature_type = kNODE_RISK; @@ -786,7 +817,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { for (int pass = 0; pass < rescoring_passes.size(); ++pass) { const RescoringPass& rp = rescoring_passes[pass]; - const vector<double>& cur_weights = rp.weight_vector; + const vector<weight_t>& cur_weights = *rp.weight_vector; if (!SILENT) cerr << endl << " RESCORING PASS #" << (pass+1) << " " << rp << endl; #ifdef FSA_RESCORING cfg_options.maybe_output_source(forest); @@ -799,11 +830,17 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { Timer t("Forest rescoring:"); rp.models->PrepareForInput(smeta); Hypergraph rescored_forest; +#ifdef CP_TIME + CpTime::Sub(clock()); +#endif ApplyModelSet(forest, smeta, *rp.models, *rp.inter_conf, &rescored_forest); +#ifdef CP_TIME + CpTime::Add(clock()); +#endif forest.swap(rescored_forest); forest.Reweight(cur_weights); if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,oracle.show_derivation); @@ -901,7 +938,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { #endif } - const vector<double>& last_weights = (rescoring_passes.empty() ? init_weights : rescoring_passes.back().weight_vector); + const vector<double>& last_weights = (rescoring_passes.empty() ? *init_weights : *rescoring_passes.back().weight_vector); // Oracle Rescoring if(get_oracle_forest) { @@ -942,7 +979,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { } else { if (kbest && !has_ref) { //TODO: does this work properly? - oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-"); + const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; + oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname); } else if (csplit_output_plf) { cout << HypergraphIO::AsPLF(forest, false) << endl; } else { @@ -989,6 +1027,12 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { // if (!SILENT) cerr << " USING UNIFORM WEIGHTS\n"; // for (int i = 0; i < forest.edges_.size(); ++i) // forest.edges_[i].edge_prob_=prob_t::One(); } + if (remove_intersected_rule_annotations) { + for (unsigned i = 0; i < forest.edges_.size(); ++i) + if (forest.edges_[i].rule_ && + forest.edges_[i].rule_->parent_rule_) + forest.edges_[i].rule_ = forest.edges_[i].rule_->parent_rule_; + } forest.Reweight(last_weights); if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,oracle.show_derivation); if (!SILENT) cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl; @@ -1059,8 +1103,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { } } if (conf.count("graphviz")) forest.PrintGraphviz(); - if (kbest) - oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-"); + if (kbest) { + const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; + oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname); + } if (conf.count("show_conditional_prob")) { const prob_t ref_z = Inside<prob_t, EdgeProb>(forest); cout << (log(ref_z) - log(first_z)) << endl << flush; |