summaryrefslogtreecommitdiff
path: root/decoder/hg_union.cc
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2014-04-08 23:16:24 -0400
committerChris Dyer <redpony@gmail.com>2014-04-08 23:16:24 -0400
commit71c1f8b274e4f0e83252fe3c68fb45c5ec4069e6 (patch)
treefc57496d413655c157985c9cc4a492ef791c5be2 /decoder/hg_union.cc
parent7242963e683d7b3d4b6c49ac3814ced360ef10c8 (diff)
smarter union
Diffstat (limited to 'decoder/hg_union.cc')
-rw-r--r--decoder/hg_union.cc105
1 files changed, 71 insertions, 34 deletions
diff --git a/decoder/hg_union.cc b/decoder/hg_union.cc
index 37082976..4899e716 100644
--- a/decoder/hg_union.cc
+++ b/decoder/hg_union.cc
@@ -1,56 +1,93 @@
#include "hg_union.h"
+#ifndef HAVE_OLD_CPP
+# include <unordered_map>
+#else
+# include <tr1/unordered_map>
+namespace std { using std::tr1::unordered_set; }
+#endif
+
#include "hg.h"
+#include "sparse_vector.h"
using namespace std;
namespace HG {
+static bool EdgesMatch(const HG::Edge& a, const Hypergraph& ahg, const HG::Edge& b, const Hypergraph& bhg) {
+ const unsigned arity = a.tail_nodes_.size();
+ if (arity != b.tail_nodes_.size()) return false;
+ if (a.rule_->e() != b.rule_->e()) return false;
+ if (a.rule_->f() != b.rule_->f()) return false;
+
+ for (unsigned i = 0; i < arity; ++i)
+ if (ahg.nodes_[a.tail_nodes_[i]].node_hash != bhg.nodes_[b.tail_nodes_[i]].node_hash) return false;
+ const SparseVector<double> diff = a.feature_values_ - b.feature_values_;
+ for (auto& kv : diff)
+ if (fabs(kv.second) > 1e-6) return false;
+ return true;
+}
+
void Union(const Hypergraph& in, Hypergraph* out) {
if (&in == out) return;
if (out->nodes_.empty()) {
out->nodes_ = in.nodes_;
out->edges_ = in.edges_; return;
}
- unsigned noff = out->nodes_.size();
- unsigned eoff = out->edges_.size();
- int ogoal = in.nodes_.size() - 1;
- int cgoal = noff - 1;
- // keep a single goal node, so add nodes.size - 1
- out->nodes_.resize(out->nodes_.size() + ogoal);
- // add all edges
- out->edges_.resize(out->edges_.size() + in.edges_.size());
-
- for (int i = 0; i < ogoal; ++i) {
- const Hypergraph::Node& on = in.nodes_[i];
- Hypergraph::Node& cn = out->nodes_[i + noff];
- cn.id_ = i + noff;
- cn.in_edges_.resize(on.in_edges_.size());
- for (unsigned j = 0; j < on.in_edges_.size(); ++j)
- cn.in_edges_[j] = on.in_edges_[j] + eoff;
-
- cn.out_edges_.resize(on.out_edges_.size());
- for (unsigned j = 0; j < on.out_edges_.size(); ++j)
- cn.out_edges_[j] = on.out_edges_[j] + eoff;
+ if (!in.AreNodesUniquelyIdentified()) {
+ cerr << "Union: Nodes are not uniquely identified in input!\n";
+ abort();
+ }
+ if (!out->AreNodesUniquelyIdentified()) {
+ cerr << "Union: Nodes are not uniquely identified in output!\n";
+ abort();
}
+ if (out->nodes_.back().node_hash != in.nodes_.back().node_hash) {
+ cerr << "Union: Goal nodes are mismatched!\n";
+ abort();
+ }
+ const int cgoal = out->nodes_.back().id_;
- for (unsigned i = 0; i < in.edges_.size(); ++i) {
- const Hypergraph::Edge& oe = in.edges_[i];
- Hypergraph::Edge& ce = out->edges_[i + eoff];
- ce.id_ = i + eoff;
- ce.rule_ = oe.rule_;
- ce.feature_values_ = oe.feature_values_;
- if (oe.head_node_ == ogoal) {
- ce.head_node_ = cgoal;
- out->nodes_[cgoal].in_edges_.push_back(ce.id_);
- } else {
- ce.head_node_ = oe.head_node_ + noff;
+ unordered_map<size_t, unsigned> h2n;
+ for (const auto& node : out->nodes_)
+ h2n[node.node_hash] = node.id_;
+ for (const auto& node : in.nodes_) {
+ if (h2n.count(node.node_hash) == 0) {
+ HG::Node* new_node = out->AddNode(node.cat_);
+ new_node->node_hash = node.node_hash;
+ h2n[node.node_hash] = new_node->id_;
}
- ce.tail_nodes_.resize(oe.tail_nodes_.size());
- for (unsigned j = 0; j < oe.tail_nodes_.size(); ++j)
- ce.tail_nodes_[j] = oe.tail_nodes_[j] + noff;
}
+ for (const auto& in_node : in.nodes_) {
+ HG::Node& out_node = out->nodes_[h2n[in_node.node_hash]];
+ for (const auto oeid : out_node.in_edges_) {
+ // TODO hash currently existing edges for quick check for duplication
+ }
+ for (const auto ieid : in_node.in_edges_) {
+ const HG::Edge& in_edge = in.edges_[ieid];
+ // TODO: replace slow N^2 check with hashing
+ bool edge_exists = false;
+ for (const auto oeid : out_node.in_edges_) {
+ if (EdgesMatch(in_edge, in, out->edges_[oeid], *out)) {
+ edge_exists = true;
+ break;
+ }
+ }
+ if (!edge_exists) {
+ const unsigned arity = in_edge.tail_nodes_.size();
+ TailNodeVector t(arity);
+ HG::Node& head = out->nodes_[h2n[in_node.node_hash]];
+ for (unsigned i = 0; i < arity; ++i)
+ t[i] = h2n[in.nodes_[in_edge.tail_nodes_[i]].node_hash];
+ HG::Edge* new_edge = out->AddEdge(in_edge, t);
+ out->ConnectEdgeToHeadNode(new_edge, &head);
+ //cerr << "Created: " << new_edge->rule_->AsString() << " [head=" << new_edge->head_node_ << "]\n";
+ } //else {
+ // cerr << "Not created: " << in.edges_[ieid].rule_->AsString() << "\n";
+ //}
+ }
+ }
out->TopologicallySortNodesAndEdges(cgoal);
}