diff options
author | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-06 15:31:18 +0000 |
---|---|---|
committer | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-06 15:31:18 +0000 |
commit | a3ff76f53ee33638eaf1f723578a4a7edff9cb38 (patch) | |
tree | 0f146585f7d3d4d18d0d29bc4116bdc4ffcb432a | |
parent | 2a5dd01ec240ff4fdbe977da5bb3f3a067b27423 (diff) |
cdec --[prelm_]density_prune
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@150 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r-- | decoder/cdec.cc | 49 | ||||
-rw-r--r-- | decoder/hg.cc | 48 |
2 files changed, 55 insertions, 42 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc index e3a2435d..bb29bafb 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -86,6 +86,8 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("show_expected_length", "Show the expected translation length under the model") ("show_partition,z", "Compute and show the partition (inside score)") ("show_cfg_search_space", "Show the search space as a CFG") + ("prelm_density_prune", po::value<double>(), "Applied to -LM forest just before final LM rescoring: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") + ("density_prune", po::value<double>(), "Keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") ("prelm_beam_prune", po::value<double>(), "Prune paths from -LM forest before LM rescoring, keeping paths within exp(alpha>=0)") ("beam_prune", po::value<double>(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)") ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices") @@ -241,7 +243,7 @@ static void ExtractRulesDedupe(const Hypergraph& hg, ostream* os) { void register_feature_functions(); -bool beam_param(po::variables_map const& conf,char const* name,double *val,bool scale_srclen=false,double srclen=1) +bool beam_param(po::variables_map const& conf,string const& name,double *val,bool scale_srclen=false,double srclen=1) { if (conf.count(name)) { *val=conf[name].as<double>()*(scale_srclen?srclen:1); @@ -263,6 +265,28 @@ bool prelm_weights_string(po::variables_map const& conf,string &s) return false; } +void maybe_prune(Hypergraph &forest,po::variables_map const& conf,string nbeam,string ndensity,string forestname,double srclen) { + double beam_prune,density_prune; + bool use_beam_prune=beam_param(conf,nbeam,&beam_prune,conf.count("scale_prune_srclen"),srclen); + bool use_density_prune=beam_param(conf,ndensity,&density_prune); + if (use_beam_prune || use_density_prune) { + double presize=forest.edges_.size(); + vector<bool> preserve_mask,*pm=0; + if (conf.count("csplit_preserve_full_word")) { + preserve_mask.resize(forest.edges_.size()); + preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true; + pm=&preserve_mask; + } + if (use_beam_prune) + forest.BeamPruneInsideOutside(1.0, false, beam_prune, pm); + if (use_density_prune) + forest.DensityPruneInsideOutside(1.0 ,false, density_prune, pm); + if (!forestname.empty()) forestname=" "+forestname; + cerr << viterbi_stats(forest," Pruned "+forestname+" forest",false,false); + cerr << " Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl; + } +} + int main(int argc, char** argv) { global_ff_registry.reset(new FFRegistry); @@ -282,12 +306,11 @@ int main(int argc, char** argv) { const string formalism = LowercaseString(conf["formalism"].as<string>()); const bool csplit_preserve_full_word = conf.count("csplit_preserve_full_word"); if (csplit_preserve_full_word && - (formalism != "csplit" || !conf.count("beam_prune"))) { + (formalism != "csplit" || !(conf.count("beam_prune")||conf.count("density_prune")||conf.count("prelm_beam_prune")||conf.count("prelm_density_prune")))) { cerr << "--csplit_preserve_full_word should only be " - << "used with csplit AND --beam_prune!\n"; + << "used with csplit AND --*_prune!\n"; exit(1); } - const bool scale_prune_srclen=conf.count("scale_prune_srclen"); const bool csplit_output_plf = conf.count("csplit_output_plf"); if (csplit_output_plf && formalism != "csplit") { cerr << "--csplit_output_plf should only be used with csplit!\n"; @@ -477,13 +500,7 @@ int main(int argc, char** argv) { cerr << viterbi_stats(forest," prelm forest",true,show_tree_structure); } - double prelm_beam_prune; - if (beam_param(conf,"prelm_beam_prune",&prelm_beam_prune,scale_prune_srclen,srclen)) { - double presize=forest.edges_.size(); - forest.BeamPruneInsideOutside(1.0, false, prelm_beam_prune, NULL); - cerr << viterbi_stats(forest," Pruned -LM forest",false,false); - cerr << " Pruned -LM forest (beam="<<prelm_beam_prune<<") portion of edges kept: "<<forest.edges_.size()/presize<<endl; - } + maybe_prune(forest,conf,"prelm_beam_prune","prelm_density_prune","-LM",srclen); bool has_late_models = !late_models.empty(); if (has_late_models) { @@ -501,14 +518,8 @@ int main(int argc, char** argv) { forest.Reweight(feature_weights); cerr << viterbi_stats(forest," +LM forest",true,show_tree_structure); } - double beam_prune; - if (beam_param(conf,"beam_prune",&beam_prune,scale_prune_srclen,srclen)) { - vector<bool> preserve_mask(forest.edges_.size(), false); - if (csplit_preserve_full_word) - preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true; - forest.BeamPruneInsideOutside(1.0, false, beam_prune, &preserve_mask); - cerr << viterbi_stats(forest," Pruned forest",false,false); - } + + maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen); if (conf.count("forest_output") && !has_ref) { ForestWriter writer(conf["forest_output"].as<string>(), sent_id); diff --git a/decoder/hg.cc b/decoder/hg.cc index 70511c07..11dd6f44 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -175,10 +175,31 @@ void Hypergraph::PruneEdges(const std::vector<bool>& prune_edge, bool run_inside TopologicallySortNodesAndEdges(nodes_.size() - 1, &filtered); } +void Hypergraph_finish_prune(Hypergraph &hg,vector<prob_t> const& io,double cutoff,vector<bool> const* preserve_mask,bool verbose=false) +{ + vector<bool> prune(hg.NumberOfEdges()); + if (verbose) { + if (preserve_mask) cerr << preserve_mask->size() << " " << prune.size() << endl; + cerr<<"Finishing prune for "<<prune.size()<<" edges; CUTOFF=" << cutoff << endl; + } + unsigned pc = 0; + for (int i = 0; i < io.size(); ++i) { + const bool prune_edge = (io[i] < cutoff); + if (prune_edge) { + ++pc; + prune[i] = !(preserve_mask && (*preserve_mask)[i]); + } + } + if (verbose) + cerr << "Finished pruning; removed " << pc << "/" << io.size() << " edges\n"; + hg.PruneEdges(prune); +} + void Hypergraph::DensityPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double density, - const vector<bool>* preserve_mask) { + const vector<bool>* preserve_mask) +{ assert(density >= 1.0); const int plen = ViterbiPathLength(*this); vector<WordID> bp; @@ -195,13 +216,7 @@ void Hypergraph::DensityPruneInsideOutside(const double scale, assert(edges_.size() == io.size()); vector<prob_t> sorted = io; nth_element(sorted.begin(), sorted.begin() + rnum, sorted.end(), greater<prob_t>()); - const double cutoff = sorted[rnum]; - vector<bool> prune(edges_.size()); - for (int i = 0; i < edges_.size(); ++i) { - prune[i] = (io[i] < cutoff); - if (preserve_mask && (*preserve_mask)[i]) prune[i] = false; - } - PruneEdges(prune); + Hypergraph_finish_prune(*this,io,sorted[rnum],preserve_mask); } void Hypergraph::BeamPruneInsideOutside( @@ -209,7 +224,7 @@ void Hypergraph::BeamPruneInsideOutside( const bool use_sum_prod_semiring, const double alpha, const vector<bool>* preserve_mask) { - assert(alpha >= 0.0); + assert(alpha > 0.0); assert(scale > 0.0); vector<prob_t> io(edges_.size()); if (use_sum_prod_semiring) @@ -220,20 +235,7 @@ void Hypergraph::BeamPruneInsideOutside( prob_t best; // initializes to zero for (int i = 0; i < io.size(); ++i) if (io[i] > best) best = io[i]; - const prob_t aprob(exp(-alpha)); - const prob_t cutoff = best * aprob; - // cerr << "aprob = " << aprob << "\t CUTOFF=" << cutoff << endl; - vector<bool> prune(edges_.size()); - //cerr << preserve_mask.size() << " " << edges_.size() << endl; - int pc = 0; - for (int i = 0; i < io.size(); ++i) { - const bool prune_edge = (io[i] < cutoff); - if (prune_edge) ++pc; - prune[i] = (io[i] < cutoff); - if (preserve_mask && (*preserve_mask)[i]) prune[i] = false; - } - // cerr << "Beam pruning " << pc << "/" << io.size() << " edges\n"; - PruneEdges(prune); + Hypergraph_finish_prune(*this,io,best*exp(-alpha),preserve_mask); } void Hypergraph::PrintGraphviz() const { |