From 70fdb6cd8774cbd0114fe0d630781bab309e0d87 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 10 Feb 2011 00:16:58 -0500 Subject: conditional compilation of experimental code, remove prelm scoring code in preparation for multi-phase (re)scoring --- decoder/decoder.cc | 136 ++++++++++++++++++----------------------------------- 1 file changed, 46 insertions(+), 90 deletions(-) (limited to 'decoder/decoder.cc') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index f37e8a37..25f05d8e 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -4,7 +4,7 @@ #include #include -#include "sampler.h" +#include "program_options.h" #include "stringlib.h" #include "weights.h" #include "filelib.h" @@ -24,22 +24,28 @@ #include "sentence_metadata.h" #include "hg_intersect.h" -#include "apply_fsa_models.h" #include "oracle_bleu.h" #include "apply_models.h" #include "ff.h" #include "ff_factory.h" -#include "cfg_options.h" #include "viterbi.h" #include "kbest.h" #include "inside_outside.h" #include "exp_semiring.h" #include "sentence_metadata.h" -#include "hg_cfg.h" +#include "sampler.h" #include "forest_writer.h" // TODO this section should probably be handled by an Observer #include "hg_io.h" #include "aligner.h" + +#undef FSA_RESCORING +#ifdef FSA_RESCORING +#include "hg_cfg.h" +#include "apply_fsa_models.h" +#include "cfg_options.h" +#endif + static const double kMINUS_EPSILON = -1e-6; // don't be too strict using namespace std; @@ -78,19 +84,6 @@ inline string str(char const* name,po::variables_map const& conf) { return conf[name].as(); } -inline bool prelm_weights_string(po::variables_map const& conf,string &s) { - if (conf.count("prelm_weights")) { - s=str("prelm_weights",conf); - return true; - } - if (conf.count("prelm_copy_weights")) { - s=str("weights",conf); - return true; - } - return false; -} - - // print just the --long_opt names suitable for bash compgen inline void print_options(std::ostream &out,po::options_description const& opts) { @@ -127,6 +120,7 @@ inline shared_ptr make_ff(string const& ffp,bool verbose_featur return pf; } +#ifdef FSA_RESCORING inline shared_ptr make_fsa_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") { string ff, param; SplitCommandAndParam(ffp, &ff, ¶m); @@ -139,6 +133,7 @@ inline shared_ptr make_fsa_ff(string const& ffp,bool verbose cerr<<"State is "<state_bytes()<<" bytes for "<()); + forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1); if (!forestname.empty()) forestname=" "+forestname; forest_stats(forest," Pruned "+forestname+" forest",false,false,0,false); cerr << " Pruned "< translator; - vector feature_weights,prelm_feature_weights; - Weights w,prelm_w; - vector > pffs,prelm_only_ffs; - vector late_ffs,prelm_ffs; + vector feature_weights; + Weights w; + vector > pffs; + vector late_ffs; +#ifdef FSA_RESCORING + CFGOptions cfg_options; vector > fsa_ffs; vector fsa_names; - ModelSet* late_models, *prelm_models; +#endif + ModelSet* late_models; IntersectionConfiguration* inter_conf; shared_ptr > rng; int sample_max_trans; bool aligner_mode; - bool minimal_forests; bool graphviz; bool joshua_viz; bool encode_b64; @@ -286,7 +282,6 @@ struct DecoderImpl { SparseVector acc_vec; // accumulate gradient double acc_obj; // accumulate objective int g_count; // number of gradient pieces computed - bool has_prelm_models; int pop_limit; bool csplit_output_plf; bool write_gradient; // TODO Observer @@ -325,15 +320,13 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("grammar,g",po::value >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)") ("per_sentence_grammar_file", po::value(), "Optional (and possibly not implemented) per sentence grammar file enables all per sentence grammars to be stored in a single large file and accessed by offset") ("weights,w",po::value(),"Feature weights file") - ("prelm_weights",po::value(),"Feature weights file for prelm_beam_prune. Requires --weights.") - ("prelm_copy_weights","use --weights as value for --prelm_weights.") - ("prelm_feature_function",po::value >()->composing(),"Additional feature functions for prelm pass only (in addition to the 0-state subset of feature_function") - ("keep_prelm_cube_order","DEPRECATED (always enabled). when forest rescoring with final models, use the edge ordering from the prelm pruning features*weights. only meaningful if --prelm_weights given. UNTESTED but assume that cube pruning gives a sensible result, and that 'good' (as tuned for bleu w/ prelm features) edges come first.") ("warn_0_weight","Warn about any feature id that has a 0 weight (this is perfectly safe if you intend 0 weight, though)") ("freeze_feature_set,Z", "Freeze feature set after reading feature weights file") ("feature_function,F",po::value >()->composing(), "Additional feature function(s) (-L for list)") +#ifdef FSA_RESCORING ("fsa_feature_function,A",po::value >()->composing(), "Additional FSA feature function(s) (-L for list)") ("apply_fsa_by",po::value()->default_value("BU_CUBE"), "Method for applying fsa_feature_functions - BU_FULL BU_CUBE EARLEY") //+ApplyFsaBy::all_names() +#endif ("list_feature_functions,L","List available feature functions") ("add_pass_through_rules,P","Add rules to translate OOV words as themselves") ("k_best,k",po::value(),"Extract the k best derivations") @@ -357,16 +350,13 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation") ("show_cfg_search_space", "Show the search space as a CFG") ("show_features","Show the feature vector for the viterbi translation") - ("prelm_density_prune", po::value(), "Applied to -LM forest just before final LM rescoring: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") ("density_prune", po::value(), "Keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") - ("prelm_beam_prune", po::value(), "Prune paths from -LM forest before LM rescoring, keeping paths within exp(alpha>=0)") ("coarse_to_fine_beam_prune", po::value(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)") ("ctf_beam_widen", po::value()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found") ("ctf_num_widenings", po::value()->default_value(2), "Widen coarse beam this many times before backing off to full parse") ("ctf_no_exhaustive", "Do not fall back to exhaustive parse if coarse-to-fine parsing fails") ("beam_prune", po::value(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)") ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices") - ("promise_power",po::value()->default_value(0), "Give more beam budget to more promising previous-pass nodes when pruning - but allocate the same average beams. 0 means off, 1 means beam proportional to inside*outside prob, n means nth power (affects just --cubepruning_pop_limit). note: for the same pop_limit, this gives more search error unless very close to 0 (recommend disabled; even 0.01 is slightly worse than 0) which is a bad sign and suggests this isn't doing a good job; further it's slightly slower to LM cube rescore with 0.01 compared to 0, as well as giving (very insignificantly) lower BLEU. TODO: test under more conditions, or try idea with different formula, or prob. cube beams.") ("lextrans_use_null", "Support source-side null words in lexical translation") ("lextrans_align_only", "Only used in alignment mode. Limit target words generated by reference") ("tagger_tagset,t", po::value(), "(Tagger) file containing tag set") @@ -382,11 +372,12 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)") ("vector_format",po::value()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") ("combine_size,C",po::value()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") - ("forest_output,O",po::value(),"Directory to write forests to") - ("minimal_forests,m","Write minimal forests (excludes Rule information). Such forests can be used for ML/MAP training, but not rescoring, etc."); + ("forest_output,O",po::value(),"Directory to write forests to"); // ob.AddOptions(&opts); +#ifdef FSA_RESCORING po::options_description cfgo(cfg_options.description()); cfg_options.AddOptions(&cfgo); +#endif po::options_description clo("Command line options"); clo.add_options() ("config,c", po::value >(&cfg_files), "Configuration file(s) - latest has priority") @@ -396,8 +387,12 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ; po::options_description dconfig_options, dcmdline_options; +#ifdef FSA_RESCORING dconfig_options.add(opts).add(cfgo); - //add(opts).add(cfgo) +#else + dconfig_options.add(opts); +#endif + dcmdline_options.add(dconfig_options).add(clo); if (argc) { argv_minus_to_underscore(argc,argv); @@ -442,8 +437,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream if (conf.count("list_feature_functions")) { cerr << "Available feature functions (specify with -F; describe with -u FeatureName):\n"; ff_registry.DisplayList(); //TODO +#ifdef FSA_RESCORING cerr << "Available FSA feature functions (specify with --fsa_feature_function):\n"; fsa_ff_registry.DisplayList(); // TODO +#endif cerr << endl; exit(1); } @@ -480,7 +477,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream const string formalism = LowercaseString(str("formalism",conf)); const bool csplit_preserve_full_word = conf.count("csplit_preserve_full_word"); if (csplit_preserve_full_word && - (formalism != "csplit" || !(conf.count("beam_prune")||conf.count("density_prune")||conf.count("prelm_beam_prune")||conf.count("prelm_density_prune")))) { + (formalism != "csplit" || !(conf.count("beam_prune")||conf.count("density_prune")))) { cerr << "--csplit_preserve_full_word should only be " << "used with csplit AND --*_prune!\n"; exit(1); @@ -492,20 +489,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream } // load feature weights (and possibly freeze feature set) - has_prelm_models = false; if (conf.count("weights")) { w.InitFromFile(str("weights",conf)); feature_weights.resize(FD::NumFeats()); w.InitVector(&feature_weights); - string plmw; - if (prelm_weights_string(conf,plmw)) { - has_prelm_models = true; - prelm_w.InitFromFile(plmw); - prelm_feature_weights.resize(FD::NumFeats()); - prelm_w.InitVector(&prelm_feature_weights); - if (show_weights) - cerr << "prelm_weights: " << WeightVector(prelm_feature_weights)<NumBytesContext()==0) - prelm_ffs.push_back(p); - else - cerr << "Excluding stateful feature from prelm pruning: "< 0) { - vector add_ffs; - store_conf(conf,"prelm_feature_function",&add_ffs); -// const vector& add_ffs = conf["prelm_feature_function"].as >(); - for (int i = 0; i < add_ffs.size(); ++i) { - prelm_only_ffs.push_back(make_ff(add_ffs[i],verbose_feature_functions,"prelm-only ")); - prelm_ffs.push_back(prelm_only_ffs.back().get()); } } +#ifdef FSA_RESCORING store_conf(conf,"fsa_feature_function",&fsa_names); for (int i=0;i); aligner_mode = conf.count("aligner"); - minimal_forests = conf.count("minimal_forests"); graphviz = conf.count("graphviz"); joshua_viz = conf.count("show_joshua_visualization"); encode_b64 = str("vector_format",conf) == "b64"; @@ -611,7 +578,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream unique_kbest = conf.count("unique_k_best"); get_oracle_forest = conf.count("get_oracle_forest"); +#ifdef FSA_RESCORING cfg_options.Validate(); +#endif if (conf.count("extract_rules")) extract_file.reset(new WriteFile(str("extract_rules",conf))); @@ -703,24 +672,9 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { cerr << " -LM partition log(Z): " << log(z) << endl; } - if (has_prelm_models) { - Timer t("prelm rescoring"); - forest.Reweight(prelm_feature_weights); - Hypergraph prelm_forest; - prelm_models->PrepareForInput(smeta); - ApplyModelSet(forest, - smeta, - *prelm_models, - *inter_conf, // this is now reduced to exhaustive if all are stateless - &prelm_forest); - forest.swap(prelm_forest); - forest.Reweight(prelm_feature_weights); //FIXME: why the reweighting? here and below. maybe in case we already had a featval for that id and ApplyModelSet only adds prob, doesn't recompute it? - forest_stats(forest," prelm forest",show_tree_structure,show_features,prelm_feature_weights,oracle.show_derivation); - } - - maybe_prune(forest,conf,"prelm_beam_prune","prelm_density_prune","-LM",srclen); - +#ifdef FSA_RESCORING cfg_options.maybe_output_source(forest); +#endif bool has_late_models = !late_models->empty(); if (has_late_models) { @@ -740,6 +694,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen); +#ifdef FSA_RESCORING HgCFG hgcfg(forest); cfg_options.prepare(hgcfg); @@ -756,6 +711,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { forest.Reweight(feature_weights); if (!SILENT) forest_stats(forest," +FSA forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); } +#endif // Oracle Rescoring if(get_oracle_forest) { @@ -781,10 +737,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { assert(succeeded); } new_hg.Union(forest); - bool succeeded = writer.Write(new_hg, minimal_forests); + bool succeeded = writer.Write(new_hg, false); assert(succeeded); } else { - bool succeeded = writer.Write(forest, minimal_forests); + bool succeeded = writer.Write(forest, false); assert(succeeded); } } @@ -861,10 +817,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { assert(succeeded); } new_hg.Union(forest); - bool succeeded = writer.Write(new_hg, minimal_forests); + bool succeeded = writer.Write(new_hg, false); assert(succeeded); } else { - bool succeeded = writer.Write(forest, minimal_forests); + bool succeeded = writer.Write(forest, false); assert(succeeded); } } -- cgit v1.2.3