diff options
-rw-r--r-- | decoder/hg.cc | 40 | ||||
-rw-r--r-- | decoder/hg.h | 4 | ||||
-rw-r--r-- | decoder/hg_intersect.cc | 4 |
3 files changed, 43 insertions, 5 deletions
diff --git a/decoder/hg.cc b/decoder/hg.cc index de8e5e49..629168ad 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -111,9 +111,45 @@ void Hypergraph::PushWeightsToGoal(double scale) { } } -void Hypergraph::PruneEdges(const std::vector<bool>& prune_edge) { +struct EdgeExistsWeightFunction { + EdgeExistsWeightFunction(const std::vector<bool>& prunes) : prunes_(prunes) {} + double operator()(const Hypergraph::Edge& edge) const { + return (prunes_[edge.id_] ? 0.0 : 1.0); + } + private: + const vector<bool>& prunes_; +}; + +void Hypergraph::PruneEdges(const std::vector<bool>& prune_edge, bool run_inside_algorithm) { assert(prune_edge.size() == edges_.size()); - TopologicallySortNodesAndEdges(nodes_.size() - 1, &prune_edge); + vector<bool> filtered = prune_edge; + + if (run_inside_algorithm) { + const EdgeExistsWeightFunction wf(prune_edge); + // use double, not bool since vector<bool> causes problems with the Inside algorithm. + // I don't know a good c++ way to resolve this short of template specialization which + // I dislike. If you know of a better way that doesn't involve specialization, + // fix this! + vector<double> reachable; + bool goal_derivable = (0 < Inside<double, EdgeExistsWeightFunction>(*this, &reachable, wf)); + + assert(reachable.size() == nodes_.size()); + for (int i = 0; i < edges_.size(); ++i) { + bool prune = prune_edge[i]; + if (!prune) { + const Edge& edge = edges_[i]; + for (int j = 0; j < edge.tail_nodes_.size(); ++j) { + if (!reachable[edge.tail_nodes_[j]]) { + prune = true; + break; + } + } + } + filtered[i] = prune; + } + } + + TopologicallySortNodesAndEdges(nodes_.size() - 1, &filtered); } void Hypergraph::DensityPruneInsideOutside(const double scale, diff --git a/decoder/hg.h b/decoder/hg.h index af8d38d2..77d76cc3 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -156,7 +156,9 @@ class Hypergraph { void RemoveNoncoaccessibleStates(int goal_node_id = -1); // remove edges from the hypergraph if prune_edge[edge_id] is true - void PruneEdges(const std::vector<bool>& prune_edge); + // TODO need to investigate why this shouldn't be run for the forest trans + // case. To investigate, change false to true and see where ftrans crashes + void PruneEdges(const std::vector<bool>& prune_edge, bool run_inside_algorithm = false); // if you don't know, use_sum_prod_semiring should be false void DensityPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double density, diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc index e0e70856..8bd11dd3 100644 --- a/decoder/hg_intersect.cc +++ b/decoder/hg_intersect.cc @@ -67,7 +67,7 @@ static bool FastLinearIntersect(const Lattice& target, Hypergraph* hg) { } } } - hg->PruneEdges(prune); + hg->PruneEdges(prune, true); return (cov.size() == target.size()); } @@ -81,7 +81,7 @@ bool HG::Intersect(const Lattice& target, Hypergraph* hg) { const RuleFilter filter(target, 15); // TODO make configurable for (int i = 0; i < rem.size(); ++i) rem[i] = filter(*hg->edges_[i].rule_); - hg->PruneEdges(rem); + hg->PruneEdges(rem, true); const int nedges = hg->edges_.size(); const int nnodes = hg->nodes_.size(); |