summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--decoder/hg.cc40
-rw-r--r--decoder/hg.h4
-rw-r--r--decoder/hg_intersect.cc4
3 files changed, 43 insertions, 5 deletions
diff --git a/decoder/hg.cc b/decoder/hg.cc
index de8e5e49..629168ad 100644
--- a/decoder/hg.cc
+++ b/decoder/hg.cc
@@ -111,9 +111,45 @@ void Hypergraph::PushWeightsToGoal(double scale) {
}
}
-void Hypergraph::PruneEdges(const std::vector<bool>& prune_edge) {
+struct EdgeExistsWeightFunction {
+ EdgeExistsWeightFunction(const std::vector<bool>& prunes) : prunes_(prunes) {}
+ double operator()(const Hypergraph::Edge& edge) const {
+ return (prunes_[edge.id_] ? 0.0 : 1.0);
+ }
+ private:
+ const vector<bool>& prunes_;
+};
+
+void Hypergraph::PruneEdges(const std::vector<bool>& prune_edge, bool run_inside_algorithm) {
assert(prune_edge.size() == edges_.size());
- TopologicallySortNodesAndEdges(nodes_.size() - 1, &prune_edge);
+ vector<bool> filtered = prune_edge;
+
+ if (run_inside_algorithm) {
+ const EdgeExistsWeightFunction wf(prune_edge);
+ // use double, not bool since vector<bool> causes problems with the Inside algorithm.
+ // I don't know a good c++ way to resolve this short of template specialization which
+ // I dislike. If you know of a better way that doesn't involve specialization,
+ // fix this!
+ vector<double> reachable;
+ bool goal_derivable = (0 < Inside<double, EdgeExistsWeightFunction>(*this, &reachable, wf));
+
+ assert(reachable.size() == nodes_.size());
+ for (int i = 0; i < edges_.size(); ++i) {
+ bool prune = prune_edge[i];
+ if (!prune) {
+ const Edge& edge = edges_[i];
+ for (int j = 0; j < edge.tail_nodes_.size(); ++j) {
+ if (!reachable[edge.tail_nodes_[j]]) {
+ prune = true;
+ break;
+ }
+ }
+ }
+ filtered[i] = prune;
+ }
+ }
+
+ TopologicallySortNodesAndEdges(nodes_.size() - 1, &filtered);
}
void Hypergraph::DensityPruneInsideOutside(const double scale,
diff --git a/decoder/hg.h b/decoder/hg.h
index af8d38d2..77d76cc3 100644
--- a/decoder/hg.h
+++ b/decoder/hg.h
@@ -156,7 +156,9 @@ class Hypergraph {
void RemoveNoncoaccessibleStates(int goal_node_id = -1);
// remove edges from the hypergraph if prune_edge[edge_id] is true
- void PruneEdges(const std::vector<bool>& prune_edge);
+ // TODO need to investigate why this shouldn't be run for the forest trans
+ // case. To investigate, change false to true and see where ftrans crashes
+ void PruneEdges(const std::vector<bool>& prune_edge, bool run_inside_algorithm = false);
// if you don't know, use_sum_prod_semiring should be false
void DensityPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double density,
diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc
index e0e70856..8bd11dd3 100644
--- a/decoder/hg_intersect.cc
+++ b/decoder/hg_intersect.cc
@@ -67,7 +67,7 @@ static bool FastLinearIntersect(const Lattice& target, Hypergraph* hg) {
}
}
}
- hg->PruneEdges(prune);
+ hg->PruneEdges(prune, true);
return (cov.size() == target.size());
}
@@ -81,7 +81,7 @@ bool HG::Intersect(const Lattice& target, Hypergraph* hg) {
const RuleFilter filter(target, 15); // TODO make configurable
for (int i = 0; i < rem.size(); ++i)
rem[i] = filter(*hg->edges_[i].rule_);
- hg->PruneEdges(rem);
+ hg->PruneEdges(rem, true);
const int nedges = hg->edges_.size();
const int nnodes = hg->nodes_.size();