summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-06 15:31:18 +0000
committergraehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-06 15:31:18 +0000
commit825b1fc172a4f097c94b0fe8137ba2356262b5f4 (patch)
treeb1c1bdd5824eb0edb5be0a306308013a1d67be71
parent45b298139d494ee81ce9ea23424faaba5f177230 (diff)
cdec --[prelm_]density_prune
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@150 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r--decoder/cdec.cc49
-rw-r--r--decoder/hg.cc48
2 files changed, 55 insertions, 42 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index e3a2435d..bb29bafb 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -86,6 +86,8 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("show_expected_length", "Show the expected translation length under the model")
("show_partition,z", "Compute and show the partition (inside score)")
("show_cfg_search_space", "Show the search space as a CFG")
+ ("prelm_density_prune", po::value<double>(), "Applied to -LM forest just before final LM rescoring: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
+ ("density_prune", po::value<double>(), "Keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")
("prelm_beam_prune", po::value<double>(), "Prune paths from -LM forest before LM rescoring, keeping paths within exp(alpha>=0)")
("beam_prune", po::value<double>(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)")
("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices")
@@ -241,7 +243,7 @@ static void ExtractRulesDedupe(const Hypergraph& hg, ostream* os) {
void register_feature_functions();
-bool beam_param(po::variables_map const& conf,char const* name,double *val,bool scale_srclen=false,double srclen=1)
+bool beam_param(po::variables_map const& conf,string const& name,double *val,bool scale_srclen=false,double srclen=1)
{
if (conf.count(name)) {
*val=conf[name].as<double>()*(scale_srclen?srclen:1);
@@ -263,6 +265,28 @@ bool prelm_weights_string(po::variables_map const& conf,string &s)
return false;
}
+void maybe_prune(Hypergraph &forest,po::variables_map const& conf,string nbeam,string ndensity,string forestname,double srclen) {
+ double beam_prune,density_prune;
+ bool use_beam_prune=beam_param(conf,nbeam,&beam_prune,conf.count("scale_prune_srclen"),srclen);
+ bool use_density_prune=beam_param(conf,ndensity,&density_prune);
+ if (use_beam_prune || use_density_prune) {
+ double presize=forest.edges_.size();
+ vector<bool> preserve_mask,*pm=0;
+ if (conf.count("csplit_preserve_full_word")) {
+ preserve_mask.resize(forest.edges_.size());
+ preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true;
+ pm=&preserve_mask;
+ }
+ if (use_beam_prune)
+ forest.BeamPruneInsideOutside(1.0, false, beam_prune, pm);
+ if (use_density_prune)
+ forest.DensityPruneInsideOutside(1.0 ,false, density_prune, pm);
+ if (!forestname.empty()) forestname=" "+forestname;
+ cerr << viterbi_stats(forest," Pruned "+forestname+" forest",false,false);
+ cerr << " Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl;
+ }
+}
+
int main(int argc, char** argv) {
global_ff_registry.reset(new FFRegistry);
@@ -282,12 +306,11 @@ int main(int argc, char** argv) {
const string formalism = LowercaseString(conf["formalism"].as<string>());
const bool csplit_preserve_full_word = conf.count("csplit_preserve_full_word");
if (csplit_preserve_full_word &&
- (formalism != "csplit" || !conf.count("beam_prune"))) {
+ (formalism != "csplit" || !(conf.count("beam_prune")||conf.count("density_prune")||conf.count("prelm_beam_prune")||conf.count("prelm_density_prune")))) {
cerr << "--csplit_preserve_full_word should only be "
- << "used with csplit AND --beam_prune!\n";
+ << "used with csplit AND --*_prune!\n";
exit(1);
}
- const bool scale_prune_srclen=conf.count("scale_prune_srclen");
const bool csplit_output_plf = conf.count("csplit_output_plf");
if (csplit_output_plf && formalism != "csplit") {
cerr << "--csplit_output_plf should only be used with csplit!\n";
@@ -477,13 +500,7 @@ int main(int argc, char** argv) {
cerr << viterbi_stats(forest," prelm forest",true,show_tree_structure);
}
- double prelm_beam_prune;
- if (beam_param(conf,"prelm_beam_prune",&prelm_beam_prune,scale_prune_srclen,srclen)) {
- double presize=forest.edges_.size();
- forest.BeamPruneInsideOutside(1.0, false, prelm_beam_prune, NULL);
- cerr << viterbi_stats(forest," Pruned -LM forest",false,false);
- cerr << " Pruned -LM forest (beam="<<prelm_beam_prune<<") portion of edges kept: "<<forest.edges_.size()/presize<<endl;
- }
+ maybe_prune(forest,conf,"prelm_beam_prune","prelm_density_prune","-LM",srclen);
bool has_late_models = !late_models.empty();
if (has_late_models) {
@@ -501,14 +518,8 @@ int main(int argc, char** argv) {
forest.Reweight(feature_weights);
cerr << viterbi_stats(forest," +LM forest",true,show_tree_structure);
}
- double beam_prune;
- if (beam_param(conf,"beam_prune",&beam_prune,scale_prune_srclen,srclen)) {
- vector<bool> preserve_mask(forest.edges_.size(), false);
- if (csplit_preserve_full_word)
- preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true;
- forest.BeamPruneInsideOutside(1.0, false, beam_prune, &preserve_mask);
- cerr << viterbi_stats(forest," Pruned forest",false,false);
- }
+
+ maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen);
if (conf.count("forest_output") && !has_ref) {
ForestWriter writer(conf["forest_output"].as<string>(), sent_id);
diff --git a/decoder/hg.cc b/decoder/hg.cc
index 70511c07..11dd6f44 100644
--- a/decoder/hg.cc
+++ b/decoder/hg.cc
@@ -175,10 +175,31 @@ void Hypergraph::PruneEdges(const std::vector<bool>& prune_edge, bool run_inside
TopologicallySortNodesAndEdges(nodes_.size() - 1, &filtered);
}
+void Hypergraph_finish_prune(Hypergraph &hg,vector<prob_t> const& io,double cutoff,vector<bool> const* preserve_mask,bool verbose=false)
+{
+ vector<bool> prune(hg.NumberOfEdges());
+ if (verbose) {
+ if (preserve_mask) cerr << preserve_mask->size() << " " << prune.size() << endl;
+ cerr<<"Finishing prune for "<<prune.size()<<" edges; CUTOFF=" << cutoff << endl;
+ }
+ unsigned pc = 0;
+ for (int i = 0; i < io.size(); ++i) {
+ const bool prune_edge = (io[i] < cutoff);
+ if (prune_edge) {
+ ++pc;
+ prune[i] = !(preserve_mask && (*preserve_mask)[i]);
+ }
+ }
+ if (verbose)
+ cerr << "Finished pruning; removed " << pc << "/" << io.size() << " edges\n";
+ hg.PruneEdges(prune);
+}
+
void Hypergraph::DensityPruneInsideOutside(const double scale,
const bool use_sum_prod_semiring,
const double density,
- const vector<bool>* preserve_mask) {
+ const vector<bool>* preserve_mask)
+{
assert(density >= 1.0);
const int plen = ViterbiPathLength(*this);
vector<WordID> bp;
@@ -195,13 +216,7 @@ void Hypergraph::DensityPruneInsideOutside(const double scale,
assert(edges_.size() == io.size());
vector<prob_t> sorted = io;
nth_element(sorted.begin(), sorted.begin() + rnum, sorted.end(), greater<prob_t>());
- const double cutoff = sorted[rnum];
- vector<bool> prune(edges_.size());
- for (int i = 0; i < edges_.size(); ++i) {
- prune[i] = (io[i] < cutoff);
- if (preserve_mask && (*preserve_mask)[i]) prune[i] = false;
- }
- PruneEdges(prune);
+ Hypergraph_finish_prune(*this,io,sorted[rnum],preserve_mask);
}
void Hypergraph::BeamPruneInsideOutside(
@@ -209,7 +224,7 @@ void Hypergraph::BeamPruneInsideOutside(
const bool use_sum_prod_semiring,
const double alpha,
const vector<bool>* preserve_mask) {
- assert(alpha >= 0.0);
+ assert(alpha > 0.0);
assert(scale > 0.0);
vector<prob_t> io(edges_.size());
if (use_sum_prod_semiring)
@@ -220,20 +235,7 @@ void Hypergraph::BeamPruneInsideOutside(
prob_t best; // initializes to zero
for (int i = 0; i < io.size(); ++i)
if (io[i] > best) best = io[i];
- const prob_t aprob(exp(-alpha));
- const prob_t cutoff = best * aprob;
- // cerr << "aprob = " << aprob << "\t CUTOFF=" << cutoff << endl;
- vector<bool> prune(edges_.size());
- //cerr << preserve_mask.size() << " " << edges_.size() << endl;
- int pc = 0;
- for (int i = 0; i < io.size(); ++i) {
- const bool prune_edge = (io[i] < cutoff);
- if (prune_edge) ++pc;
- prune[i] = (io[i] < cutoff);
- if (preserve_mask && (*preserve_mask)[i]) prune[i] = false;
- }
- // cerr << "Beam pruning " << pc << "/" << io.size() << " edges\n";
- PruneEdges(prune);
+ Hypergraph_finish_prune(*this,io,best*exp(-alpha),preserve_mask);
}
void Hypergraph::PrintGraphviz() const {