diff options
author | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-07 21:26:51 +0000 |
---|---|---|
committer | graehl <graehl@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-07 21:26:51 +0000 |
commit | c4cb48ad003de65f97e0a6013e9da4329c89faf1 (patch) | |
tree | 1d96a2224375bfbb0a8fb97c975147d37a4e324d /decoder/hg.cc | |
parent | ffe002f8792dd8693c12e9bc6a7f715ca170acfc (diff) |
safe hg pruning without needing additional inside reachability pass (max margin tightness is less at bottom of derivation tree)
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@181 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'decoder/hg.cc')
-rw-r--r-- | decoder/hg.cc | 29 |
1 files changed, 18 insertions, 11 deletions
diff --git a/decoder/hg.cc b/decoder/hg.cc index b6b9d8bd..70831d3d 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -129,7 +129,7 @@ void Hypergraph::PushWeightsToGoal(double scale) { } struct EdgeExistsWeightFunction { - EdgeExistsWeightFunction(const std::vector<bool>& prunes) : prunes_(prunes) {} + EdgeExistsWeightFunction(const vector<bool>& prunes) : prunes_(prunes) {} bool operator()(const Hypergraph::Edge& edge) const { return !prunes_[edge.id_]; } @@ -137,7 +137,7 @@ struct EdgeExistsWeightFunction { const vector<bool>& prunes_; }; -void Hypergraph::PruneEdges(const std::vector<bool>& prune_edge, bool run_inside_algorithm) { +void Hypergraph::PruneEdges(const EdgeMask& prune_edge, bool run_inside_algorithm) { assert(prune_edge.size() == edges_.size()); vector<bool> filtered = prune_edge; @@ -175,18 +175,22 @@ void Hypergraph::PruneEdges(const std::vector<bool>& prune_edge, bool run_inside TopologicallySortNodesAndEdges(nodes_.size() - 1, &filtered); } -void Hypergraph_finish_prune(Hypergraph &hg,vector<prob_t> const& io,double cutoff,vector<bool> const* preserve_mask,bool verbose=false) +void Hypergraph::MarginPrune(vector<prob_t> const& io,prob_t cutoff,vector<bool> const* preserve_mask,bool safe_inside,bool verbose) { - const double EPSILON=1e-5; + const prob_t BARELY_SMALLER(1e-6,false); // nearly 1; 1-epsilon //TODO: //FIXME: if EPSILON is 0, then remnants (useless edges that don't connect to top? or top-connected but not bottom-up buildable referneced?) are left in the hypergraph output that cause mr_vest_map to segfault. adding EPSILON probably just covers up the symptom by making it far less frequent; I imagine any time threshold is set by DensityPrune, cutoff is exactly equal to the io of several nodes, but because of how it's computed, some round slightly down vs. slightly up. probably the flaw is in PruneEdges. - cutoff=cutoff-EPSILON; - vector<bool> prune(hg.NumberOfEdges()); + int ne=NumberOfEdges(); + cutoff*=BARELY_SMALLER; + prob_t creep=BARELY_SMALLER.root(-(ne+1)); // start more permissive, then become less generous. this is barely more than 1. we want to do this because it's a disaster if something lower in a derivation tree is deleted, but the higher thing remains (unless safe_inside) + + vector<bool> prune(ne); if (verbose) { if (preserve_mask) cerr << preserve_mask->size() << " " << prune.size() << endl; cerr<<"Finishing prune for "<<prune.size()<<" edges; CUTOFF=" << cutoff << endl; } unsigned pc = 0; for (int i = 0; i < io.size(); ++i) { + cutoff*=creep; const bool prune_edge = (io[i] < cutoff); if (prune_edge) { ++pc; @@ -195,13 +199,15 @@ void Hypergraph_finish_prune(Hypergraph &hg,vector<prob_t> const& io,double cuto } if (verbose) cerr << "Finished pruning; removed " << pc << "/" << io.size() << " edges\n"; - hg.PruneEdges(prune,true); // inside reachability check in case cutoff rounded down too much (probably redundant with EPSILON hack) + PruneEdges(prune,safe_inside); // inside reachability check in case cutoff rounded down too much (probably redundant with EPSILON hack) } void Hypergraph::DensityPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double density, - const vector<bool>* preserve_mask) + const vector<bool>* preserve_mask + , bool safe_inside + ) { assert(density >= 1.0); const int plen = ViterbiPathLength(*this); @@ -221,14 +227,15 @@ void Hypergraph::DensityPruneInsideOutside(const double scale, assert(edges_.size() == io.size()); vector<prob_t> sorted = io; nth_element(sorted.begin(), sorted.begin() + rnum, sorted.end(), greater<prob_t>()); - Hypergraph_finish_prune(*this,io,sorted[rnum],preserve_mask); + MarginPrune(io,sorted[rnum],preserve_mask,safe_inside); } void Hypergraph::BeamPruneInsideOutside( const double scale, const bool use_sum_prod_semiring, const double alpha, - const vector<bool>* preserve_mask) { + const vector<bool>* preserve_mask + ,bool safe_inside) { assert(alpha >= 0.0); vector<prob_t> io(edges_.size()); if (use_sum_prod_semiring) { @@ -240,7 +247,7 @@ void Hypergraph::BeamPruneInsideOutside( prob_t best; // initializes to zero for (int i = 0; i < io.size(); ++i) if (io[i] > best) best = io[i]; - Hypergraph_finish_prune(*this,io,best*exp(-alpha),preserve_mask); + MarginPrune(io,best*prob_t::exp(-alpha),preserve_mask,safe_inside); } void Hypergraph::PrintGraphviz() const { |