From 93618c0fcce1544bf948172d04e764f53073cf8a Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 10 Feb 2011 18:45:13 -0500 Subject: multipass decoding --- decoder/decoder.cc | 288 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 195 insertions(+), 93 deletions(-) (limited to 'decoder/decoder.cc') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 25f05d8e..478a1cf3 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -59,7 +59,7 @@ namespace Hack { void MaxTrans(const Hypergraph& in, int beam_size); } namespace NgramCache { void Clear(); } DecoderObserver::~DecoderObserver() {} -void DecoderObserver::NotifyDecodingStart(const SentenceMetadata& smeta) {} +void DecoderObserver::NotifyDecodingStart(const SentenceMetadata&) {} void DecoderObserver::NotifySourceParseFailure(const SentenceMetadata&) {} void DecoderObserver::NotifyTranslationForest(const SentenceMetadata&, Hypergraph*) {} void DecoderObserver::NotifyAlignmentFailure(const SentenceMetadata&) {} @@ -72,7 +72,7 @@ struct ELengthWeightFunction { } }; inline void ShowBanner() { - cerr << "cdec v1.0 (c) 2009-2010 by Chris Dyer\n"; + cerr << "cdec v1.0 (c) 2009-2011 by Chris Dyer\n"; } inline void show_models(po::variables_map const& conf,ModelSet &ms,char const* header) { @@ -135,12 +135,40 @@ inline shared_ptr make_fsa_ff(string const& ffp,bool verbose } #endif +// when the translation forest is first built, it is scored by the features associated +// with the rules. To add other features (like language models, etc), cdec applies one or +// more "rescoring passes", which compute new features and optionally apply new weights +// and then prune the resulting (rescored) hypergraph. All feature values from previous +// passes are present in future passes (although their weights may change). +struct RescoringPass { + RescoringPass() : density_prune(), beam_prune() {} + shared_ptr models; + shared_ptr inter_conf; + vector ffs; + shared_ptr w; // null == use previous weights + vector weight_vector; + double density_prune; // 0 == don't density prune + double beam_prune; // 0 == don't beam prune +}; + +ostream& operator<<(ostream& os, const RescoringPass& rp) { + os << "[num_fn=" << rp.ffs.size(); + if (rp.inter_conf) { os << " int_alg=" << *rp.inter_conf; } + if (rp.w) os << " new_weights"; + if (rp.density_prune) os << " density_prune=" << rp.density_prune; + if (rp.beam_prune) os << " beam_prune=" << rp.beam_prune; + os << ']'; + return os; +} + struct DecoderImpl { DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg); ~DecoderImpl(); bool Decode(const string& input, DecoderObserver*); void SetWeights(const vector& weights) { - feature_weights = weights; + init_weights = weights; + for (int i = 0; i < rescoring_passes.size(); ++i) + rescoring_passes[i].weight_vector = weights; } void SetId(int next_sent_id) { sent_id = next_sent_id - 1; } @@ -159,8 +187,8 @@ struct DecoderImpl { } } - void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,DenseWeightVector const& feature_weights, bool sd=false) { - WeightVector fw(feature_weights); + void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,DenseWeightVector const& weights, bool sd=false) { + WeightVector fw(weights); forest_stats(forest,name,show_tree,show_features,&fw,sd); } @@ -252,21 +280,30 @@ struct DecoderImpl { } } + // used to construct the suffix string to get the name of arguments for multiple passes + // e.g., the "2" in --weights2 + static string StringSuffixForRescoringPass(int pass) { + if (pass == 0) return ""; + string ps = "1"; + assert(pass < 9); + ps[0] += pass; + return ps; + } + + vector rescoring_passes; + po::variables_map& conf; OracleBleu oracle; string formalism; shared_ptr translator; - vector feature_weights; - Weights w; + Weights w_init_weights; // used with initial parse + vector init_weights; // weights used with initial parse vector > pffs; - vector late_ffs; #ifdef FSA_RESCORING CFGOptions cfg_options; vector > fsa_ffs; vector fsa_names; #endif - ModelSet* late_models; - IntersectionConfiguration* inter_conf; shared_ptr > rng; int sample_max_trans; bool aligner_mode; @@ -319,44 +356,58 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("input,i",po::value()->default_value("-"),"Source file") ("grammar,g",po::value >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)") ("per_sentence_grammar_file", po::value(), "Optional (and possibly not implemented) per sentence grammar file enables all per sentence grammars to be stored in a single large file and accessed by offset") - ("weights,w",po::value(),"Feature weights file") - ("warn_0_weight","Warn about any feature id that has a 0 weight (this is perfectly safe if you intend 0 weight, though)") - ("freeze_feature_set,Z", "Freeze feature set after reading feature weights file") - ("feature_function,F",po::value >()->composing(), "Additional feature function(s) (-L for list)") + ("list_feature_functions,L","List available feature functions") + + ("weights,w",po::value(),"Feature weights file (initial forest / pass 1)") + ("feature_function,F",po::value >()->composing(), "Pass 1 additional feature function(s) (-L for list)") + ("intersection_strategy,I",po::value()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full") + ("density_prune", po::value(), "Pass 1 pruning: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") + ("beam_prune", po::value(), "Pass 1 pruning: Prune paths from scored forest, keep paths within exp(alpha>=0)") + + ("weights2",po::value(),"Optional pass 2") + ("feature_function2",po::value >()->composing(), "Optional pass 2") + ("intersection_strategy2",po::value()->default_value("cube_pruning"), "Optional pass 2") + ("density_prune2", po::value(), "Optional pass 2") + ("beam_prune2", po::value(), "Optional pass 2") + + ("weights3",po::value(),"Optional pass 3") + ("feature_function3",po::value >()->composing(), "Optional pass 3") + ("intersection_strategy3",po::value()->default_value("cube_pruning"), "Optional pass 3") + ("density_prune3", po::value(), "Optional pass 3") + ("beam_prune3", po::value(), "Optional pass 3") + #ifdef FSA_RESCORING ("fsa_feature_function,A",po::value >()->composing(), "Additional FSA feature function(s) (-L for list)") ("apply_fsa_by",po::value()->default_value("BU_CUBE"), "Method for applying fsa_feature_functions - BU_FULL BU_CUBE EARLEY") //+ApplyFsaBy::all_names() #endif - ("list_feature_functions,L","List available feature functions") ("add_pass_through_rules,P","Add rules to translate OOV words as themselves") ("k_best,k",po::value(),"Extract the k best derivations") ("unique_k_best,r", "Unique k-best translation list") + ("cubepruning_pop_limit,K",po::value()->default_value(200), "Max number of pops from the candidate heap at each node") ("aligner,a", "Run as a word/phrase aligner (src & ref required)") ("aligner_use_viterbi", "If run in alignment mode, compute the Viterbi (rather than MAP) alignment") - ("intersection_strategy,I",po::value()->default_value("cube_pruning"), "Intersection strategy for incorporating finite-state features; values include Cube_pruning, Full") - ("cubepruning_pop_limit,K",po::value()->default_value(200), "Max number of pops from the candidate heap at each node") ("goal",po::value()->default_value("S"),"Goal symbol (SCFG & FST)") + ("freeze_feature_set,Z", "Freeze feature set after reading feature weights file") + ("warn_0_weight","Warn about any feature id that has a 0 weight (this is perfectly safe if you intend 0 weight, though)") ("scfg_extra_glue_grammar", po::value(), "Extra glue grammar file (Glue grammars apply when i=0 but have no other span restrictions)") ("scfg_no_hiero_glue_grammar,n", "No Hiero glue grammar (nb. by default the SCFG decoder adds Hiero glue rules)") ("scfg_default_nt,d",po::value()->default_value("X"),"Default non-terminal symbol in SCFG") ("scfg_max_span_limit,S",po::value()->default_value(10),"Maximum non-terminal span limit (except \"glue\" grammar)") - ("quiet", "Disable verbose output") - ("show_config", po::bool_switch(&show_config), "show contents of loaded -c config files.") - ("show_weights", po::bool_switch(&show_weights), "show effective feature weights") + ("quiet", "Disable verbose output") + ("show_config", po::bool_switch(&show_config), "show contents of loaded -c config files.") + ("show_weights", po::bool_switch(&show_weights), "show effective feature weights") ("show_joshua_visualization,J", "Produce output compatible with the Joshua visualization tools") ("show_tree_structure", "Show the Viterbi derivation structure") ("show_expected_length", "Show the expected translation length under the model") ("show_partition,z", "Compute and show the partition (inside score)") ("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation") ("show_cfg_search_space", "Show the search space as a CFG") - ("show_features","Show the feature vector for the viterbi translation") - ("density_prune", po::value(), "Keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") + ("show_features","Show the feature vector for the viterbi translation") ("coarse_to_fine_beam_prune", po::value(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)") ("ctf_beam_widen", po::value()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found") ("ctf_num_widenings", po::value()->default_value(2), "Widen coarse beam this many times before backing off to full parse") ("ctf_no_exhaustive", "Do not fall back to exhaustive parse if coarse-to-fine parsing fails") - ("beam_prune", po::value(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)") - ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices") + ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices") ("lextrans_use_null", "Support source-side null words in lexical translation") ("lextrans_align_only", "Only used in alignment mode. Limit target words generated by reference") ("tagger_tagset,t", po::value(), "(Tagger) file containing tag set") @@ -488,14 +539,79 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream exit(1); } - // load feature weights (and possibly freeze feature set) + // load initial feature weights (and possibly freeze feature set) if (conf.count("weights")) { - w.InitFromFile(str("weights",conf)); - feature_weights.resize(FD::NumFeats()); - w.InitVector(&feature_weights); + w_init_weights.InitFromFile(str("weights",conf)); + w_init_weights.InitVector(&init_weights); + init_weights.resize(FD::NumFeats()); + if (show_weights) - cerr << "+LM weights: " << WeightVector(feature_weights)<(); + + // determine the number of rescoring/pruning/weighting passes configured + const int MAX_PASSES = 3; + for (int pass = 0; pass < MAX_PASSES; ++pass) { + string ws = "weights" + StringSuffixForRescoringPass(pass); + string ff = "feature_function" + StringSuffixForRescoringPass(pass); + string bp = "beam_prune" + StringSuffixForRescoringPass(pass); + string dp = "density_prune" + StringSuffixForRescoringPass(pass); + bool first_pass_condition = ((pass == 0) && (conf.count(ff) || conf.count(bp) || conf.count(dp))); + bool nth_pass_condition = ((pass > 0) && (conf.count(ws) || conf.count(ff) || conf.count(bp) || conf.count(dp))); + if (first_pass_condition || nth_pass_condition) { + rescoring_passes.push_back(RescoringPass()); + RescoringPass& rp = rescoring_passes.back(); + // only configure new weights if pass > 0, otherwise we reuse the initial chart weights + if (nth_pass_condition && conf.count(ws)) { + rp.w.reset(new Weights); + rp.w->InitFromFile(str(ws.c_str(), conf)); + rp.w->InitVector(&rp.weight_vector); + } + bool has_stateful = false; + if (conf.count(ff)) { + vector add_ffs; + store_conf(conf,ff,&add_ffs); + for (int i = 0; i < add_ffs.size(); ++i) { + pffs.push_back(make_ff(add_ffs[i],verbose_feature_functions)); + FeatureFunction const* p=pffs.back().get(); + rp.ffs.push_back(p); + if (p->IsStateful()) { has_stateful = true; } + } + } + if (conf.count(bp)) { rp.beam_prune = conf[bp].as(); } + if (conf.count(dp)) { rp.density_prune = conf[dp].as(); } + int palg = (has_stateful ? 1 : 0); // if there are no stateful featueres, default to FULL + string isn = "intersection_strategy" + StringSuffixForRescoringPass(pass); + if (LowercaseString(str(isn.c_str(),conf)) == "full") { + palg = 0; + } + rp.inter_conf.reset(new IntersectionConfiguration(palg, pop_limit)); + } else { + break; // TODO alert user if there are any future configurations + } + } + + // set up weight vectors since later phases may reuse weights from earlier phases + const vector* prev = &init_weights; + for (int pass = 0; pass < rescoring_passes.size(); ++pass) { + RescoringPass& rp = rescoring_passes[pass]; + if (!rp.w) { rp.weight_vector = *prev; } else { prev = &rp.weight_vector; } + rp.models.reset(new ModelSet(rp.weight_vector, rp.ffs)); + string ps = "Pass1 "; ps[4] += pass; + if (!SILENT) show_models(conf,*rp.models,ps.c_str()); } + + // show configuration of rescoring passes + if (!SILENT) { + int num = rescoring_passes.size(); + cerr << "Configured " << num << " rescoring pass" << (num == 1 ? "" : "es") << endl; + for (int pass = 0; pass < num; ++pass) + cerr << " " << rescoring_passes[pass] << endl; + } + bool warn0=conf.count("warn_0_weight"); bool freeze=conf.count("freeze_feature_set"); bool early_freeze=freeze && !warn0; @@ -523,18 +639,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream else assert(!"error"); - // set up additional scoring features - if (conf.count("feature_function") > 0) { - vector add_ffs; -// const vector& add_ffs = conf["feature_function"].as >(); - store_conf(conf,"feature_function",&add_ffs); - for (int i = 0; i < add_ffs.size(); ++i) { - pffs.push_back(make_ff(add_ffs[i],verbose_feature_functions)); - FeatureFunction const* p=pffs.back().get(); - late_ffs.push_back(p); - } - } - #ifdef FSA_RESCORING store_conf(conf,"fsa_feature_function",&fsa_names); for (int i=0;i(); - inter_conf = new IntersectionConfiguration(palg, pop_limit); - sample_max_trans = conf.count("max_translation_sample") ? conf["max_translation_sample"].as() : 0; if (sample_max_trans) @@ -644,8 +737,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { translator->ProcessMarkupHints(smeta.sgml_); Timer t("Translation"); const bool translation_successful = - translator->Translate(to_translate, &smeta, feature_weights, &forest); - //TODO: modify translator to incorporate all 0-state model scores immediately? + translator->Translate(to_translate, &smeta, init_weights, &forest); translator->SentenceComplete(); if (!translation_successful) { @@ -660,7 +752,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { const bool show_tree_structure=conf.count("show_tree_structure"); const bool show_features=conf.count("show_features"); - if (!SILENT) forest_stats(forest," -LM forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); + if (!SILENT) forest_stats(forest," Init. forest",show_tree_structure,show_features,init_weights,oracle.show_derivation); if (conf.count("show_expected_length")) { const PRPair res = Inside, @@ -669,53 +761,63 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { } if (conf.count("show_partition")) { const prob_t z = Inside(forest); - cerr << " -LM partition log(Z): " << log(z) << endl; + cerr << " Init. partition log(Z): " << log(z) << endl; } + for (int pass = 0; pass < rescoring_passes.size(); ++pass) { + const RescoringPass& rp = rescoring_passes[pass]; + const vector& cur_weights = rp.weight_vector; + if (!SILENT) cerr << endl << " RESCORING PASS #" << (pass+1) << " " << rp << endl; #ifdef FSA_RESCORING - cfg_options.maybe_output_source(forest); + cfg_options.maybe_output_source(forest); #endif - bool has_late_models = !late_models->empty(); - if (has_late_models) { - Timer t("Forest rescoring:"); - late_models->PrepareForInput(smeta); - forest.Reweight(feature_weights); - Hypergraph lm_forest; - ApplyModelSet(forest, + string passtr = "Pass1"; passtr[4] += pass; + forest.Reweight(cur_weights); + const bool has_rescoring_models = !rp.models->empty(); + if (has_rescoring_models) { + Timer t("Forest rescoring:"); + rp.models->PrepareForInput(smeta); + Hypergraph rescored_forest; + ApplyModelSet(forest, smeta, - *late_models, - *inter_conf, - &lm_forest); - forest.swap(lm_forest); - forest.Reweight(feature_weights); - if (!SILENT) forest_stats(forest," +LM forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); - } + *rp.models, + *rp.inter_conf, + &rescored_forest); + forest.swap(rescored_forest); + forest.Reweight(cur_weights); + if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,show_features,cur_weights,oracle.show_derivation); + } - maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen); + string fullbp = "beam_prune" + StringSuffixForRescoringPass(pass); + string fulldp = "density_prune" + StringSuffixForRescoringPass(pass); + maybe_prune(forest,conf,fullbp.c_str(),fulldp.c_str(),passtr,srclen); #ifdef FSA_RESCORING - HgCFG hgcfg(forest); - cfg_options.prepare(hgcfg); - - if (!fsa_ffs.empty()) { - Timer t("Target FSA rescoring:"); - if (!has_late_models) - forest.Reweight(feature_weights); - Hypergraph fsa_forest; - assert(fsa_ffs.size()==1); - ApplyFsaBy cfg(str("apply_fsa_by",conf),pop_limit); - if (!SILENT) cerr << "FSA rescoring with "<describe()<describe()<& last_weights = (rescoring_passes.empty() ? init_weights : rescoring_passes.back().weight_vector); // Oracle Rescoring if(get_oracle_forest) { - Oracle oc=oracle.ComputeOracle(smeta,&forest,FeatureVector(feature_weights),10,conf["forest_output"].as()); + Oracle oc=oracle.ComputeOracle(smeta,&forest,FeatureVector(last_weights),10,conf["forest_output"].as()); if (!SILENT) cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; if (!SILENT) cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl; oc.hope.Print(cerr," +Oracle BLEU"); @@ -798,8 +900,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { // if (!SILENT) cerr << " USING UNIFORM WEIGHTS\n"; // for (int i = 0; i < forest.edges_.size(); ++i) // forest.edges_[i].edge_prob_=prob_t::One(); } - forest.Reweight(feature_weights); - if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); + forest.Reweight(last_weights); + if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,show_features,last_weights,oracle.show_derivation); if (!SILENT) cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl; if (conf.count("show_partition")) { const prob_t z = Inside(forest); @@ -830,7 +932,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { const prob_t ref_z = InsideOutside, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp); ref_exp /= ref_z; // if (crf_uniform_empirical) -// log_ref_z = ref_exp.dot(feature_weights); +// log_ref_z = ref_exp.dot(last_weights); log_ref_z = log(ref_z); //cerr << " MODEL LOG Z: " << log_z << endl; //cerr << " EMPIRICAL LOG Z: " << log_ref_z << endl; -- cgit v1.2.3