diff options
author | Paul Baltescu <pauldb89@gmail.com> | 2013-11-25 23:24:42 +0000 |
---|---|---|
committer | Paul Baltescu <pauldb89@gmail.com> | 2013-11-25 23:24:42 +0000 |
commit | 2b95390f08d9f556e6207ecff03b4b0fd5ede993 (patch) | |
tree | 7a96e837a3e28cfc8258a3c5293ac333d7c3e29e /decoder/decoder.cc | |
parent | 467ef6ce78cfe7341a696ebf0948e377be619ae5 (diff) | |
parent | 62a2526e69eb1570bf349763fc8bb65179337918 (diff) |
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'decoder/decoder.cc')
-rw-r--r-- | decoder/decoder.cc | 99 |
1 files changed, 7 insertions, 92 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc index da65713a..9b41253b 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -11,7 +11,6 @@ namespace std { using std::tr1::unordered_map; } #include <boost/make_shared.hpp> #include <boost/scoped_ptr.hpp> -#include "program_options.h" #include "stringlib.h" #include "weights.h" #include "filelib.h" @@ -49,13 +48,6 @@ namespace std { using std::tr1::unordered_map; } #include "hg_io.h" #include "aligner.h" -#undef FSA_RESCORING -#ifdef FSA_RESCORING -#include "hg_cfg.h" -#include "apply_fsa_models.h" -#include "cfg_options.h" -#endif - #ifdef CP_TIME clock_t CpTime::time_; void CpTime::Add(clock_t x){time_+=x;} @@ -140,21 +132,6 @@ inline boost::shared_ptr<FeatureFunction> make_ff(string const& ffp,bool verbose return pf; } -#ifdef FSA_RESCORING -inline boost::shared_ptr<FsaFeatureFunction> make_fsa_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") { - string ff, param; - SplitCommandAndParam(ffp, &ff, ¶m); - cerr << "FSA Feature: " << ff; - if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n"; - else cerr << " (no config parameters)\n"; - boost::shared_ptr<FsaFeatureFunction> pf = fsa_ff_registry.Create(ff, param); - if (!pf) exit(1); - if (verbose_feature_functions && !SILENT) - cerr<<"State is "<<pf->state_bytes()<<" bytes for "<<pre<<"feature "<<ffp<<endl; - return pf; -} -#endif - // when the translation forest is first built, it is scored by the features associated // with the rules. To add other features (like language models, etc), cdec applies one or // more "rescoring passes", which compute new features and optionally apply new weights @@ -304,11 +281,6 @@ struct DecoderImpl { boost::shared_ptr<Translator> translator; boost::shared_ptr<vector<weight_t> > init_weights; // weights used with initial parse vector<boost::shared_ptr<FeatureFunction> > pffs; -#ifdef FSA_RESCORING - CFGOptions cfg_options; - vector<boost::shared_ptr<FsaFeatureFunction> > fsa_ffs; - vector<string> fsa_names; -#endif boost::shared_ptr<RandomNumberGenerator<boost::mt19937> > rng; int sample_max_trans; bool aligner_mode; @@ -324,7 +296,6 @@ struct DecoderImpl { SparseVector<prob_t> acc_vec; // accumulate gradient double acc_obj; // accumulate objective int g_count; // number of gradient pieces computed - int pop_limit; bool csplit_output_plf; bool write_gradient; // TODO Observer bool feature_expectations; // TODO Observer @@ -372,6 +343,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("weights,w",po::value<string>(),"Feature weights file (initial forest / pass 1)") ("feature_function,F",po::value<vector<string> >()->composing(), "Pass 1 additional feature function(s) (-L for list)") ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full, Fast_cube_pruning, Fast_cube_pruning_2") + ("cubepruning_pop_limit,K",po::value<unsigned>()->default_value(200), "Max number of pops from the candidate heap at each node") ("summary_feature", po::value<string>(), "Compute a 'summary feature' at the end of the pass (before any pruning) with name=arg and value=inside-outside/Z") ("summary_feature_type", po::value<string>()->default_value("node_risk"), "Summary feature types: node_risk, edge_risk, edge_prob") ("density_prune", po::value<double>(), "Pass 1 pruning: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") @@ -380,6 +352,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("weights2",po::value<string>(),"Optional pass 2") ("feature_function2",po::value<vector<string> >()->composing(), "Optional pass 2") ("intersection_strategy2",po::value<string>()->default_value("cube_pruning"), "Optional pass 2") + ("cubepruning_pop_limit2",po::value<unsigned>()->default_value(200), "Optional pass 2") ("summary_feature2", po::value<string>(), "Optional pass 2") ("density_prune2", po::value<double>(), "Optional pass 2") ("beam_prune2", po::value<double>(), "Optional pass 2") @@ -387,18 +360,14 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("weights3",po::value<string>(),"Optional pass 3") ("feature_function3",po::value<vector<string> >()->composing(), "Optional pass 3") ("intersection_strategy3",po::value<string>()->default_value("cube_pruning"), "Optional pass 3") + ("cubepruning_pop_limit3",po::value<unsigned>()->default_value(200), "Optional pass 3") ("summary_feature3", po::value<string>(), "Optional pass 3") ("density_prune3", po::value<double>(), "Optional pass 3") ("beam_prune3", po::value<double>(), "Optional pass 3") -#ifdef FSA_RESCORING - ("fsa_feature_function,A",po::value<vector<string> >()->composing(), "Additional FSA feature function(s) (-L for list)") - ("apply_fsa_by",po::value<string>()->default_value("BU_CUBE"), "Method for applying fsa_feature_functions - BU_FULL BU_CUBE EARLEY") //+ApplyFsaBy::all_names() -#endif ("add_pass_through_rules,P","Add rules to translate OOV words as themselves") ("k_best,k",po::value<int>(),"Extract the k best derivations") ("unique_k_best,r", "Unique k-best translation list") - ("cubepruning_pop_limit,K",po::value<int>()->default_value(200), "Max number of pops from the candidate heap at each node") ("aligner,a", "Run as a word/phrase aligner (src & ref required)") ("aligner_use_viterbi", "If run in alignment mode, compute the Viterbi (rather than MAP) alignment") ("goal",po::value<string>()->default_value("S"),"Goal symbol (SCFG & FST)") @@ -446,10 +415,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)"); // ob.AddOptions(&opts); -#ifdef FSA_RESCORING - po::options_description cfgo(cfg_options.description()); - cfg_options.AddOptions(&cfgo); -#endif po::options_description clo("Command line options"); clo.add_options() ("config,c", po::value<vector<string> >(&cfg_files), "Configuration file(s) - latest has priority") @@ -459,15 +424,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ; po::options_description dconfig_options, dcmdline_options; -#ifdef FSA_RESCORING - dconfig_options.add(opts).add(cfgo); -#else dconfig_options.add(opts); -#endif dcmdline_options.add(dconfig_options).add(clo); if (argc) { - argv_minus_to_underscore(argc,argv); po::store(parse_command_line(argc, argv, dcmdline_options), conf); if (conf.count("compgen")) { print_options(cout,dcmdline_options); @@ -511,10 +471,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream if (conf.count("list_feature_functions")) { cerr << "Available feature functions (specify with -F; describe with -u FeatureName):\n"; ff_registry.DisplayList(); //TODO -#ifdef FSA_RESCORING - cerr << "Available FSA feature functions (specify with --fsa_feature_function):\n"; - fsa_ff_registry.DisplayList(); // TODO -#endif cerr << endl; exit(1); } @@ -574,9 +530,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream if (conf.count("weights")) Weights::InitFromFile(str("weights",conf), init_weights.get()); - // cube pruning pop-limit: we may want to configure this on a per-pass basis - pop_limit = conf["cubepruning_pop_limit"].as<int>(); - if (conf.count("extract_rules")) { if (!DirectoryExists(conf["extract_rules"].as<string>())) MkDirP(conf["extract_rules"].as<string>()); @@ -620,6 +573,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream if (conf.count(dp)) { rp.density_prune = conf[dp].as<double>(); } int palg = (has_stateful ? 1 : 0); // if there are no stateful featueres, default to FULL string isn = "intersection_strategy" + StringSuffixForRescoringPass(pass); + string spl = "cubepruning_pop_limit" + StringSuffixForRescoringPass(pass); + unsigned pop_limit = 200; + if (conf.count(spl)) { pop_limit = conf[spl].as<unsigned>(); } if (LowercaseString(str(isn.c_str(),conf)) == "full") { palg = 0; } @@ -686,21 +642,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream else assert(!"error"); -#ifdef FSA_RESCORING - store_conf(conf,"fsa_feature_function",&fsa_names); - for (int i=0;i<fsa_names.size();++i) - fsa_ffs.push_back(make_fsa_ff(fsa_names[i],verbose_feature_functions,"FSA ")); - if (fsa_ffs.size()>1) { - //FIXME: support N fsa ffs. - cerr<<"Only the first fsa FF will be used (FIXME).\n"; - fsa_ffs.resize(1); - } - if (!fsa_ffs.empty()) { - cerr<<"FSA: "; - show_all_features(fsa_ffs,*init_weights,cerr,cerr,true,true); - } -#endif - if (late_freeze) { cerr << "Late freezing feature set (use --no_freeze_feature_set to prevent)." << endl; FD::Freeze(); // this means we can't see the feature names of not-weighted features @@ -720,10 +661,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream oracle.show_derivation=conf.count("show_derivations"); remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations"); -#ifdef FSA_RESCORING - cfg_options.Validate(); -#endif - if (conf.count("extract_rules")) { stringstream ss; ss << sent_id; @@ -840,7 +777,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { HypergraphIO::WriteTarget(conf["show_target_graph"].as<string>(), sent_id, forest); } if (conf.count("incremental_search")) { - incremental->Search(pop_limit, forest); + incremental->Search(conf["cubepruning_pop_limit"].as<unsigned>(), forest); } if (conf.count("show_target_graph") || conf.count("incremental_search")) { o->NotifyDecodingComplete(smeta); @@ -851,9 +788,6 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { const RescoringPass& rp = rescoring_passes[pass]; const vector<weight_t>& cur_weights = *rp.weight_vector; if (!SILENT) cerr << endl << " RESCORING PASS #" << (pass+1) << " " << rp << endl; -#ifdef FSA_RESCORING - cfg_options.maybe_output_source(forest); -#endif string passtr = "Pass1"; passtr[4] += pass; forest.Reweight(cur_weights); @@ -949,25 +883,6 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { string fullbp = "beam_prune" + StringSuffixForRescoringPass(pass); string fulldp = "density_prune" + StringSuffixForRescoringPass(pass); maybe_prune(forest,conf,fullbp.c_str(),fulldp.c_str(),passtr,srclen); - -#ifdef FSA_RESCORING - HgCFG hgcfg(forest); - cfg_options.prepare(hgcfg); - - if (!fsa_ffs.empty()) { - Timer t("Target FSA rescoring:"); - if (!has_late_models) - forest.Reweight(pass0_weights); - Hypergraph fsa_forest; - assert(fsa_ffs.size()==1); - ApplyFsaBy cfg(str("apply_fsa_by",conf),pop_limit); - if (!SILENT) cerr << "FSA rescoring with "<<cfg<<" "<<fsa_ffs[0]->describe()<<endl; - ApplyFsaModels(hgcfg,smeta,*fsa_ffs[0],pass0_weights,cfg,&fsa_forest); - forest.swap(fsa_forest); - forest.Reweight(pass0_weights); - if (!SILENT) forest_stats(forest," +FSA forest",show_tree_structure,oracle.show_derivation); - } -#endif } const vector<double>& last_weights = (rescoring_passes.empty() ? *init_weights : *rescoring_passes.back().weight_vector); |