summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2014-04-09 20:20:13 -0400
committerChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2014-04-09 20:20:13 -0400
commit6f6c219c5fe032869d8f2235726fb3b9b38c4354 (patch)
treead77d662db1af40c5e0d95ec5d64a1c9326ff485 /decoder
parent74401769fdb8b16f44df8911070b7ae091de5fef (diff)
parent5f3ec63bf30459d97ad9e61d4e9b3b734a4867bf (diff)
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'decoder')
-rw-r--r--decoder/hg.cc1
-rw-r--r--decoder/hg_test.cc17
-rw-r--r--decoder/hg_union.cc116
-rw-r--r--decoder/node_state_hash.h10
4 files changed, 108 insertions, 36 deletions
diff --git a/decoder/hg.cc b/decoder/hg.cc
index 405169c6..e456fa7c 100644
--- a/decoder/hg.cc
+++ b/decoder/hg.cc
@@ -396,6 +396,7 @@ void Hypergraph::PrintGraphviz() const {
for (const auto& node : nodes_) {
cerr << " " << node.id_ << "[label=\"" << (node.cat_ < 0 ? TD::Convert(node.cat_ * -1) : "")
<< " n=" << node.id_
+ << " h=" << node.node_hash
<< "\"];\n";
}
cerr << "}\n";
diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc
index 95cfae51..5cb8626a 100644
--- a/decoder/hg_test.cc
+++ b/decoder/hg_test.cc
@@ -44,6 +44,13 @@ BOOST_AUTO_TEST_CASE(Union) {
Hypergraph hg2;
CreateHG_tiny(path, &hg1);
CreateHG(path, &hg2);
+ int nc = 0;
+ for (auto& node: hg1.nodes_)
+ node.node_hash = ++nc;
+ for (auto& node: hg2.nodes_)
+ node.node_hash = ++nc;
+ hg1.nodes_.back().node_hash = nc;
+
SparseVector<double> wts;
wts.set_value(FD::Convert("f1"), 0.4);
wts.set_value(FD::Convert("f2"), 1.0);
@@ -56,8 +63,11 @@ BOOST_AUTO_TEST_CASE(Union) {
int l2 = ViterbiPathLength(hg2);
cerr << c1 << "\t" << TD::GetString(t1) << endl;
cerr << c2 << "\t" << TD::GetString(t2) << endl;
+ hg1.PrintGraphviz();
+ hg2.PrintGraphviz();
HG::Union(hg2, &hg1);
hg1.Reweight(wts);
+ hg1.PrintGraphviz();
c3 = ViterbiESentence(hg1, &t3);
int l3 = ViterbiPathLength(hg1);
cerr << c3 << "\t" << TD::GetString(t3) << endl;
@@ -84,6 +94,13 @@ BOOST_AUTO_TEST_CASE(Union) {
BOOST_CHECK_CLOSE(log(list[0].second), log(c4), 1e-4);
BOOST_CHECK_EQUAL(list.size(), 6);
BOOST_CHECK_CLOSE(log(list.back().second / list.front().second), -97.7, 1e-4);
+ hg1 = hg2;
+ BOOST_CHECK_EQUAL(hg1.nodes_.size(), hg2.nodes_.size());
+ BOOST_CHECK_EQUAL(hg1.edges_.size(), hg2.edges_.size());
+ HG::Union(hg1, &hg2); // this should be a no-op
+ BOOST_CHECK_EQUAL(hg1.nodes_.size(), hg2.nodes_.size());
+ BOOST_CHECK_EQUAL(hg1.edges_.size(), hg2.edges_.size());
+ cerr << "DONE UNION\n";
}
BOOST_AUTO_TEST_CASE(ControlledKBest) {
diff --git a/decoder/hg_union.cc b/decoder/hg_union.cc
index 37082976..a659b6bc 100644
--- a/decoder/hg_union.cc
+++ b/decoder/hg_union.cc
@@ -1,56 +1,104 @@
#include "hg_union.h"
+#ifndef HAVE_OLD_CPP
+# include <unordered_map>
+#else
+# include <tr1/unordered_map>
+namespace std { using std::tr1::unordered_set; }
+#endif
+
+#include "verbose.h"
#include "hg.h"
+#include "sparse_vector.h"
using namespace std;
namespace HG {
+static bool EdgesMatch(const HG::Edge& a, const Hypergraph& ahg, const HG::Edge& b, const Hypergraph& bhg) {
+ const unsigned arity = a.tail_nodes_.size();
+ if (arity != b.tail_nodes_.size()) return false;
+ if (a.rule_->e() != b.rule_->e()) return false;
+ if (a.rule_->f() != b.rule_->f()) return false;
+
+ for (unsigned i = 0; i < arity; ++i)
+ if (ahg.nodes_[a.tail_nodes_[i]].node_hash != bhg.nodes_[b.tail_nodes_[i]].node_hash) return false;
+ const SparseVector<double> diff = a.feature_values_ - b.feature_values_;
+ for (auto& kv : diff)
+ if (fabs(kv.second) > 1e-6) return false;
+ return true;
+}
+
void Union(const Hypergraph& in, Hypergraph* out) {
if (&in == out) return;
if (out->nodes_.empty()) {
out->nodes_ = in.nodes_;
out->edges_ = in.edges_; return;
}
- unsigned noff = out->nodes_.size();
- unsigned eoff = out->edges_.size();
- int ogoal = in.nodes_.size() - 1;
- int cgoal = noff - 1;
- // keep a single goal node, so add nodes.size - 1
- out->nodes_.resize(out->nodes_.size() + ogoal);
- // add all edges
- out->edges_.resize(out->edges_.size() + in.edges_.size());
-
- for (int i = 0; i < ogoal; ++i) {
- const Hypergraph::Node& on = in.nodes_[i];
- Hypergraph::Node& cn = out->nodes_[i + noff];
- cn.id_ = i + noff;
- cn.in_edges_.resize(on.in_edges_.size());
- for (unsigned j = 0; j < on.in_edges_.size(); ++j)
- cn.in_edges_[j] = on.in_edges_[j] + eoff;
-
- cn.out_edges_.resize(on.out_edges_.size());
- for (unsigned j = 0; j < on.out_edges_.size(); ++j)
- cn.out_edges_[j] = on.out_edges_[j] + eoff;
+ if (!in.AreNodesUniquelyIdentified()) {
+ cerr << "Union: Nodes are not uniquely identified in input!\n";
+ abort();
+ }
+ if (!out->AreNodesUniquelyIdentified()) {
+ cerr << "Union: Nodes are not uniquely identified in output!\n";
+ abort();
}
+ if (out->nodes_.back().node_hash != in.nodes_.back().node_hash) {
+ cerr << "Union: Goal nodes are mismatched!\n a=" << in.nodes_.back().node_hash << " b=" << out->nodes_.back().node_hash << "\n";
+ abort();
+ }
+ const int cgoal = out->nodes_.back().id_;
- for (unsigned i = 0; i < in.edges_.size(); ++i) {
- const Hypergraph::Edge& oe = in.edges_[i];
- Hypergraph::Edge& ce = out->edges_[i + eoff];
- ce.id_ = i + eoff;
- ce.rule_ = oe.rule_;
- ce.feature_values_ = oe.feature_values_;
- if (oe.head_node_ == ogoal) {
- ce.head_node_ = cgoal;
- out->nodes_[cgoal].in_edges_.push_back(ce.id_);
- } else {
- ce.head_node_ = oe.head_node_ + noff;
+ unordered_map<size_t, unsigned> h2n;
+ for (const auto& node : out->nodes_)
+ h2n[node.node_hash] = node.id_;
+ for (const auto& node : in.nodes_) {
+ if (h2n.count(node.node_hash) == 0) {
+ HG::Node* new_node = out->AddNode(node.cat_);
+ new_node->node_hash = node.node_hash;
+ h2n[node.node_hash] = new_node->id_;
}
- ce.tail_nodes_.resize(oe.tail_nodes_.size());
- for (unsigned j = 0; j < oe.tail_nodes_.size(); ++j)
- ce.tail_nodes_[j] = oe.tail_nodes_[j] + noff;
}
+ double n_exists = 0;
+ double n_created = 0;
+ for (const auto& in_node : in.nodes_) {
+ HG::Node& out_node = out->nodes_[h2n[in_node.node_hash]];
+ for (const auto oeid : out_node.in_edges_) {
+ // TODO hash currently existing edges for quick check for duplication
+ }
+ for (const auto ieid : in_node.in_edges_) {
+ const HG::Edge& in_edge = in.edges_[ieid];
+ // TODO: replace slow N^2 check with hashing
+ bool edge_exists = false;
+ for (const auto oeid : out_node.in_edges_) {
+ if (EdgesMatch(in_edge, in, out->edges_[oeid], *out)) {
+ edge_exists = true;
+ break;
+ }
+ }
+ if (!edge_exists) {
+ const unsigned arity = in_edge.tail_nodes_.size();
+ TailNodeVector t(arity);
+ HG::Node& head = out->nodes_[h2n[in_node.node_hash]];
+ for (unsigned i = 0; i < arity; ++i)
+ t[i] = h2n[in.nodes_[in_edge.tail_nodes_[i]].node_hash];
+ HG::Edge* new_edge = out->AddEdge(in_edge, t);
+ out->ConnectEdgeToHeadNode(new_edge, &head);
+ ++n_created;
+ //cerr << "Created: " << new_edge->rule_->AsString() << " [head=" << new_edge->head_node_ << "]\n";
+ } else {
+ ++n_exists;
+ }
+ // cerr << "Not created: " << in.edges_[ieid].rule_->AsString() << "\n";
+ //}
+ }
+ }
+ if (!SILENT)
+ cerr << " Union: edges_created=" << n_created
+ << " edges_already_existing="
+ << n_exists << " ratio_new=" << (n_created / (n_exists + n_created))
+ << endl;
out->TopologicallySortNodesAndEdges(cgoal);
}
diff --git a/decoder/node_state_hash.h b/decoder/node_state_hash.h
index cdc05877..9fc01a09 100644
--- a/decoder/node_state_hash.h
+++ b/decoder/node_state_hash.h
@@ -3,14 +3,19 @@
#include <cassert>
#include <cstring>
+#include "tdict.h"
#include "murmur_hash3.h"
#include "ffset.h"
namespace cdec {
struct FirstPassNode {
- FirstPassNode(int cat, int i, int j, int pi, int pj) : lhs(cat), s(i), t(j), u(pi), v(pj) {}
- int32_t lhs;
+ FirstPassNode(int cat, int i, int j, int pi, int pj) : s(i), t(j), u(pi), v(pj) {
+ memset(lhs, 0, 120);
+ unsigned it = 0;
+ for (auto& c : TD::Convert(-cat)) { lhs[it++] = c; if (it == 120) break; }
+ }
+ char lhs[120];
short s;
short t;
short u;
@@ -23,6 +28,7 @@ namespace cdec {
}
inline uint64_t HashNode(uint64_t old_hash, const FFState& state) {
+ if (state.size() == 0) return old_hash;
uint8_t buf[1024];
std::memcpy(buf, &old_hash, sizeof(uint64_t));
assert(state.size() < (1024u - sizeof(uint64_t)));