diff options
Diffstat (limited to 'decoder/decoder.cc')
-rw-r--r-- | decoder/decoder.cc | 136 |
1 files changed, 46 insertions, 90 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc index f37e8a37..25f05d8e 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -4,7 +4,7 @@ #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> -#include "sampler.h" +#include "program_options.h" #include "stringlib.h" #include "weights.h" #include "filelib.h" @@ -24,22 +24,28 @@ #include "sentence_metadata.h" #include "hg_intersect.h" -#include "apply_fsa_models.h" #include "oracle_bleu.h" #include "apply_models.h" #include "ff.h" #include "ff_factory.h" -#include "cfg_options.h" #include "viterbi.h" #include "kbest.h" #include "inside_outside.h" #include "exp_semiring.h" #include "sentence_metadata.h" -#include "hg_cfg.h" +#include "sampler.h" #include "forest_writer.h" // TODO this section should probably be handled by an Observer #include "hg_io.h" #include "aligner.h" + +#undef FSA_RESCORING +#ifdef FSA_RESCORING +#include "hg_cfg.h" +#include "apply_fsa_models.h" +#include "cfg_options.h" +#endif + static const double kMINUS_EPSILON = -1e-6; // don't be too strict using namespace std; @@ -78,19 +84,6 @@ inline string str(char const* name,po::variables_map const& conf) { return conf[name].as<string>(); } -inline bool prelm_weights_string(po::variables_map const& conf,string &s) { - if (conf.count("prelm_weights")) { - s=str("prelm_weights",conf); - return true; - } - if (conf.count("prelm_copy_weights")) { - s=str("weights",conf); - return true; - } - return false; -} - - // print just the --long_opt names suitable for bash compgen inline void print_options(std::ostream &out,po::options_description const& opts) { @@ -127,6 +120,7 @@ inline shared_ptr<FeatureFunction> make_ff(string const& ffp,bool verbose_featur return pf; } +#ifdef FSA_RESCORING inline shared_ptr<FsaFeatureFunction> make_fsa_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") { string ff, param; SplitCommandAndParam(ffp, &ff, ¶m); @@ -139,6 +133,7 @@ inline shared_ptr<FsaFeatureFunction> make_fsa_ff(string const& ffp,bool verbose cerr<<"State is "<<pf->state_bytes()<<" bytes for "<<pre<<"feature "<<ffp<<endl; return pf; } +#endif struct DecoderImpl { DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg); @@ -189,7 +184,7 @@ struct DecoderImpl { preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true; pm=&preserve_mask; } - forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1,conf["promise_power"].as<double>()); + forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1); if (!forestname.empty()) forestname=" "+forestname; forest_stats(forest," Pruned "+forestname+" forest",false,false,0,false); cerr << " Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl; @@ -259,21 +254,22 @@ struct DecoderImpl { po::variables_map& conf; OracleBleu oracle; - CFGOptions cfg_options; string formalism; shared_ptr<Translator> translator; - vector<double> feature_weights,prelm_feature_weights; - Weights w,prelm_w; - vector<shared_ptr<FeatureFunction> > pffs,prelm_only_ffs; - vector<const FeatureFunction*> late_ffs,prelm_ffs; + vector<double> feature_weights; + Weights w; + vector<shared_ptr<FeatureFunction> > pffs; + vector<const FeatureFunction*> late_ffs; +#ifdef FSA_RESCORING + CFGOptions cfg_options; vector<shared_ptr<FsaFeatureFunction> > fsa_ffs; vector<string> fsa_names; - ModelSet* late_models, *prelm_models; +#endif + ModelSet* late_models; IntersectionConfiguration* inter_conf; shared_ptr<RandomNumberGenerator<boost::mt19937> > rng; int sample_max_trans; bool aligner_mode; - bool minimal_forests; bool graphviz; bool joshua_viz; bool encode_b64; @@ -286,7 +282,6 @@ struct DecoderImpl { SparseVector<prob_t> acc_vec; // accumulate gradient double acc_obj; // accumulate objective int g_count; // number of gradient pieces computed - bool has_prelm_models; int pop_limit; bool csplit_output_plf; bool write_gradient; // TODO Observer @@ -325,15 +320,13 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)") ("per_sentence_grammar_file", po::value<string>(), "Optional (and possibly not implemented) per sentence grammar file enables all per sentence grammars to be stored in a single large file and accessed by offset") ("weights,w",po::value<string>(),"Feature weights file") - ("prelm_weights",po::value<string>(),"Feature weights file for prelm_beam_prune. Requires --weights.") - ("prelm_copy_weights","use --weights as value for --prelm_weights.") - ("prelm_feature_function",po::value<vector<string> >()->composing(),"Additional feature functions for prelm pass only (in addition to the 0-state subset of feature_function") - ("keep_prelm_cube_order","DEPRECATED (always enabled). when forest rescoring with final models, use the edge ordering from the prelm pruning features*weights. only meaningful if --prelm_weights given. UNTESTED but assume that cube pruning gives a sensible result, and that 'good' (as tuned for bleu w/ prelm features) edges come first.") ("warn_0_weight","Warn about any feature id that has a 0 weight (this is perfectly safe if you intend 0 weight, though)") ("freeze_feature_set,Z", "Freeze feature set after reading feature weights file") ("feature_function,F",po::value<vector<string> >()->composing(), "Additional feature function(s) (-L for list)") +#ifdef FSA_RESCORING ("fsa_feature_function,A",po::value<vector<string> >()->composing(), "Additional FSA feature function(s) (-L for list)") ("apply_fsa_by",po::value<string>()->default_value("BU_CUBE"), "Method for applying fsa_feature_functions - BU_FULL BU_CUBE EARLEY") //+ApplyFsaBy::all_names() +#endif ("list_feature_functions,L","List available feature functions") ("add_pass_through_rules,P","Add rules to translate OOV words as themselves") ("k_best,k",po::value<int>(),"Extract the k best derivations") @@ -357,16 +350,13 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation") ("show_cfg_search_space", "Show the search space as a CFG") ("show_features","Show the feature vector for the viterbi translation") - ("prelm_density_prune", po::value<double>(), "Applied to -LM forest just before final LM rescoring: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") ("density_prune", po::value<double>(), "Keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") - ("prelm_beam_prune", po::value<double>(), "Prune paths from -LM forest before LM rescoring, keeping paths within exp(alpha>=0)") ("coarse_to_fine_beam_prune", po::value<double>(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)") ("ctf_beam_widen", po::value<double>()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found") ("ctf_num_widenings", po::value<int>()->default_value(2), "Widen coarse beam this many times before backing off to full parse") ("ctf_no_exhaustive", "Do not fall back to exhaustive parse if coarse-to-fine parsing fails") ("beam_prune", po::value<double>(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)") ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices") - ("promise_power",po::value<double>()->default_value(0), "Give more beam budget to more promising previous-pass nodes when pruning - but allocate the same average beams. 0 means off, 1 means beam proportional to inside*outside prob, n means nth power (affects just --cubepruning_pop_limit). note: for the same pop_limit, this gives more search error unless very close to 0 (recommend disabled; even 0.01 is slightly worse than 0) which is a bad sign and suggests this isn't doing a good job; further it's slightly slower to LM cube rescore with 0.01 compared to 0, as well as giving (very insignificantly) lower BLEU. TODO: test under more conditions, or try idea with different formula, or prob. cube beams.") ("lextrans_use_null", "Support source-side null words in lexical translation") ("lextrans_align_only", "Only used in alignment mode. Limit target words generated by reference") ("tagger_tagset,t", po::value<string>(), "(Tagger) file containing tag set") @@ -382,11 +372,12 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)") ("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") ("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") - ("forest_output,O",po::value<string>(),"Directory to write forests to") - ("minimal_forests,m","Write minimal forests (excludes Rule information). Such forests can be used for ML/MAP training, but not rescoring, etc."); + ("forest_output,O",po::value<string>(),"Directory to write forests to"); // ob.AddOptions(&opts); +#ifdef FSA_RESCORING po::options_description cfgo(cfg_options.description()); cfg_options.AddOptions(&cfgo); +#endif po::options_description clo("Command line options"); clo.add_options() ("config,c", po::value<vector<string> >(&cfg_files), "Configuration file(s) - latest has priority") @@ -396,8 +387,12 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ; po::options_description dconfig_options, dcmdline_options; +#ifdef FSA_RESCORING dconfig_options.add(opts).add(cfgo); - //add(opts).add(cfgo) +#else + dconfig_options.add(opts); +#endif + dcmdline_options.add(dconfig_options).add(clo); if (argc) { argv_minus_to_underscore(argc,argv); @@ -442,8 +437,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream if (conf.count("list_feature_functions")) { cerr << "Available feature functions (specify with -F; describe with -u FeatureName):\n"; ff_registry.DisplayList(); //TODO +#ifdef FSA_RESCORING cerr << "Available FSA feature functions (specify with --fsa_feature_function):\n"; fsa_ff_registry.DisplayList(); // TODO +#endif cerr << endl; exit(1); } @@ -480,7 +477,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream const string formalism = LowercaseString(str("formalism",conf)); const bool csplit_preserve_full_word = conf.count("csplit_preserve_full_word"); if (csplit_preserve_full_word && - (formalism != "csplit" || !(conf.count("beam_prune")||conf.count("density_prune")||conf.count("prelm_beam_prune")||conf.count("prelm_density_prune")))) { + (formalism != "csplit" || !(conf.count("beam_prune")||conf.count("density_prune")))) { cerr << "--csplit_preserve_full_word should only be " << "used with csplit AND --*_prune!\n"; exit(1); @@ -492,20 +489,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream } // load feature weights (and possibly freeze feature set) - has_prelm_models = false; if (conf.count("weights")) { w.InitFromFile(str("weights",conf)); feature_weights.resize(FD::NumFeats()); w.InitVector(&feature_weights); - string plmw; - if (prelm_weights_string(conf,plmw)) { - has_prelm_models = true; - prelm_w.InitFromFile(plmw); - prelm_feature_weights.resize(FD::NumFeats()); - prelm_w.InitVector(&prelm_feature_weights); - if (show_weights) - cerr << "prelm_weights: " << WeightVector(prelm_feature_weights)<<endl; - } if (show_weights) cerr << "+LM weights: " << WeightVector(feature_weights)<<endl; } @@ -545,24 +532,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream pffs.push_back(make_ff(add_ffs[i],verbose_feature_functions)); FeatureFunction const* p=pffs.back().get(); late_ffs.push_back(p); - if (has_prelm_models) { - if (p->NumBytesContext()==0) - prelm_ffs.push_back(p); - else - cerr << "Excluding stateful feature from prelm pruning: "<<add_ffs[i]<<endl; - } - } - } - if (conf.count("prelm_feature_function") > 0) { - vector<string> add_ffs; - store_conf(conf,"prelm_feature_function",&add_ffs); -// const vector<string>& add_ffs = conf["prelm_feature_function"].as<vector<string> >(); - for (int i = 0; i < add_ffs.size(); ++i) { - prelm_only_ffs.push_back(make_ff(add_ffs[i],verbose_feature_functions,"prelm-only ")); - prelm_ffs.push_back(prelm_only_ffs.back().get()); } } +#ifdef FSA_RESCORING store_conf(conf,"fsa_feature_function",&fsa_names); for (int i=0;i<fsa_names.size();++i) fsa_ffs.push_back(make_fsa_ff(fsa_names[i],verbose_feature_functions,"FSA ")); @@ -575,20 +548,15 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream cerr<<"FSA: "; show_all_features(fsa_ffs,feature_weights,cerr,cerr,true,true); } +#endif if (late_freeze) { cerr << "Late freezing feature set (use --no_freeze_feature_set to prevent)." << endl; FD::Freeze(); // this means we can't see the feature names of not-weighted features } - if (has_prelm_models) - cerr << "prelm rescoring with "<<prelm_ffs.size()<<" 0-state feature functions. +LM pass will use "<<late_ffs.size()<<" features (not counting rule features)."<<endl; - late_models = new ModelSet(feature_weights, late_ffs); if (!SILENT) show_models(conf,*late_models,"late "); - prelm_models = new ModelSet(prelm_feature_weights, prelm_ffs); - if (has_prelm_models) { - if (!SILENT) show_models(conf,*prelm_models,"prelm "); } int palg = 1; if (LowercaseString(str("intersection_strategy",conf)) == "full") { @@ -603,7 +571,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream if (sample_max_trans) rng.reset(new RandomNumberGenerator<boost::mt19937>); aligner_mode = conf.count("aligner"); - minimal_forests = conf.count("minimal_forests"); graphviz = conf.count("graphviz"); joshua_viz = conf.count("show_joshua_visualization"); encode_b64 = str("vector_format",conf) == "b64"; @@ -611,7 +578,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream unique_kbest = conf.count("unique_k_best"); get_oracle_forest = conf.count("get_oracle_forest"); +#ifdef FSA_RESCORING cfg_options.Validate(); +#endif if (conf.count("extract_rules")) extract_file.reset(new WriteFile(str("extract_rules",conf))); @@ -703,24 +672,9 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { cerr << " -LM partition log(Z): " << log(z) << endl; } - if (has_prelm_models) { - Timer t("prelm rescoring"); - forest.Reweight(prelm_feature_weights); - Hypergraph prelm_forest; - prelm_models->PrepareForInput(smeta); - ApplyModelSet(forest, - smeta, - *prelm_models, - *inter_conf, // this is now reduced to exhaustive if all are stateless - &prelm_forest); - forest.swap(prelm_forest); - forest.Reweight(prelm_feature_weights); //FIXME: why the reweighting? here and below. maybe in case we already had a featval for that id and ApplyModelSet only adds prob, doesn't recompute it? - forest_stats(forest," prelm forest",show_tree_structure,show_features,prelm_feature_weights,oracle.show_derivation); - } - - maybe_prune(forest,conf,"prelm_beam_prune","prelm_density_prune","-LM",srclen); - +#ifdef FSA_RESCORING cfg_options.maybe_output_source(forest); +#endif bool has_late_models = !late_models->empty(); if (has_late_models) { @@ -740,6 +694,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen); +#ifdef FSA_RESCORING HgCFG hgcfg(forest); cfg_options.prepare(hgcfg); @@ -756,6 +711,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { forest.Reweight(feature_weights); if (!SILENT) forest_stats(forest," +FSA forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); } +#endif // Oracle Rescoring if(get_oracle_forest) { @@ -781,10 +737,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { assert(succeeded); } new_hg.Union(forest); - bool succeeded = writer.Write(new_hg, minimal_forests); + bool succeeded = writer.Write(new_hg, false); assert(succeeded); } else { - bool succeeded = writer.Write(forest, minimal_forests); + bool succeeded = writer.Write(forest, false); assert(succeeded); } } @@ -861,10 +817,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { assert(succeeded); } new_hg.Union(forest); - bool succeeded = writer.Write(new_hg, minimal_forests); + bool succeeded = writer.Write(new_hg, false); assert(succeeded); } else { - bool succeeded = writer.Write(forest, minimal_forests); + bool succeeded = writer.Write(forest, false); assert(succeeded); } } |