summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2014-04-08 23:16:24 -0400
committerChris Dyer <redpony@gmail.com>2014-04-08 23:16:24 -0400
commit71c1f8b274e4f0e83252fe3c68fb45c5ec4069e6 (patch)
treefc57496d413655c157985c9cc4a492ef791c5be2
parent7242963e683d7b3d4b6c49ac3814ced360ef10c8 (diff)
smarter union
-rw-r--r--decoder/hg.cc1
-rw-r--r--decoder/hg_test.cc17
-rw-r--r--decoder/hg_union.cc105
3 files changed, 89 insertions, 34 deletions
diff --git a/decoder/hg.cc b/decoder/hg.cc
index 405169c6..e456fa7c 100644
--- a/decoder/hg.cc
+++ b/decoder/hg.cc
@@ -396,6 +396,7 @@ void Hypergraph::PrintGraphviz() const {
for (const auto& node : nodes_) {
cerr << " " << node.id_ << "[label=\"" << (node.cat_ < 0 ? TD::Convert(node.cat_ * -1) : "")
<< " n=" << node.id_
+ << " h=" << node.node_hash
<< "\"];\n";
}
cerr << "}\n";
diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc
index 95cfae51..5cb8626a 100644
--- a/decoder/hg_test.cc
+++ b/decoder/hg_test.cc
@@ -44,6 +44,13 @@ BOOST_AUTO_TEST_CASE(Union) {
Hypergraph hg2;
CreateHG_tiny(path, &hg1);
CreateHG(path, &hg2);
+ int nc = 0;
+ for (auto& node: hg1.nodes_)
+ node.node_hash = ++nc;
+ for (auto& node: hg2.nodes_)
+ node.node_hash = ++nc;
+ hg1.nodes_.back().node_hash = nc;
+
SparseVector<double> wts;
wts.set_value(FD::Convert("f1"), 0.4);
wts.set_value(FD::Convert("f2"), 1.0);
@@ -56,8 +63,11 @@ BOOST_AUTO_TEST_CASE(Union) {
int l2 = ViterbiPathLength(hg2);
cerr << c1 << "\t" << TD::GetString(t1) << endl;
cerr << c2 << "\t" << TD::GetString(t2) << endl;
+ hg1.PrintGraphviz();
+ hg2.PrintGraphviz();
HG::Union(hg2, &hg1);
hg1.Reweight(wts);
+ hg1.PrintGraphviz();
c3 = ViterbiESentence(hg1, &t3);
int l3 = ViterbiPathLength(hg1);
cerr << c3 << "\t" << TD::GetString(t3) << endl;
@@ -84,6 +94,13 @@ BOOST_AUTO_TEST_CASE(Union) {
BOOST_CHECK_CLOSE(log(list[0].second), log(c4), 1e-4);
BOOST_CHECK_EQUAL(list.size(), 6);
BOOST_CHECK_CLOSE(log(list.back().second / list.front().second), -97.7, 1e-4);
+ hg1 = hg2;
+ BOOST_CHECK_EQUAL(hg1.nodes_.size(), hg2.nodes_.size());
+ BOOST_CHECK_EQUAL(hg1.edges_.size(), hg2.edges_.size());
+ HG::Union(hg1, &hg2); // this should be a no-op
+ BOOST_CHECK_EQUAL(hg1.nodes_.size(), hg2.nodes_.size());
+ BOOST_CHECK_EQUAL(hg1.edges_.size(), hg2.edges_.size());
+ cerr << "DONE UNION\n";
}
BOOST_AUTO_TEST_CASE(ControlledKBest) {
diff --git a/decoder/hg_union.cc b/decoder/hg_union.cc
index 37082976..4899e716 100644
--- a/decoder/hg_union.cc
+++ b/decoder/hg_union.cc
@@ -1,56 +1,93 @@
#include "hg_union.h"
+#ifndef HAVE_OLD_CPP
+# include <unordered_map>
+#else
+# include <tr1/unordered_map>
+namespace std { using std::tr1::unordered_set; }
+#endif
+
#include "hg.h"
+#include "sparse_vector.h"
using namespace std;
namespace HG {
+static bool EdgesMatch(const HG::Edge& a, const Hypergraph& ahg, const HG::Edge& b, const Hypergraph& bhg) {
+ const unsigned arity = a.tail_nodes_.size();
+ if (arity != b.tail_nodes_.size()) return false;
+ if (a.rule_->e() != b.rule_->e()) return false;
+ if (a.rule_->f() != b.rule_->f()) return false;
+
+ for (unsigned i = 0; i < arity; ++i)
+ if (ahg.nodes_[a.tail_nodes_[i]].node_hash != bhg.nodes_[b.tail_nodes_[i]].node_hash) return false;
+ const SparseVector<double> diff = a.feature_values_ - b.feature_values_;
+ for (auto& kv : diff)
+ if (fabs(kv.second) > 1e-6) return false;
+ return true;
+}
+
void Union(const Hypergraph& in, Hypergraph* out) {
if (&in == out) return;
if (out->nodes_.empty()) {
out->nodes_ = in.nodes_;
out->edges_ = in.edges_; return;
}
- unsigned noff = out->nodes_.size();
- unsigned eoff = out->edges_.size();
- int ogoal = in.nodes_.size() - 1;
- int cgoal = noff - 1;
- // keep a single goal node, so add nodes.size - 1
- out->nodes_.resize(out->nodes_.size() + ogoal);
- // add all edges
- out->edges_.resize(out->edges_.size() + in.edges_.size());
-
- for (int i = 0; i < ogoal; ++i) {
- const Hypergraph::Node& on = in.nodes_[i];
- Hypergraph::Node& cn = out->nodes_[i + noff];
- cn.id_ = i + noff;
- cn.in_edges_.resize(on.in_edges_.size());
- for (unsigned j = 0; j < on.in_edges_.size(); ++j)
- cn.in_edges_[j] = on.in_edges_[j] + eoff;
-
- cn.out_edges_.resize(on.out_edges_.size());
- for (unsigned j = 0; j < on.out_edges_.size(); ++j)
- cn.out_edges_[j] = on.out_edges_[j] + eoff;
+ if (!in.AreNodesUniquelyIdentified()) {
+ cerr << "Union: Nodes are not uniquely identified in input!\n";
+ abort();
+ }
+ if (!out->AreNodesUniquelyIdentified()) {
+ cerr << "Union: Nodes are not uniquely identified in output!\n";
+ abort();
}
+ if (out->nodes_.back().node_hash != in.nodes_.back().node_hash) {
+ cerr << "Union: Goal nodes are mismatched!\n";
+ abort();
+ }
+ const int cgoal = out->nodes_.back().id_;
- for (unsigned i = 0; i < in.edges_.size(); ++i) {
- const Hypergraph::Edge& oe = in.edges_[i];
- Hypergraph::Edge& ce = out->edges_[i + eoff];
- ce.id_ = i + eoff;
- ce.rule_ = oe.rule_;
- ce.feature_values_ = oe.feature_values_;
- if (oe.head_node_ == ogoal) {
- ce.head_node_ = cgoal;
- out->nodes_[cgoal].in_edges_.push_back(ce.id_);
- } else {
- ce.head_node_ = oe.head_node_ + noff;
+ unordered_map<size_t, unsigned> h2n;
+ for (const auto& node : out->nodes_)
+ h2n[node.node_hash] = node.id_;
+ for (const auto& node : in.nodes_) {
+ if (h2n.count(node.node_hash) == 0) {
+ HG::Node* new_node = out->AddNode(node.cat_);
+ new_node->node_hash = node.node_hash;
+ h2n[node.node_hash] = new_node->id_;
}
- ce.tail_nodes_.resize(oe.tail_nodes_.size());
- for (unsigned j = 0; j < oe.tail_nodes_.size(); ++j)
- ce.tail_nodes_[j] = oe.tail_nodes_[j] + noff;
}
+ for (const auto& in_node : in.nodes_) {
+ HG::Node& out_node = out->nodes_[h2n[in_node.node_hash]];
+ for (const auto oeid : out_node.in_edges_) {
+ // TODO hash currently existing edges for quick check for duplication
+ }
+ for (const auto ieid : in_node.in_edges_) {
+ const HG::Edge& in_edge = in.edges_[ieid];
+ // TODO: replace slow N^2 check with hashing
+ bool edge_exists = false;
+ for (const auto oeid : out_node.in_edges_) {
+ if (EdgesMatch(in_edge, in, out->edges_[oeid], *out)) {
+ edge_exists = true;
+ break;
+ }
+ }
+ if (!edge_exists) {
+ const unsigned arity = in_edge.tail_nodes_.size();
+ TailNodeVector t(arity);
+ HG::Node& head = out->nodes_[h2n[in_node.node_hash]];
+ for (unsigned i = 0; i < arity; ++i)
+ t[i] = h2n[in.nodes_[in_edge.tail_nodes_[i]].node_hash];
+ HG::Edge* new_edge = out->AddEdge(in_edge, t);
+ out->ConnectEdgeToHeadNode(new_edge, &head);
+ //cerr << "Created: " << new_edge->rule_->AsString() << " [head=" << new_edge->head_node_ << "]\n";
+ } //else {
+ // cerr << "Not created: " << in.edges_[ieid].rule_->AsString() << "\n";
+ //}
+ }
+ }
out->TopologicallySortNodesAndEdges(cgoal);
}