summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-08-12 23:33:21 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-08-12 23:33:21 -0400
commitda176941c1f481f14e93bd7d055cc29cac0ea8c8 (patch)
treec7ec8c0f75b386e6ca6d37da830e5a2e369b1cca /decoder
parent4760209baa483403db3bcb9bf1a32ae87a7b576d (diff)
use new union api
Diffstat (limited to 'decoder')
-rw-r--r--decoder/Makefile.am1
-rw-r--r--decoder/decoder.cc5
-rw-r--r--decoder/hg_test.cc7
-rw-r--r--decoder/hg_union.cc58
-rw-r--r--decoder/hg_union.h9
5 files changed, 75 insertions, 5 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index 0a792549..4a98a4f1 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -44,6 +44,7 @@ libcdec_a_SOURCES = \
hg_remove_eps.cc \
decoder.cc \
hg_intersect.cc \
+ hg_union.cc \
hg_sampler.cc \
factored_lexicon_helper.cc \
viterbi.cc \
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index a6f7b1ce..a69a6d05 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -24,6 +24,7 @@
#include "hg.h"
#include "sentence_metadata.h"
#include "hg_intersect.h"
+#include "hg_union.h"
#include "oracle_bleu.h"
#include "apply_models.h"
@@ -980,7 +981,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
if (!succeeded) abort();
}
- new_hg.Union(forest);
+ HG::Union(forest, &new_hg);
bool succeeded = writer.Write(new_hg, false);
if (!succeeded) abort();
} else {
@@ -1067,7 +1068,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
if (!succeeded) abort();
}
- new_hg.Union(forest);
+ HG::Union(forest, &new_hg);
bool succeeded = writer.Write(new_hg, false);
if (!succeeded) abort();
} else {
diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc
index 92ed98b2..c47af850 100644
--- a/decoder/hg_test.cc
+++ b/decoder/hg_test.cc
@@ -6,6 +6,7 @@
#include "json_parse.h"
#include "hg_intersect.h"
+#include "hg_union.h"
#include "viterbi.h"
#include "kbest.h"
#include "inside_outside.h"
@@ -52,7 +53,7 @@ BOOST_AUTO_TEST_CASE(Union) {
int l2 = ViterbiPathLength(hg2);
cerr << c1 << "\t" << TD::GetString(t1) << endl;
cerr << c2 << "\t" << TD::GetString(t2) << endl;
- hg1.Union(hg2);
+ HG::Union(hg2, &hg1);
hg1.Reweight(wts);
c3 = ViterbiESentence(hg1, &t3);
int l3 = ViterbiPathLength(hg1);
@@ -121,8 +122,8 @@ BOOST_AUTO_TEST_CASE(InsideScore) {
vector<prob_t> post;
inside = hg.ComputeBestPathThroughEdges(&post);
BOOST_CHECK_CLOSE(-0.3, log(inside), 1e-4); // computed by hand
- BOOST_CHECK_EQUAL(post.size(), 4);
- for (int i = 0; i < 4; ++i) {
+ BOOST_CHECK_EQUAL(post.size(), 5);
+ for (int i = 0; i < 5; ++i) {
cerr << "edge post: " << log(post[i]) << '\t' << hg.edges_[i].rule_->AsString() << endl;
}
}
diff --git a/decoder/hg_union.cc b/decoder/hg_union.cc
new file mode 100644
index 00000000..37082976
--- /dev/null
+++ b/decoder/hg_union.cc
@@ -0,0 +1,58 @@
+#include "hg_union.h"
+
+#include "hg.h"
+
+using namespace std;
+
+namespace HG {
+
+void Union(const Hypergraph& in, Hypergraph* out) {
+ if (&in == out) return;
+ if (out->nodes_.empty()) {
+ out->nodes_ = in.nodes_;
+ out->edges_ = in.edges_; return;
+ }
+ unsigned noff = out->nodes_.size();
+ unsigned eoff = out->edges_.size();
+ int ogoal = in.nodes_.size() - 1;
+ int cgoal = noff - 1;
+ // keep a single goal node, so add nodes.size - 1
+ out->nodes_.resize(out->nodes_.size() + ogoal);
+ // add all edges
+ out->edges_.resize(out->edges_.size() + in.edges_.size());
+
+ for (int i = 0; i < ogoal; ++i) {
+ const Hypergraph::Node& on = in.nodes_[i];
+ Hypergraph::Node& cn = out->nodes_[i + noff];
+ cn.id_ = i + noff;
+ cn.in_edges_.resize(on.in_edges_.size());
+ for (unsigned j = 0; j < on.in_edges_.size(); ++j)
+ cn.in_edges_[j] = on.in_edges_[j] + eoff;
+
+ cn.out_edges_.resize(on.out_edges_.size());
+ for (unsigned j = 0; j < on.out_edges_.size(); ++j)
+ cn.out_edges_[j] = on.out_edges_[j] + eoff;
+ }
+
+ for (unsigned i = 0; i < in.edges_.size(); ++i) {
+ const Hypergraph::Edge& oe = in.edges_[i];
+ Hypergraph::Edge& ce = out->edges_[i + eoff];
+ ce.id_ = i + eoff;
+ ce.rule_ = oe.rule_;
+ ce.feature_values_ = oe.feature_values_;
+ if (oe.head_node_ == ogoal) {
+ ce.head_node_ = cgoal;
+ out->nodes_[cgoal].in_edges_.push_back(ce.id_);
+ } else {
+ ce.head_node_ = oe.head_node_ + noff;
+ }
+ ce.tail_nodes_.resize(oe.tail_nodes_.size());
+ for (unsigned j = 0; j < oe.tail_nodes_.size(); ++j)
+ ce.tail_nodes_[j] = oe.tail_nodes_[j] + noff;
+ }
+
+ out->TopologicallySortNodesAndEdges(cgoal);
+}
+
+}
+
diff --git a/decoder/hg_union.h b/decoder/hg_union.h
new file mode 100644
index 00000000..34624246
--- /dev/null
+++ b/decoder/hg_union.h
@@ -0,0 +1,9 @@
+#ifndef _HG_UNION_H_
+#define _HG_UNION_H_
+
+class Hypergraph;
+namespace HG {
+ void Union(const Hypergraph& in, Hypergraph* out);
+};
+
+#endif