diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/hg.cc | 40 | ||||
| -rw-r--r-- | decoder/hg.h | 4 | ||||
| -rw-r--r-- | decoder/hg_intersect.cc | 4 | 
3 files changed, 43 insertions, 5 deletions
| diff --git a/decoder/hg.cc b/decoder/hg.cc index de8e5e49..629168ad 100644 --- a/decoder/hg.cc +++ b/decoder/hg.cc @@ -111,9 +111,45 @@ void Hypergraph::PushWeightsToGoal(double scale) {    }  } -void Hypergraph::PruneEdges(const std::vector<bool>& prune_edge) { +struct EdgeExistsWeightFunction { +  EdgeExistsWeightFunction(const std::vector<bool>& prunes) : prunes_(prunes) {} +  double operator()(const Hypergraph::Edge& edge) const { +    return (prunes_[edge.id_] ? 0.0 : 1.0); +  } + private: +  const vector<bool>& prunes_; +}; + +void Hypergraph::PruneEdges(const std::vector<bool>& prune_edge, bool run_inside_algorithm) {    assert(prune_edge.size() == edges_.size()); -  TopologicallySortNodesAndEdges(nodes_.size() - 1, &prune_edge); +  vector<bool> filtered = prune_edge; + +  if (run_inside_algorithm) { +    const EdgeExistsWeightFunction wf(prune_edge); +    // use double, not bool since vector<bool> causes problems with the Inside algorithm. +    // I don't know a good c++ way to resolve this short of template specialization which +    // I dislike.  If you know of a better way that doesn't involve specialization, +    // fix this! +    vector<double> reachable; +    bool goal_derivable = (0 < Inside<double, EdgeExistsWeightFunction>(*this, &reachable, wf)); + +    assert(reachable.size() == nodes_.size()); +    for (int i = 0; i < edges_.size(); ++i) { +      bool prune = prune_edge[i]; +      if (!prune) { +        const Edge& edge = edges_[i]; +        for (int j = 0; j < edge.tail_nodes_.size(); ++j) { +          if (!reachable[edge.tail_nodes_[j]]) { +            prune = true; +            break; +          } +        } +      } +      filtered[i] = prune; +    } +  } +  +  TopologicallySortNodesAndEdges(nodes_.size() - 1, &filtered);  }  void Hypergraph::DensityPruneInsideOutside(const double scale, diff --git a/decoder/hg.h b/decoder/hg.h index af8d38d2..77d76cc3 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -156,7 +156,9 @@ class Hypergraph {    void RemoveNoncoaccessibleStates(int goal_node_id = -1);    // remove edges from the hypergraph if prune_edge[edge_id] is true -  void PruneEdges(const std::vector<bool>& prune_edge); +  // TODO need to investigate why this shouldn't be run for the forest trans +  // case.  To investigate, change false to true and see where ftrans crashes +  void PruneEdges(const std::vector<bool>& prune_edge, bool run_inside_algorithm = false);    // if you don't know, use_sum_prod_semiring should be false    void DensityPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double density, diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc index e0e70856..8bd11dd3 100644 --- a/decoder/hg_intersect.cc +++ b/decoder/hg_intersect.cc @@ -67,7 +67,7 @@ static bool FastLinearIntersect(const Lattice& target, Hypergraph* hg) {        }      }    } -  hg->PruneEdges(prune); +  hg->PruneEdges(prune, true);    return (cov.size() == target.size());  } @@ -81,7 +81,7 @@ bool HG::Intersect(const Lattice& target, Hypergraph* hg) {    const RuleFilter filter(target, 15);   // TODO make configurable    for (int i = 0; i < rem.size(); ++i)      rem[i] = filter(*hg->edges_[i].rule_); -  hg->PruneEdges(rem); +  hg->PruneEdges(rem, true);    const int nedges = hg->edges_.size();    const int nnodes = hg->nodes_.size(); | 
