summaryrefslogtreecommitdiff
path: root/decoder/decoder.cc
diff options
context:
space:
mode:
authorPaul Baltescu <pauldb89@gmail.com>2013-11-25 23:24:42 +0000
committerPaul Baltescu <pauldb89@gmail.com>2013-11-25 23:24:42 +0000
commit2b95390f08d9f556e6207ecff03b4b0fd5ede993 (patch)
tree7a96e837a3e28cfc8258a3c5293ac333d7c3e29e /decoder/decoder.cc
parent467ef6ce78cfe7341a696ebf0948e377be619ae5 (diff)
parent62a2526e69eb1570bf349763fc8bb65179337918 (diff)
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'decoder/decoder.cc')
-rw-r--r--decoder/decoder.cc99
1 files changed, 7 insertions, 92 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index da65713a..9b41253b 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -11,7 +11,6 @@ namespace std { using std::tr1::unordered_map; }
#include <boost/make_shared.hpp>
#include <boost/scoped_ptr.hpp>
-#include "program_options.h"
#include "stringlib.h"
#include "weights.h"
#include "filelib.h"
@@ -49,13 +48,6 @@ namespace std { using std::tr1::unordered_map; }
#include "hg_io.h"
#include "aligner.h"
-#undef FSA_RESCORING
-#ifdef FSA_RESCORING
-#include "hg_cfg.h"
-#include "apply_fsa_models.h"
-#include "cfg_options.h"
-#endif
-
#ifdef CP_TIME
clock_t CpTime::time_;
void CpTime::Add(clock_t x){time_+=x;}
@@ -140,21 +132,6 @@ inline boost::shared_ptr<FeatureFunction> make_ff(string const& ffp,bool verbose
return pf;
}
-#ifdef FSA_RESCORING
-inline boost::shared_ptr<FsaFeatureFunction> make_fsa_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") {
- string ff, param;
- SplitCommandAndParam(ffp, &ff, &param);
- cerr << "FSA Feature: " << ff;
- if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n";
- else cerr << " (no config parameters)\n";
- boost::shared_ptr<FsaFeatureFunction> pf = fsa_ff_registry.Create(ff, param);
- if (!pf) exit(1);
- if (verbose_feature_functions && !SILENT)
- cerr<<"State is "<<pf->state_bytes()<<" bytes for "<<pre<<"feature "<<ffp<<endl;
- return pf;
-}
-#endif
-
// when the translation forest is first built, it is scored by the features associated
// with the rules. To add other features (like language models, etc), cdec applies one or
// more "rescoring passes", which compute new features and optionally apply new weights
@@ -304,11 +281,6 @@ struct DecoderImpl {
boost::shared_ptr<Translator> translator;
boost::shared_ptr<vector<weight_t> > init_weights; // weights used with initial parse
vector<boost::shared_ptr<FeatureFunction> > pffs;
-#ifdef FSA_RESCORING
- CFGOptions cfg_options;
- vector<boost::shared_ptr<FsaFeatureFunction> > fsa_ffs;
- vector<string> fsa_names;
-#endif
boost::shared_ptr<RandomNumberGenerator<boost::mt19937> > rng;
int sample_max_trans;
bool aligner_mode;
@@ -324,7 +296,6 @@ struct DecoderImpl {
SparseVector<prob_t> acc_vec; // accumulate gradient
double acc_obj; // accumulate objective
int g_count; // number of gradient pieces computed
- int pop_limit;
bool csplit_output_plf;
bool write_gradient; // TODO Observer
bool feature_expectations; // TODO Observer
@@ -372,6 +343,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("weights,w",po::value<string>(),"Feature weights file (initial forest / pass 1)")
("feature_function,F",po::value<vector<string> >()->composing(), "Pass 1 additional feature function(s) (-L for list)")
("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full, Fast_cube_pruning, Fast_cube_pruning_2")
+ ("cubepruning_pop_limit,K",po::value<unsigned>()->default_value(200), "Max number of pops from the candidate heap at each node")
("summary_feature", po::value<string>(), "Compute a 'summary feature' at the end of the pass (before any pruning) with name=arg and value=inside-outside/Z")
("summary_feature_type", po::value<string>()->default_value("node_risk"), "Summary feature types: node_risk, edge_risk, edge_prob")
("density_prune", po::value<double>(), "Pass 1 pruning: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
@@ -380,6 +352,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("weights2",po::value<string>(),"Optional pass 2")
("feature_function2",po::value<vector<string> >()->composing(), "Optional pass 2")
("intersection_strategy2",po::value<string>()->default_value("cube_pruning"), "Optional pass 2")
+ ("cubepruning_pop_limit2",po::value<unsigned>()->default_value(200), "Optional pass 2")
("summary_feature2", po::value<string>(), "Optional pass 2")
("density_prune2", po::value<double>(), "Optional pass 2")
("beam_prune2", po::value<double>(), "Optional pass 2")
@@ -387,18 +360,14 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("weights3",po::value<string>(),"Optional pass 3")
("feature_function3",po::value<vector<string> >()->composing(), "Optional pass 3")
("intersection_strategy3",po::value<string>()->default_value("cube_pruning"), "Optional pass 3")
+ ("cubepruning_pop_limit3",po::value<unsigned>()->default_value(200), "Optional pass 3")
("summary_feature3", po::value<string>(), "Optional pass 3")
("density_prune3", po::value<double>(), "Optional pass 3")
("beam_prune3", po::value<double>(), "Optional pass 3")
-#ifdef FSA_RESCORING
- ("fsa_feature_function,A",po::value<vector<string> >()->composing(), "Additional FSA feature function(s) (-L for list)")
- ("apply_fsa_by",po::value<string>()->default_value("BU_CUBE"), "Method for applying fsa_feature_functions - BU_FULL BU_CUBE EARLEY") //+ApplyFsaBy::all_names()
-#endif
("add_pass_through_rules,P","Add rules to translate OOV words as themselves")
("k_best,k",po::value<int>(),"Extract the k best derivations")
("unique_k_best,r", "Unique k-best translation list")
- ("cubepruning_pop_limit,K",po::value<int>()->default_value(200), "Max number of pops from the candidate heap at each node")
("aligner,a", "Run as a word/phrase aligner (src & ref required)")
("aligner_use_viterbi", "If run in alignment mode, compute the Viterbi (rather than MAP) alignment")
("goal",po::value<string>()->default_value("S"),"Goal symbol (SCFG & FST)")
@@ -446,10 +415,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)");
// ob.AddOptions(&opts);
-#ifdef FSA_RESCORING
- po::options_description cfgo(cfg_options.description());
- cfg_options.AddOptions(&cfgo);
-#endif
po::options_description clo("Command line options");
clo.add_options()
("config,c", po::value<vector<string> >(&cfg_files), "Configuration file(s) - latest has priority")
@@ -459,15 +424,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
;
po::options_description dconfig_options, dcmdline_options;
-#ifdef FSA_RESCORING
- dconfig_options.add(opts).add(cfgo);
-#else
dconfig_options.add(opts);
-#endif
dcmdline_options.add(dconfig_options).add(clo);
if (argc) {
- argv_minus_to_underscore(argc,argv);
po::store(parse_command_line(argc, argv, dcmdline_options), conf);
if (conf.count("compgen")) {
print_options(cout,dcmdline_options);
@@ -511,10 +471,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
if (conf.count("list_feature_functions")) {
cerr << "Available feature functions (specify with -F; describe with -u FeatureName):\n";
ff_registry.DisplayList(); //TODO
-#ifdef FSA_RESCORING
- cerr << "Available FSA feature functions (specify with --fsa_feature_function):\n";
- fsa_ff_registry.DisplayList(); // TODO
-#endif
cerr << endl;
exit(1);
}
@@ -574,9 +530,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
if (conf.count("weights"))
Weights::InitFromFile(str("weights",conf), init_weights.get());
- // cube pruning pop-limit: we may want to configure this on a per-pass basis
- pop_limit = conf["cubepruning_pop_limit"].as<int>();
-
if (conf.count("extract_rules")) {
if (!DirectoryExists(conf["extract_rules"].as<string>()))
MkDirP(conf["extract_rules"].as<string>());
@@ -620,6 +573,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
if (conf.count(dp)) { rp.density_prune = conf[dp].as<double>(); }
int palg = (has_stateful ? 1 : 0); // if there are no stateful featueres, default to FULL
string isn = "intersection_strategy" + StringSuffixForRescoringPass(pass);
+ string spl = "cubepruning_pop_limit" + StringSuffixForRescoringPass(pass);
+ unsigned pop_limit = 200;
+ if (conf.count(spl)) { pop_limit = conf[spl].as<unsigned>(); }
if (LowercaseString(str(isn.c_str(),conf)) == "full") {
palg = 0;
}
@@ -686,21 +642,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
else
assert(!"error");
-#ifdef FSA_RESCORING
- store_conf(conf,"fsa_feature_function",&fsa_names);
- for (int i=0;i<fsa_names.size();++i)
- fsa_ffs.push_back(make_fsa_ff(fsa_names[i],verbose_feature_functions,"FSA "));
- if (fsa_ffs.size()>1) {
- //FIXME: support N fsa ffs.
- cerr<<"Only the first fsa FF will be used (FIXME).\n";
- fsa_ffs.resize(1);
- }
- if (!fsa_ffs.empty()) {
- cerr<<"FSA: ";
- show_all_features(fsa_ffs,*init_weights,cerr,cerr,true,true);
- }
-#endif
-
if (late_freeze) {
cerr << "Late freezing feature set (use --no_freeze_feature_set to prevent)." << endl;
FD::Freeze(); // this means we can't see the feature names of not-weighted features
@@ -720,10 +661,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
oracle.show_derivation=conf.count("show_derivations");
remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations");
-#ifdef FSA_RESCORING
- cfg_options.Validate();
-#endif
-
if (conf.count("extract_rules")) {
stringstream ss;
ss << sent_id;
@@ -840,7 +777,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
HypergraphIO::WriteTarget(conf["show_target_graph"].as<string>(), sent_id, forest);
}
if (conf.count("incremental_search")) {
- incremental->Search(pop_limit, forest);
+ incremental->Search(conf["cubepruning_pop_limit"].as<unsigned>(), forest);
}
if (conf.count("show_target_graph") || conf.count("incremental_search")) {
o->NotifyDecodingComplete(smeta);
@@ -851,9 +788,6 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
const RescoringPass& rp = rescoring_passes[pass];
const vector<weight_t>& cur_weights = *rp.weight_vector;
if (!SILENT) cerr << endl << " RESCORING PASS #" << (pass+1) << " " << rp << endl;
-#ifdef FSA_RESCORING
- cfg_options.maybe_output_source(forest);
-#endif
string passtr = "Pass1"; passtr[4] += pass;
forest.Reweight(cur_weights);
@@ -949,25 +883,6 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
string fullbp = "beam_prune" + StringSuffixForRescoringPass(pass);
string fulldp = "density_prune" + StringSuffixForRescoringPass(pass);
maybe_prune(forest,conf,fullbp.c_str(),fulldp.c_str(),passtr,srclen);
-
-#ifdef FSA_RESCORING
- HgCFG hgcfg(forest);
- cfg_options.prepare(hgcfg);
-
- if (!fsa_ffs.empty()) {
- Timer t("Target FSA rescoring:");
- if (!has_late_models)
- forest.Reweight(pass0_weights);
- Hypergraph fsa_forest;
- assert(fsa_ffs.size()==1);
- ApplyFsaBy cfg(str("apply_fsa_by",conf),pop_limit);
- if (!SILENT) cerr << "FSA rescoring with "<<cfg<<" "<<fsa_ffs[0]->describe()<<endl;
- ApplyFsaModels(hgcfg,smeta,*fsa_ffs[0],pass0_weights,cfg,&fsa_forest);
- forest.swap(fsa_forest);
- forest.Reweight(pass0_weights);
- if (!SILENT) forest_stats(forest," +FSA forest",show_tree_structure,oracle.show_derivation);
- }
-#endif
}
const vector<double>& last_weights = (rescoring_passes.empty() ? *init_weights : *rescoring_passes.back().weight_vector);