diff options
| -rw-r--r-- | decoder/cdec.cc | 49 | ||||
| -rw-r--r-- | decoder/hg.cc | 48 | 
2 files changed, 55 insertions, 42 deletions
| diff --git a/decoder/cdec.cc b/decoder/cdec.cc index e3a2435d..bb29bafb 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -86,6 +86,8 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {          ("show_expected_length", "Show the expected translation length under the model")          ("show_partition,z", "Compute and show the partition (inside score)")          ("show_cfg_search_space", "Show the search space as a CFG") +    ("prelm_density_prune", po::value<double>(), "Applied to -LM forest just before final LM rescoring: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") +    ("density_prune", po::value<double>(), "Keep no more than this many times the number of edges used in the best derivation tree (>=1.0)")          ("prelm_beam_prune", po::value<double>(), "Prune paths from -LM forest before LM rescoring, keeping paths within exp(alpha>=0)")          ("beam_prune", po::value<double>(), "Prune paths from +LM forest, keep paths within exp(alpha>=0)")      ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices") @@ -241,7 +243,7 @@ static void ExtractRulesDedupe(const Hypergraph& hg, ostream* os) {  void register_feature_functions(); -bool beam_param(po::variables_map const& conf,char const* name,double *val,bool scale_srclen=false,double srclen=1) +bool beam_param(po::variables_map const& conf,string const& name,double *val,bool scale_srclen=false,double srclen=1)  {    if (conf.count(name)) {      *val=conf[name].as<double>()*(scale_srclen?srclen:1); @@ -263,6 +265,28 @@ bool prelm_weights_string(po::variables_map const& conf,string &s)    return false;  } +void maybe_prune(Hypergraph &forest,po::variables_map const& conf,string nbeam,string ndensity,string forestname,double srclen) { +    double beam_prune,density_prune; +    bool use_beam_prune=beam_param(conf,nbeam,&beam_prune,conf.count("scale_prune_srclen"),srclen); +    bool use_density_prune=beam_param(conf,ndensity,&density_prune); +    if (use_beam_prune || use_density_prune) { +      double presize=forest.edges_.size(); +      vector<bool> preserve_mask,*pm=0; +      if (conf.count("csplit_preserve_full_word")) { +        preserve_mask.resize(forest.edges_.size()); +        preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true; +        pm=&preserve_mask; +      } +      if (use_beam_prune) +        forest.BeamPruneInsideOutside(1.0, false, beam_prune, pm); +      if (use_density_prune) +        forest.DensityPruneInsideOutside(1.0 ,false, density_prune, pm); +      if (!forestname.empty()) forestname=" "+forestname; +      cerr << viterbi_stats(forest,"  Pruned "+forestname+" forest",false,false); +      cerr << "  Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl; +    } +} +  int main(int argc, char** argv) {    global_ff_registry.reset(new FFRegistry); @@ -282,12 +306,11 @@ int main(int argc, char** argv) {    const string formalism = LowercaseString(conf["formalism"].as<string>());    const bool csplit_preserve_full_word = conf.count("csplit_preserve_full_word");    if (csplit_preserve_full_word && -      (formalism != "csplit" || !conf.count("beam_prune"))) { +      (formalism != "csplit" || !(conf.count("beam_prune")||conf.count("density_prune")||conf.count("prelm_beam_prune")||conf.count("prelm_density_prune")))) {      cerr << "--csplit_preserve_full_word should only be " -         << "used with csplit AND --beam_prune!\n"; +         << "used with csplit AND --*_prune!\n";      exit(1);    } -  const bool scale_prune_srclen=conf.count("scale_prune_srclen");    const bool csplit_output_plf = conf.count("csplit_output_plf");    if (csplit_output_plf && formalism != "csplit") {      cerr << "--csplit_output_plf should only be used with csplit!\n"; @@ -477,13 +500,7 @@ int main(int argc, char** argv) {        cerr << viterbi_stats(forest," prelm forest",true,show_tree_structure);      } -    double prelm_beam_prune; -    if (beam_param(conf,"prelm_beam_prune",&prelm_beam_prune,scale_prune_srclen,srclen)) { -      double presize=forest.edges_.size(); -      forest.BeamPruneInsideOutside(1.0, false, prelm_beam_prune, NULL); -      cerr << viterbi_stats(forest,"  Pruned -LM forest",false,false); -      cerr << "  Pruned -LM forest (beam="<<prelm_beam_prune<<") portion of edges kept: "<<forest.edges_.size()/presize<<endl; -    } +    maybe_prune(forest,conf,"prelm_beam_prune","prelm_density_prune","-LM",srclen);      bool has_late_models = !late_models.empty();      if (has_late_models) { @@ -501,14 +518,8 @@ int main(int argc, char** argv) {        forest.Reweight(feature_weights);        cerr << viterbi_stats(forest,"  +LM forest",true,show_tree_structure);      } -    double beam_prune; -    if (beam_param(conf,"beam_prune",&beam_prune,scale_prune_srclen,srclen)) { -      vector<bool> preserve_mask(forest.edges_.size(), false); -      if (csplit_preserve_full_word) -        preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true; -      forest.BeamPruneInsideOutside(1.0, false, beam_prune, &preserve_mask); -      cerr << viterbi_stats(forest,"  Pruned forest",false,false); -    } + +    maybe_prune(forest,conf,"beam_prune","density_prune","+LM",srclen);      if (conf.count("forest_output") && !has_ref) {        ForestWriter writer(conf["forest_output"].as<string>(), sent_id); diff --git a/decoder/hg.cc b/decoder/hg.cc index 70511c07..11dd6f44 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -175,10 +175,31 @@ void Hypergraph::PruneEdges(const std::vector<bool>& prune_edge, bool run_inside    TopologicallySortNodesAndEdges(nodes_.size() - 1, &filtered);  } +void Hypergraph_finish_prune(Hypergraph &hg,vector<prob_t> const& io,double cutoff,vector<bool> const* preserve_mask,bool verbose=false) +{ +  vector<bool> prune(hg.NumberOfEdges()); +  if (verbose) { +    if (preserve_mask) cerr << preserve_mask->size() << " " << prune.size() << endl; +    cerr<<"Finishing prune for "<<prune.size()<<" edges; CUTOFF=" << cutoff << endl; +  } +  unsigned pc = 0; +  for (int i = 0; i < io.size(); ++i) { +    const bool prune_edge = (io[i] < cutoff); +    if (prune_edge) { +      ++pc; +      prune[i] = !(preserve_mask && (*preserve_mask)[i]); +    } +  } +  if (verbose) +    cerr << "Finished pruning; removed " << pc << "/" << io.size() << " edges\n"; +  hg.PruneEdges(prune); +} +  void Hypergraph::DensityPruneInsideOutside(const double scale,                                             const bool use_sum_prod_semiring,                                             const double density, -                                           const vector<bool>* preserve_mask) { +                                           const vector<bool>* preserve_mask) +{    assert(density >= 1.0);    const int plen = ViterbiPathLength(*this);    vector<WordID> bp; @@ -195,13 +216,7 @@ void Hypergraph::DensityPruneInsideOutside(const double scale,    assert(edges_.size() == io.size());    vector<prob_t> sorted = io;    nth_element(sorted.begin(), sorted.begin() + rnum, sorted.end(), greater<prob_t>()); -  const double cutoff = sorted[rnum]; -  vector<bool> prune(edges_.size()); -  for (int i = 0; i < edges_.size(); ++i) { -    prune[i] = (io[i] < cutoff); -    if (preserve_mask && (*preserve_mask)[i]) prune[i] = false; -  } -  PruneEdges(prune); +  Hypergraph_finish_prune(*this,io,sorted[rnum],preserve_mask);  }  void Hypergraph::BeamPruneInsideOutside( @@ -209,7 +224,7 @@ void Hypergraph::BeamPruneInsideOutside(      const bool use_sum_prod_semiring,      const double alpha,      const vector<bool>* preserve_mask) { -  assert(alpha >= 0.0); +  assert(alpha > 0.0);    assert(scale > 0.0);    vector<prob_t> io(edges_.size());    if (use_sum_prod_semiring) @@ -220,20 +235,7 @@ void Hypergraph::BeamPruneInsideOutside(    prob_t best;  // initializes to zero    for (int i = 0; i < io.size(); ++i)      if (io[i] > best) best = io[i]; -  const prob_t aprob(exp(-alpha)); -  const prob_t cutoff = best * aprob; -  // cerr << "aprob = " << aprob << "\t  CUTOFF=" << cutoff << endl; -  vector<bool> prune(edges_.size()); -  //cerr << preserve_mask.size() << " " << edges_.size() << endl; -  int pc = 0; -  for (int i = 0; i < io.size(); ++i) { -    const bool prune_edge = (io[i] < cutoff); -    if (prune_edge) ++pc; -    prune[i] = (io[i] < cutoff); -    if (preserve_mask && (*preserve_mask)[i]) prune[i] = false; -  } -  // cerr << "Beam pruning " << pc << "/" << io.size() << " edges\n"; -  PruneEdges(prune); +  Hypergraph_finish_prune(*this,io,best*exp(-alpha),preserve_mask);  }  void Hypergraph::PrintGraphviz() const { | 
