diff options
Diffstat (limited to 'decoder/hg.cc')
-rw-r--r-- | decoder/hg.cc | 48 |
1 files changed, 25 insertions, 23 deletions
diff --git a/decoder/hg.cc b/decoder/hg.cc index 70511c07..11dd6f44 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -175,10 +175,31 @@ void Hypergraph::PruneEdges(const std::vector<bool>& prune_edge, bool run_inside TopologicallySortNodesAndEdges(nodes_.size() - 1, &filtered); } +void Hypergraph_finish_prune(Hypergraph &hg,vector<prob_t> const& io,double cutoff,vector<bool> const* preserve_mask,bool verbose=false) +{ + vector<bool> prune(hg.NumberOfEdges()); + if (verbose) { + if (preserve_mask) cerr << preserve_mask->size() << " " << prune.size() << endl; + cerr<<"Finishing prune for "<<prune.size()<<" edges; CUTOFF=" << cutoff << endl; + } + unsigned pc = 0; + for (int i = 0; i < io.size(); ++i) { + const bool prune_edge = (io[i] < cutoff); + if (prune_edge) { + ++pc; + prune[i] = !(preserve_mask && (*preserve_mask)[i]); + } + } + if (verbose) + cerr << "Finished pruning; removed " << pc << "/" << io.size() << " edges\n"; + hg.PruneEdges(prune); +} + void Hypergraph::DensityPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double density, - const vector<bool>* preserve_mask) { + const vector<bool>* preserve_mask) +{ assert(density >= 1.0); const int plen = ViterbiPathLength(*this); vector<WordID> bp; @@ -195,13 +216,7 @@ void Hypergraph::DensityPruneInsideOutside(const double scale, assert(edges_.size() == io.size()); vector<prob_t> sorted = io; nth_element(sorted.begin(), sorted.begin() + rnum, sorted.end(), greater<prob_t>()); - const double cutoff = sorted[rnum]; - vector<bool> prune(edges_.size()); - for (int i = 0; i < edges_.size(); ++i) { - prune[i] = (io[i] < cutoff); - if (preserve_mask && (*preserve_mask)[i]) prune[i] = false; - } - PruneEdges(prune); + Hypergraph_finish_prune(*this,io,sorted[rnum],preserve_mask); } void Hypergraph::BeamPruneInsideOutside( @@ -209,7 +224,7 @@ void Hypergraph::BeamPruneInsideOutside( const bool use_sum_prod_semiring, const double alpha, const vector<bool>* preserve_mask) { - assert(alpha >= 0.0); + assert(alpha > 0.0); assert(scale > 0.0); vector<prob_t> io(edges_.size()); if (use_sum_prod_semiring) @@ -220,20 +235,7 @@ void Hypergraph::BeamPruneInsideOutside( prob_t best; // initializes to zero for (int i = 0; i < io.size(); ++i) if (io[i] > best) best = io[i]; - const prob_t aprob(exp(-alpha)); - const prob_t cutoff = best * aprob; - // cerr << "aprob = " << aprob << "\t CUTOFF=" << cutoff << endl; - vector<bool> prune(edges_.size()); - //cerr << preserve_mask.size() << " " << edges_.size() << endl; - int pc = 0; - for (int i = 0; i < io.size(); ++i) { - const bool prune_edge = (io[i] < cutoff); - if (prune_edge) ++pc; - prune[i] = (io[i] < cutoff); - if (preserve_mask && (*preserve_mask)[i]) prune[i] = false; - } - // cerr << "Beam pruning " << pc << "/" << io.size() << " edges\n"; - PruneEdges(prune); + Hypergraph_finish_prune(*this,io,best*exp(-alpha),preserve_mask); } void Hypergraph::PrintGraphviz() const { |