summaryrefslogtreecommitdiff
path: root/decoder/decoder.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-02-10 00:16:58 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2011-02-10 00:16:58 -0500
commit70fdb6cd8774cbd0114fe0d630781bab309e0d87 (patch)
tree64af5934539a5577d7da05d4e7425dc30e43cf9e /decoder/decoder.cc
parent34a276778ddca51c61f3e9acf5d885d98d34d9cf (diff)
conditional compilation of experimental code, remove prelm scoring code in preparation for multi-phase (re)scoring
Diffstat (limited to 'decoder/decoder.cc')
-rw-r--r--decoder/decoder.cc136
1 files changed, 46 insertions, 90 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index f37e8a37..25f05d8e 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -4,7 +4,7 @@
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
-#include "sampler.h"
+#include "program_options.h"
#include "stringlib.h"
#include "weights.h"
#include "filelib.h"
@@ -24,22 +24,28 @@
#include "sentence_metadata.h"
#include "hg_intersect.h"
-#include "apply_fsa_models.h"
#include "oracle_bleu.h"
#include "apply_models.h"
#include "ff.h"
#include "ff_factory.h"
-#include "cfg_options.h"
#include "viterbi.h"
#include "kbest.h"
#include "inside_outside.h"
#include "exp_semiring.h"
#include "sentence_metadata.h"
-#include "hg_cfg.h"
+#include "sampler.h"
#include "forest_writer.h" // TODO this section should probably be handled by an Observer
#include "hg_io.h"
#include "aligner.h"
+
+#undef FSA_RESCORING
+#ifdef FSA_RESCORING
+#include "hg_cfg.h"
+#include "apply_fsa_models.h"
+#include "cfg_options.h"
+#endif
+
static const double kMINUS_EPSILON = -1e-6; // don't be too strict
using namespace std;
@@ -78,19 +84,6 @@ inline string str(char const* name,po::variables_map const& conf) {
return conf[name].as<string>();
}
-inline bool prelm_weights_string(po::variables_map const& conf,string &s) {
- if (conf.count("prelm_weights")) {
- s=str("prelm_weights",conf);
- return true;
- }
- if (conf.count("prelm_copy_weights")) {
- s=str("weights",conf);
- return true;
- }
- return false;
-}
-
-
// print just the --long_opt names suitable for bash compgen
inline void print_options(std::ostream &out,po::options_description const& opts) {
@@ -127,6 +120,7 @@ inline shared_ptr<FeatureFunction> make_ff(string const& ffp,bool verbose_featur
return pf;
}
+#ifdef FSA_RESCORING
inline shared_ptr<FsaFeatureFunction> make_fsa_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") {
string ff, param;
SplitCommandAndParam(ffp, &ff, &param);
@@ -139,6 +133,7 @@ inline shared_ptr<FsaFeatureFunction> make_fsa_ff(string const& ffp,bool verbose
cerr<<"State is "<<pf->state_bytes()<<" bytes for "<<pre<<"feature "<<ffp<<endl;
return pf;
}
+#endif
struct DecoderImpl {
DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg);
@@ -189,7 +184,7 @@ struct DecoderImpl {
preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true;
pm=&preserve_mask;
}
- forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1,conf["promise_power"].as<double>());
+ forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1);
if (!forestname.empty()) forestname=" "+forestname;
forest_stats(forest," Pruned "+forestname+" forest",false,false,0,false);
cerr << " Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl;
@@ -259,21 +254,22 @@ struct DecoderImpl {
po::variables_map& conf;
OracleBleu oracle;
- CFGOptions cfg_options;
string formalism;
shared_ptr<Translator> translator;
- vector<double> feature_weights,prelm_feature_weights;
- Weights w,prelm_w;
- vector<shared_ptr<FeatureFunction> > pffs,prelm_only_ffs;
- vector<const FeatureFunction*> late_ffs,prelm_ffs;
+ vector<double> feature_weights;
+ Weights w;
+ vector<shared_ptr<FeatureFunction> > pffs;
+ vector<const FeatureFunction*> late_ffs;
+#ifdef FSA_RESCORING
+ CFGOptions cfg_options;
vector<shared_ptr<FsaFeatureFunction> > fsa_ffs;
vector<string> fsa_names;
- ModelSet* late_models, *prelm_models;
+#endif
+ ModelSet* late_models;
IntersectionConfiguration* inter_conf;
shared_ptr<RandomNumberGenerator<boost::mt19937> > rng;
int sample_max_trans;
bool aligner_mode;
- bool minimal_forests;
bool graphviz;
bool joshua_viz;
bool encode_b64;
@@ -286,7 +282,6 @@ struct DecoderImpl {
SparseVector<prob_t> acc_vec; // accumulate gradient
double acc_obj; // accumulate objective
int g_count; // number of gradient pieces computed
- bool has_prelm_models;
int pop_limit;
bool csplit_output_plf;
bool write_gradient; // TODO Observer
@@ -325,15 +320,13 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)")
("per_sentence_grammar_file", po::value<string>(), "Optional (and possibly not implemented) per sentence grammar file enables all per sentence grammars to be stored in a single large file and accessed by offset")
("weights,w",po::value<string>(),"Feature weights file")
- ("prelm_weights",po::value<string>(),"Feature weights file for prelm_beam_prune. Requires --weights.")
- ("prelm_copy_weights","use --weights as value for --prelm_weights.")
- ("prelm_feature_function",po::value<vector<string> >()->composing(),"Additional feature functions for prelm pass only (in addition to the 0-state subset of feature_function")
- ("keep_prelm_cube_order","DEPRECATED (always enabled). when forest rescoring with final models, use the edge ordering from the prelm pruning features*weights. only meaningful if --prelm_weights given. UNTESTED but assume that cube pruning gives a sensible result, and that 'good' (as tuned for bleu w/ prelm features) edges come first.")
("warn_0_weight","Warn about any feature id that has a 0 weight (this is perfectly safe if you intend 0 weight, though)")
("freeze_feature_set,Z", "Freeze feature set after reading feature weights file")
("feature_function,F",po::value<vector<string> >()->composing(), "Additional feature function(s) (-L for list)")
+#ifdef FSA_RESCORING
("fsa_feature_function,A",po::value<vector<string> >()->composing(), "Additional FSA feature function(s) (-L for list)")
("apply_fsa_by",po::value<string>()->default_value("BU_CUBE"), "Method for applying fsa_feature_functions - BU_FULL BU_CUBE EARLEY") //+ApplyFsaBy::all_names()
+#endif
("list_feature_functions,L","List available feature functions")
("add_pass_through_rules,P","Add rules to translate OOV words as themselves")
("k_best,k",po::value<int>(),"Extract the k best derivations")
@@ -357,16 +350,13 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation")
("show_cfg_search_space", "Show the search space as a CFG")
("show_features","Show the feature vector for the viterbi translation")
- ("prelm_density_prune", po::value<double>(), "Applied to -LM forest just before final LM rescoring: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
("density_prune", po::value<double>(), "Keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
- ("prelm_beam_prune", po::value<double>(), "Prune paths from -LM forest before LM rescoring, keeping paths within exp(alpha>=0)")
("coarse_to_fine_beam_prune", po::value<double>(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)")
("ctf_beam_widen", po::value<double>()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found")
("ctf_num_widenings", po::value<int>()->default_value(2), "Widen coarse beam this many times before backing off to full parse")
("ctf_no_exhaustive", "Do not fall back to exhaustive parse if coarse-to-fine parsing fails")
("beam_prune", po::value<double>(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)")
("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices")
- ("promise_power",po::value<double>()->default_value(0), "Give more beam budget to more promising previous-pass nodes when pruning - but allocate the same average beams. 0 means off, 1 means beam proportional to inside*outside prob, n means nth power (affects just --cubepruning_pop_limit). note: for the same pop_limit, this gives more search error unless very close to 0 (recommend disabled; even 0.01 is slightly worse than 0) which is a bad sign and suggests this isn't doing a good job; further it's slightly slower to LM cube rescore with 0.01 compared to 0, as well as giving (very insignificantly) lower BLEU. TODO: test under more conditions, or try idea with different formula, or prob. cube beams.")
("lextrans_use_null", "Support source-side null words in lexical translation")
("lextrans_align_only", "Only used in alignment mode. Limit target words generated by reference")
("tagger_tagset,t", po::value<string>(), "(Tagger) file containing tag set")
@@ -382,11 +372,12 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)")
("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")
- ("forest_output,O",po::value<string>(),"Directory to write forests to")
- ("minimal_forests,m","Write minimal forests (excludes Rule information). Such forests can be used for ML/MAP training, but not rescoring, etc.");
+ ("forest_output,O",po::value<string>(),"Directory to write forests to");
// ob.AddOptions(&opts);
+#ifdef FSA_RESCORING
po::options_description cfgo(cfg_options.description());
cfg_options.AddOptions(&cfgo);
+#endif
po::options_description clo("Command line options");
clo.add_options()
("config,c", po::value<vector<string> >(&cfg_files), "Configuration file(s) - latest has priority")
@@ -396,8 +387,12 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
;
po::options_description dconfig_options, dcmdline_options;
+#ifdef FSA_RESCORING
dconfig_options.add(opts).add(cfgo);
- //add(opts).add(cfgo)
+#else
+ dconfig_options.add(opts);
+#endif
+
dcmdline_options.add(dconfig_options).add(clo);
if (argc) {
argv_minus_to_underscore(argc,argv);
@@ -442,8 +437,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
if (conf.count("list_feature_functions")) {
cerr << "Available feature functions (specify with -F; describe with -u FeatureName):\n";
ff_registry.DisplayList(); //TODO
+#ifdef FSA_RESCORING
cerr << "Available FSA feature functions (specify with --fsa_feature_function):\n";
fsa_ff_registry.DisplayList(); // TODO
+#endif
cerr << endl;
exit(1);
}
@@ -480,7 +477,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
const string formalism = LowercaseString(str("formalism",conf));
const bool csplit_preserve_full_word = conf.count("csplit_preserve_full_word");
if (csplit_preserve_full_word &&
- (formalism != "csplit" || !(conf.count("beam_prune")||conf.count("density_prune")||conf.count("prelm_beam_prune")||conf.count("prelm_density_prune")))) {
+ (formalism != "csplit" || !(conf.count("beam_prune")||conf.count("density_prune")))) {
cerr << "--csplit_preserve_full_word should only be "
<< "used with csplit AND --*_prune!\n";
exit(1);
@@ -492,20 +489,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
}
// load feature weights (and possibly freeze feature set)
- has_prelm_models = false;
if (conf.count("weights")) {
w.InitFromFile(str("weights",conf));
feature_weights.resize(FD::NumFeats());
w.InitVector(&feature_weights);
- string plmw;
- if (prelm_weights_string(conf,plmw)) {
- has_prelm_models = true;
- prelm_w.InitFromFile(plmw);
- prelm_feature_weights.resize(FD::NumFeats());
- prelm_w.InitVector(&prelm_feature_weights);
- if (show_weights)
- cerr << "prelm_weights: " << WeightVector(prelm_feature_weights)<<endl;
- }
if (show_weights)
cerr << "+LM weights: " << WeightVector(feature_weights)<<endl;
}
@@ -545,24 +532,10 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
pffs.push_back(make_ff(add_ffs[i],verbose_feature_functions));
FeatureFunction const* p=pffs.back().get();
late_ffs.push_back(p);
- if (has_prelm_models) {
- if (p->NumBytesContext()==0)
- prelm_ffs.push_back(p);
- else
- cerr << "Excluding stateful feature from prelm pruning: "<<add_ffs[i]<<endl;
- }
- }
- }
- if (conf.count("prelm_feature_function") > 0) {
- vector<string> add_ffs;
- store_conf(conf,"prelm_feature_function",&add_ffs);
-// const vector<string>& add_ffs = conf["prelm_feature_function"].as<vector<string> >();
- for (int i = 0; i < add_ffs.size(); ++i) {
- prelm_only_ffs.push_back(make_ff(add_ffs[i],verbose_feature_functions,"prelm-only "));
- prelm_ffs.push_back(prelm_only_ffs.back().get());
}
}
+#ifdef FSA_RESCORING
store_conf(conf,"fsa_feature_function",&fsa_names);
for (int i=0;i<fsa_names.size();++i)
fsa_ffs.push_back(make_fsa_ff(fsa_names[i],verbose_feature_functions,"FSA "));
@@ -575,20 +548,15 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
cerr<<"FSA: ";
show_all_features(fsa_ffs,feature_weights,cerr,cerr,true,true);
}
+#endif
if (late_freeze) {
cerr << "Late freezing feature set (use --no_freeze_feature_set to prevent)." << endl;
FD::Freeze(); // this means we can't see the feature names of not-weighted features
}
- if (has_prelm_models)
- cerr << "prelm rescoring with "<<prelm_ffs.size()<<" 0-state feature functions. +LM pass will use "<<late_ffs.size()<<" features (not counting rule features)."<<endl;
-
late_models = new ModelSet(feature_weights, late_ffs);
if (!SILENT) show_models(conf,*late_models,"late ");
- prelm_models = new ModelSet(prelm_feature_weights, prelm_ffs);
- if (has_prelm_models) {
- if (!SILENT) show_models(conf,*prelm_models,"prelm "); }
int palg = 1;
if (LowercaseString(str("intersection_strategy",conf)) == "full") {
@@ -603,7 +571,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
if (sample_max_trans)
rng.reset(new RandomNumberGenerator<boost::mt19937>);
aligner_mode = conf.count("aligner");
- minimal_forests = conf.count("minimal_forests");
graphviz = conf.count("graphviz");
joshua_viz = conf.count("show_joshua_visualization");
encode_b64 = str("vector_format",conf) == "b64";
@@ -611,7 +578,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
unique_kbest = conf.count("unique_k_best");
get_oracle_forest = conf.count("get_oracle_forest");
+#ifdef FSA_RESCORING
cfg_options.Validate();
+#endif
if (conf.count("extract_rules"))
extract_file.reset(new WriteFile(str("extract_rules",conf)));
@@ -703,24 +672,9 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
cerr << " -LM partition log(Z): " << log(z) << endl;
}
- if (has_prelm_models) {
- Timer t("prelm rescoring");
- forest.Reweight(prelm_feature_weights);
- Hypergraph prelm_forest;
- prelm_models->PrepareForInput(smeta);
- ApplyModelSet(forest,
- smeta,
- *prelm_models,
- *inter_conf, // this is now reduced to exhaustive if all are stateless
- &prelm_forest);
- forest.swap(prelm_forest);
- forest.Reweight(prelm_feature_weights); //FIXME: why the reweighting? here and below. maybe in case we already had a featval for that id and ApplyModelSet only adds prob, doesn't recompute it?
- forest_stats(forest," prelm forest",show_tree_structure,show_features,prelm_feature_weights,oracle.show_derivation);
- }
-
- maybe_prune(forest,conf,"prelm_beam_prune","prelm_density_prune","-LM",srclen);
-
+#ifdef FSA_RESCORING
cfg_options.maybe_output_source(forest);
+#endif
bool has_late_models = !late_models->empty();
if (has_late_models) {
@@ -740,6 +694,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen);
+#ifdef FSA_RESCORING
HgCFG hgcfg(forest);
cfg_options.prepare(hgcfg);
@@ -756,6 +711,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
forest.Reweight(feature_weights);
if (!SILENT) forest_stats(forest," +FSA forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation);
}
+#endif
// Oracle Rescoring
if(get_oracle_forest) {
@@ -781,10 +737,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
assert(succeeded);
}
new_hg.Union(forest);
- bool succeeded = writer.Write(new_hg, minimal_forests);
+ bool succeeded = writer.Write(new_hg, false);
assert(succeeded);
} else {
- bool succeeded = writer.Write(forest, minimal_forests);
+ bool succeeded = writer.Write(forest, false);
assert(succeeded);
}
}
@@ -861,10 +817,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
assert(succeeded);
}
new_hg.Union(forest);
- bool succeeded = writer.Write(new_hg, minimal_forests);
+ bool succeeded = writer.Write(new_hg, false);
assert(succeeded);
} else {
- bool succeeded = writer.Write(forest, minimal_forests);
+ bool succeeded = writer.Write(forest, false);
assert(succeeded);
}
}