summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-02-10 18:45:13 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2011-02-10 18:45:13 -0500
commit9a695967a5e4efc987b61bb3df90c0558c678512 (patch)
tree2ba3feb890e5fa1565fdffed6b58756d5ab30e08 /decoder
parent70fdb6cd8774cbd0114fe0d630781bab309e0d87 (diff)
multipass decoding
Diffstat (limited to 'decoder')
-rw-r--r--decoder/apply_models.h10
-rw-r--r--decoder/decoder.cc288
-rw-r--r--decoder/ff.h1
-rw-r--r--decoder/scfg_translator.cc4
4 files changed, 208 insertions, 95 deletions
diff --git a/decoder/apply_models.h b/decoder/apply_models.h
index 81fa068e..a85694aa 100644
--- a/decoder/apply_models.h
+++ b/decoder/apply_models.h
@@ -1,6 +1,8 @@
#ifndef _APPLY_MODELS_H_
#define _APPLY_MODELS_H_
+#include <iostream>
+
struct ModelSet;
struct Hypergraph;
struct SentenceMetadata;
@@ -20,6 +22,14 @@ enum {
IntersectionConfiguration(exhaustive_t /* t */) : algorithm(0), pop_limit() {}
};
+inline std::ostream& operator<<(std::ostream& os, const IntersectionConfiguration& c) {
+ if (c.algorithm == 0) { os << "FULL"; }
+ else if (c.algorithm == 1) { os << "CUBE:k=" << c.pop_limit; }
+ else if (c.algorithm == 2) { os << "N_ALGORITHMS"; }
+ else os << "OTHER";
+ return os;
+}
+
void ApplyModelSet(const Hypergraph& in,
const SentenceMetadata& smeta,
const ModelSet& models,
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 25f05d8e..478a1cf3 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -59,7 +59,7 @@ namespace Hack { void MaxTrans(const Hypergraph& in, int beam_size); }
namespace NgramCache { void Clear(); }
DecoderObserver::~DecoderObserver() {}
-void DecoderObserver::NotifyDecodingStart(const SentenceMetadata& smeta) {}
+void DecoderObserver::NotifyDecodingStart(const SentenceMetadata&) {}
void DecoderObserver::NotifySourceParseFailure(const SentenceMetadata&) {}
void DecoderObserver::NotifyTranslationForest(const SentenceMetadata&, Hypergraph*) {}
void DecoderObserver::NotifyAlignmentFailure(const SentenceMetadata&) {}
@@ -72,7 +72,7 @@ struct ELengthWeightFunction {
}
};
inline void ShowBanner() {
- cerr << "cdec v1.0 (c) 2009-2010 by Chris Dyer\n";
+ cerr << "cdec v1.0 (c) 2009-2011 by Chris Dyer\n";
}
inline void show_models(po::variables_map const& conf,ModelSet &ms,char const* header) {
@@ -135,12 +135,40 @@ inline shared_ptr<FsaFeatureFunction> make_fsa_ff(string const& ffp,bool verbose
}
#endif
+// when the translation forest is first built, it is scored by the features associated
+// with the rules. To add other features (like language models, etc), cdec applies one or
+// more "rescoring passes", which compute new features and optionally apply new weights
+// and then prune the resulting (rescored) hypergraph. All feature values from previous
+// passes are present in future passes (although their weights may change).
+struct RescoringPass {
+ RescoringPass() : density_prune(), beam_prune() {}
+ shared_ptr<ModelSet> models;
+ shared_ptr<IntersectionConfiguration> inter_conf;
+ vector<const FeatureFunction*> ffs;
+ shared_ptr<Weights> w; // null == use previous weights
+ vector<double> weight_vector;
+ double density_prune; // 0 == don't density prune
+ double beam_prune; // 0 == don't beam prune
+};
+
+ostream& operator<<(ostream& os, const RescoringPass& rp) {
+ os << "[num_fn=" << rp.ffs.size();
+ if (rp.inter_conf) { os << " int_alg=" << *rp.inter_conf; }
+ if (rp.w) os << " new_weights";
+ if (rp.density_prune) os << " density_prune=" << rp.density_prune;
+ if (rp.beam_prune) os << " beam_prune=" << rp.beam_prune;
+ os << ']';
+ return os;
+}
+
struct DecoderImpl {
DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg);
~DecoderImpl();
bool Decode(const string& input, DecoderObserver*);
void SetWeights(const vector<double>& weights) {
- feature_weights = weights;
+ init_weights = weights;
+ for (int i = 0; i < rescoring_passes.size(); ++i)
+ rescoring_passes[i].weight_vector = weights;
}
void SetId(int next_sent_id) { sent_id = next_sent_id - 1; }
@@ -159,8 +187,8 @@ struct DecoderImpl {
}
}
- void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,DenseWeightVector const& feature_weights, bool sd=false) {
- WeightVector fw(feature_weights);
+ void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_features,DenseWeightVector const& weights, bool sd=false) {
+ WeightVector fw(weights);
forest_stats(forest,name,show_tree,show_features,&fw,sd);
}
@@ -252,21 +280,30 @@ struct DecoderImpl {
}
}
+ // used to construct the suffix string to get the name of arguments for multiple passes
+ // e.g., the "2" in --weights2
+ static string StringSuffixForRescoringPass(int pass) {
+ if (pass == 0) return "";
+ string ps = "1";
+ assert(pass < 9);
+ ps[0] += pass;
+ return ps;
+ }
+
+ vector<RescoringPass> rescoring_passes;
+
po::variables_map& conf;
OracleBleu oracle;
string formalism;
shared_ptr<Translator> translator;
- vector<double> feature_weights;
- Weights w;
+ Weights w_init_weights; // used with initial parse
+ vector<double> init_weights; // weights used with initial parse
vector<shared_ptr<FeatureFunction> > pffs;
- vector<const FeatureFunction*> late_ffs;
#ifdef FSA_RESCORING
CFGOptions cfg_options;
vector<shared_ptr<FsaFeatureFunction> > fsa_ffs;
vector<string> fsa_names;
#endif
- ModelSet* late_models;
- IntersectionConfiguration* inter_conf;
shared_ptr<RandomNumberGenerator<boost::mt19937> > rng;
int sample_max_trans;
bool aligner_mode;
@@ -319,44 +356,58 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("input,i",po::value<string>()->default_value("-"),"Source file")
("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)")
("per_sentence_grammar_file", po::value<string>(), "Optional (and possibly not implemented) per sentence grammar file enables all per sentence grammars to be stored in a single large file and accessed by offset")
- ("weights,w",po::value<string>(),"Feature weights file")
- ("warn_0_weight","Warn about any feature id that has a 0 weight (this is perfectly safe if you intend 0 weight, though)")
- ("freeze_feature_set,Z", "Freeze feature set after reading feature weights file")
- ("feature_function,F",po::value<vector<string> >()->composing(), "Additional feature function(s) (-L for list)")
+ ("list_feature_functions,L","List available feature functions")
+
+ ("weights,w",po::value<string>(),"Feature weights file (initial forest / pass 1)")
+ ("feature_function,F",po::value<vector<string> >()->composing(), "Pass 1 additional feature function(s) (-L for list)")
+ ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full")
+ ("density_prune", po::value<double>(), "Pass 1 pruning: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
+ ("beam_prune", po::value<double>(), "Pass 1 pruning: Prune paths from scored forest, keep paths within exp(alpha>=0)")
+
+ ("weights2",po::value<string>(),"Optional pass 2")
+ ("feature_function2",po::value<vector<string> >()->composing(), "Optional pass 2")
+ ("intersection_strategy2",po::value<string>()->default_value("cube_pruning"), "Optional pass 2")
+ ("density_prune2", po::value<double>(), "Optional pass 2")
+ ("beam_prune2", po::value<double>(), "Optional pass 2")
+
+ ("weights3",po::value<string>(),"Optional pass 3")
+ ("feature_function3",po::value<vector<string> >()->composing(), "Optional pass 3")
+ ("intersection_strategy3",po::value<string>()->default_value("cube_pruning"), "Optional pass 3")
+ ("density_prune3", po::value<double>(), "Optional pass 3")
+ ("beam_prune3", po::value<double>(), "Optional pass 3")
+
#ifdef FSA_RESCORING
("fsa_feature_function,A",po::value<vector<string> >()->composing(), "Additional FSA feature function(s) (-L for list)")
("apply_fsa_by",po::value<string>()->default_value("BU_CUBE"), "Method for applying fsa_feature_functions - BU_FULL BU_CUBE EARLEY") //+ApplyFsaBy::all_names()
#endif
- ("list_feature_functions,L","List available feature functions")
("add_pass_through_rules,P","Add rules to translate OOV words as themselves")
("k_best,k",po::value<int>(),"Extract the k best derivations")
("unique_k_best,r", "Unique k-best translation list")
+ ("cubepruning_pop_limit,K",po::value<int>()->default_value(200), "Max number of pops from the candidate heap at each node")
("aligner,a", "Run as a word/phrase aligner (src & ref required)")
("aligner_use_viterbi", "If run in alignment mode, compute the Viterbi (rather than MAP) alignment")
- ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Intersection strategy for incorporating finite-state features; values include Cube_pruning, Full")
- ("cubepruning_pop_limit,K",po::value<int>()->default_value(200), "Max number of pops from the candidate heap at each node")
("goal",po::value<string>()->default_value("S"),"Goal symbol (SCFG & FST)")
+ ("freeze_feature_set,Z", "Freeze feature set after reading feature weights file")
+ ("warn_0_weight","Warn about any feature id that has a 0 weight (this is perfectly safe if you intend 0 weight, though)")
("scfg_extra_glue_grammar", po::value<string>(), "Extra glue grammar file (Glue grammars apply when i=0 but have no other span restrictions)")
("scfg_no_hiero_glue_grammar,n", "No Hiero glue grammar (nb. by default the SCFG decoder adds Hiero glue rules)")
("scfg_default_nt,d",po::value<string>()->default_value("X"),"Default non-terminal symbol in SCFG")
("scfg_max_span_limit,S",po::value<int>()->default_value(10),"Maximum non-terminal span limit (except \"glue\" grammar)")
- ("quiet", "Disable verbose output")
- ("show_config", po::bool_switch(&show_config), "show contents of loaded -c config files.")
- ("show_weights", po::bool_switch(&show_weights), "show effective feature weights")
+ ("quiet", "Disable verbose output")
+ ("show_config", po::bool_switch(&show_config), "show contents of loaded -c config files.")
+ ("show_weights", po::bool_switch(&show_weights), "show effective feature weights")
("show_joshua_visualization,J", "Produce output compatible with the Joshua visualization tools")
("show_tree_structure", "Show the Viterbi derivation structure")
("show_expected_length", "Show the expected translation length under the model")
("show_partition,z", "Compute and show the partition (inside score)")
("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation")
("show_cfg_search_space", "Show the search space as a CFG")
- ("show_features","Show the feature vector for the viterbi translation")
- ("density_prune", po::value<double>(), "Keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
+ ("show_features","Show the feature vector for the viterbi translation")
("coarse_to_fine_beam_prune", po::value<double>(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)")
("ctf_beam_widen", po::value<double>()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found")
("ctf_num_widenings", po::value<int>()->default_value(2), "Widen coarse beam this many times before backing off to full parse")
("ctf_no_exhaustive", "Do not fall back to exhaustive parse if coarse-to-fine parsing fails")
- ("beam_prune", po::value<double>(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)")
- ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices")
+ ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices")
("lextrans_use_null", "Support source-side null words in lexical translation")
("lextrans_align_only", "Only used in alignment mode. Limit target words generated by reference")
("tagger_tagset,t", po::value<string>(), "(Tagger) file containing tag set")
@@ -488,14 +539,79 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
exit(1);
}
- // load feature weights (and possibly freeze feature set)
+ // load initial feature weights (and possibly freeze feature set)
if (conf.count("weights")) {
- w.InitFromFile(str("weights",conf));
- feature_weights.resize(FD::NumFeats());
- w.InitVector(&feature_weights);
+ w_init_weights.InitFromFile(str("weights",conf));
+ w_init_weights.InitVector(&init_weights);
+ init_weights.resize(FD::NumFeats());
+
if (show_weights)
- cerr << "+LM weights: " << WeightVector(feature_weights)<<endl;
+ cerr << "Initial weights: " << WeightVector(init_weights)<<endl;
+ }
+
+ // cube pruning pop-limit: we may want to configure this on a per-pass basis
+ pop_limit = conf["cubepruning_pop_limit"].as<int>();
+
+ // determine the number of rescoring/pruning/weighting passes configured
+ const int MAX_PASSES = 3;
+ for (int pass = 0; pass < MAX_PASSES; ++pass) {
+ string ws = "weights" + StringSuffixForRescoringPass(pass);
+ string ff = "feature_function" + StringSuffixForRescoringPass(pass);
+ string bp = "beam_prune" + StringSuffixForRescoringPass(pass);
+ string dp = "density_prune" + StringSuffixForRescoringPass(pass);
+ bool first_pass_condition = ((pass == 0) && (conf.count(ff) || conf.count(bp) || conf.count(dp)));
+ bool nth_pass_condition = ((pass > 0) && (conf.count(ws) || conf.count(ff) || conf.count(bp) || conf.count(dp)));
+ if (first_pass_condition || nth_pass_condition) {
+ rescoring_passes.push_back(RescoringPass());
+ RescoringPass& rp = rescoring_passes.back();
+ // only configure new weights if pass > 0, otherwise we reuse the initial chart weights
+ if (nth_pass_condition && conf.count(ws)) {
+ rp.w.reset(new Weights);
+ rp.w->InitFromFile(str(ws.c_str(), conf));
+ rp.w->InitVector(&rp.weight_vector);
+ }
+ bool has_stateful = false;
+ if (conf.count(ff)) {
+ vector<string> add_ffs;
+ store_conf(conf,ff,&add_ffs);
+ for (int i = 0; i < add_ffs.size(); ++i) {
+ pffs.push_back(make_ff(add_ffs[i],verbose_feature_functions));
+ FeatureFunction const* p=pffs.back().get();
+ rp.ffs.push_back(p);
+ if (p->IsStateful()) { has_stateful = true; }
+ }
+ }
+ if (conf.count(bp)) { rp.beam_prune = conf[bp].as<double>(); }
+ if (conf.count(dp)) { rp.density_prune = conf[dp].as<double>(); }
+ int palg = (has_stateful ? 1 : 0); // if there are no stateful featueres, default to FULL
+ string isn = "intersection_strategy" + StringSuffixForRescoringPass(pass);
+ if (LowercaseString(str(isn.c_str(),conf)) == "full") {
+ palg = 0;
+ }
+ rp.inter_conf.reset(new IntersectionConfiguration(palg, pop_limit));
+ } else {
+ break; // TODO alert user if there are any future configurations
+ }
+ }
+
+ // set up weight vectors since later phases may reuse weights from earlier phases
+ const vector<double>* prev = &init_weights;
+ for (int pass = 0; pass < rescoring_passes.size(); ++pass) {
+ RescoringPass& rp = rescoring_passes[pass];
+ if (!rp.w) { rp.weight_vector = *prev; } else { prev = &rp.weight_vector; }
+ rp.models.reset(new ModelSet(rp.weight_vector, rp.ffs));
+ string ps = "Pass1 "; ps[4] += pass;
+ if (!SILENT) show_models(conf,*rp.models,ps.c_str());
}
+
+ // show configuration of rescoring passes
+ if (!SILENT) {
+ int num = rescoring_passes.size();
+ cerr << "Configured " << num << " rescoring pass" << (num == 1 ? "" : "es") << endl;
+ for (int pass = 0; pass < num; ++pass)
+ cerr << " " << rescoring_passes[pass] << endl;
+ }
+
bool warn0=conf.count("warn_0_weight");
bool freeze=conf.count("freeze_feature_set");
bool early_freeze=freeze && !warn0;
@@ -523,18 +639,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
else
assert(!"error");
- // set up additional scoring features
- if (conf.count("feature_function") > 0) {
- vector<string> add_ffs;
-// const vector<string>& add_ffs = conf["feature_function"].as<vector<string> >();
- store_conf(conf,"feature_function",&add_ffs);
- for (int i = 0; i < add_ffs.size(); ++i) {
- pffs.push_back(make_ff(add_ffs[i],verbose_feature_functions));
- FeatureFunction const* p=pffs.back().get();
- late_ffs.push_back(p);
- }
- }
-
#ifdef FSA_RESCORING
store_conf(conf,"fsa_feature_function",&fsa_names);
for (int i=0;i<fsa_names.size();++i)
@@ -546,7 +650,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
}
if (!fsa_ffs.empty()) {
cerr<<"FSA: ";
- show_all_features(fsa_ffs,feature_weights,cerr,cerr,true,true);
+ show_all_features(fsa_ffs,init_weights,cerr,cerr,true,true);
}
#endif
@@ -555,17 +659,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
FD::Freeze(); // this means we can't see the feature names of not-weighted features
}
- late_models = new ModelSet(feature_weights, late_ffs);
- if (!SILENT) show_models(conf,*late_models,"late ");
-
- int palg = 1;
- if (LowercaseString(str("intersection_strategy",conf)) == "full") {
- palg = 0;
- cerr << "Using full intersection (no pruning).\n";
- }
- pop_limit=conf["cubepruning_pop_limit"].as<int>();
- inter_conf = new IntersectionConfiguration(palg, pop_limit);
-
sample_max_trans = conf.count("max_translation_sample") ?
conf["max_translation_sample"].as<int>() : 0;
if (sample_max_trans)
@@ -644,8 +737,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
translator->ProcessMarkupHints(smeta.sgml_);
Timer t("Translation");
const bool translation_successful =
- translator->Translate(to_translate, &smeta, feature_weights, &forest);
- //TODO: modify translator to incorporate all 0-state model scores immediately?
+ translator->Translate(to_translate, &smeta, init_weights, &forest);
translator->SentenceComplete();
if (!translation_successful) {
@@ -660,7 +752,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
const bool show_tree_structure=conf.count("show_tree_structure");
const bool show_features=conf.count("show_features");
- if (!SILENT) forest_stats(forest," -LM forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation);
+ if (!SILENT) forest_stats(forest," Init. forest",show_tree_structure,show_features,init_weights,oracle.show_derivation);
if (conf.count("show_expected_length")) {
const PRPair<double, double> res =
Inside<PRPair<double, double>,
@@ -669,53 +761,63 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
}
if (conf.count("show_partition")) {
const prob_t z = Inside<prob_t, EdgeProb>(forest);
- cerr << " -LM partition log(Z): " << log(z) << endl;
+ cerr << " Init. partition log(Z): " << log(z) << endl;
}
+ for (int pass = 0; pass < rescoring_passes.size(); ++pass) {
+ const RescoringPass& rp = rescoring_passes[pass];
+ const vector<double>& cur_weights = rp.weight_vector;
+ if (!SILENT) cerr << endl << " RESCORING PASS #" << (pass+1) << " " << rp << endl;
#ifdef FSA_RESCORING
- cfg_options.maybe_output_source(forest);
+ cfg_options.maybe_output_source(forest);
#endif
- bool has_late_models = !late_models->empty();
- if (has_late_models) {
- Timer t("Forest rescoring:");
- late_models->PrepareForInput(smeta);
- forest.Reweight(feature_weights);
- Hypergraph lm_forest;
- ApplyModelSet(forest,
+ string passtr = "Pass1"; passtr[4] += pass;
+ forest.Reweight(cur_weights);
+ const bool has_rescoring_models = !rp.models->empty();
+ if (has_rescoring_models) {
+ Timer t("Forest rescoring:");
+ rp.models->PrepareForInput(smeta);
+ Hypergraph rescored_forest;
+ ApplyModelSet(forest,
smeta,
- *late_models,
- *inter_conf,
- &lm_forest);
- forest.swap(lm_forest);
- forest.Reweight(feature_weights);
- if (!SILENT) forest_stats(forest," +LM forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation);
- }
+ *rp.models,
+ *rp.inter_conf,
+ &rescored_forest);
+ forest.swap(rescored_forest);
+ forest.Reweight(cur_weights);
+ if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,show_features,cur_weights,oracle.show_derivation);
+ }
- maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen);
+ string fullbp = "beam_prune" + StringSuffixForRescoringPass(pass);
+ string fulldp = "density_prune" + StringSuffixForRescoringPass(pass);
+ maybe_prune(forest,conf,fullbp.c_str(),fulldp.c_str(),passtr,srclen);
#ifdef FSA_RESCORING
- HgCFG hgcfg(forest);
- cfg_options.prepare(hgcfg);
-
- if (!fsa_ffs.empty()) {
- Timer t("Target FSA rescoring:");
- if (!has_late_models)
- forest.Reweight(feature_weights);
- Hypergraph fsa_forest;
- assert(fsa_ffs.size()==1);
- ApplyFsaBy cfg(str("apply_fsa_by",conf),pop_limit);
- if (!SILENT) cerr << "FSA rescoring with "<<cfg<<" "<<fsa_ffs[0]->describe()<<endl;
- ApplyFsaModels(hgcfg,smeta,*fsa_ffs[0],feature_weights,cfg,&fsa_forest);
- forest.swap(fsa_forest);
- forest.Reweight(feature_weights);
- if (!SILENT) forest_stats(forest," +FSA forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation);
- }
+ HgCFG hgcfg(forest);
+ cfg_options.prepare(hgcfg);
+
+ if (!fsa_ffs.empty()) {
+ Timer t("Target FSA rescoring:");
+ if (!has_late_models)
+ forest.Reweight(pass0_weights);
+ Hypergraph fsa_forest;
+ assert(fsa_ffs.size()==1);
+ ApplyFsaBy cfg(str("apply_fsa_by",conf),pop_limit);
+ if (!SILENT) cerr << "FSA rescoring with "<<cfg<<" "<<fsa_ffs[0]->describe()<<endl;
+ ApplyFsaModels(hgcfg,smeta,*fsa_ffs[0],pass0_weights,cfg,&fsa_forest);
+ forest.swap(fsa_forest);
+ forest.Reweight(pass0_weights);
+ if (!SILENT) forest_stats(forest," +FSA forest",show_tree_structure,show_features,pass0_weights,oracle.show_derivation);
+ }
#endif
+ }
+
+ const vector<double>& last_weights = (rescoring_passes.empty() ? init_weights : rescoring_passes.back().weight_vector);
// Oracle Rescoring
if(get_oracle_forest) {
- Oracle oc=oracle.ComputeOracle(smeta,&forest,FeatureVector(feature_weights),10,conf["forest_output"].as<std::string>());
+ Oracle oc=oracle.ComputeOracle(smeta,&forest,FeatureVector(last_weights),10,conf["forest_output"].as<std::string>());
if (!SILENT) cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
if (!SILENT) cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl;
oc.hope.Print(cerr," +Oracle BLEU");
@@ -798,8 +900,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
// if (!SILENT) cerr << " USING UNIFORM WEIGHTS\n";
// for (int i = 0; i < forest.edges_.size(); ++i)
// forest.edges_[i].edge_prob_=prob_t::One(); }
- forest.Reweight(feature_weights);
- if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation);
+ forest.Reweight(last_weights);
+ if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,show_features,last_weights,oracle.show_derivation);
if (!SILENT) cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl;
if (conf.count("show_partition")) {
const prob_t z = Inside<prob_t, EdgeProb>(forest);
@@ -830,7 +932,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
const prob_t ref_z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp);
ref_exp /= ref_z;
// if (crf_uniform_empirical)
-// log_ref_z = ref_exp.dot(feature_weights);
+// log_ref_z = ref_exp.dot(last_weights);
log_ref_z = log(ref_z);
//cerr << " MODEL LOG Z: " << log_z << endl;
//cerr << " EMPIRICAL LOG Z: " << log_ref_z << endl;
diff --git a/decoder/ff.h b/decoder/ff.h
index e470f9a9..89b8b067 100644
--- a/decoder/ff.h
+++ b/decoder/ff.h
@@ -40,6 +40,7 @@ class FeatureFunction {
FeatureFunction() : state_size_() {}
explicit FeatureFunction(int state_size) : state_size_(state_size) {}
virtual ~FeatureFunction();
+ bool IsStateful() const { return state_size_ > 0; }
// override this. not virtual because we want to expose this to factory template for help before creating a FF
static std::string usage(bool show_params,bool show_details) {
diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc
index a19e9d75..0f7e40bd 100644
--- a/decoder/scfg_translator.cc
+++ b/decoder/scfg_translator.cc
@@ -124,10 +124,10 @@ struct SCFGTranslatorImpl {
if (!SILENT) cerr << "First pass parse... " << endl;
ExhaustiveBottomUpParser parser(goal, glist);
if (!parser.Parse(lattice, forest)){
- if (!SILENT) cerr << "parse failed." << endl;
+ if (!SILENT) cerr << " parse failed." << endl;
return false;
} else {
- if (!SILENT) cerr << "parse succeeded." << endl;
+ // if (!SILENT) cerr << " parse succeeded." << endl;
}
forest->Reweight(weights);
if (use_ctf_) {