From b6e70b420ed993ee73f71058d04b382147896068 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 12 Aug 2012 23:33:21 -0400 Subject: use new union api --- decoder/Makefile.am | 1 + decoder/decoder.cc | 5 +++-- decoder/hg_test.cc | 7 ++++--- decoder/hg_union.cc | 58 +++++++++++++++++++++++++++++++++++++++++++++++++++++ decoder/hg_union.h | 9 +++++++++ 5 files changed, 75 insertions(+), 5 deletions(-) create mode 100644 decoder/hg_union.cc create mode 100644 decoder/hg_union.h (limited to 'decoder') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 0a792549..4a98a4f1 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -44,6 +44,7 @@ libcdec_a_SOURCES = \ hg_remove_eps.cc \ decoder.cc \ hg_intersect.cc \ + hg_union.cc \ hg_sampler.cc \ factored_lexicon_helper.cc \ viterbi.cc \ diff --git a/decoder/decoder.cc b/decoder/decoder.cc index a6f7b1ce..a69a6d05 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -24,6 +24,7 @@ #include "hg.h" #include "sentence_metadata.h" #include "hg_intersect.h" +#include "hg_union.h" #include "oracle_bleu.h" #include "apply_models.h" @@ -980,7 +981,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); if (!succeeded) abort(); } - new_hg.Union(forest); + HG::Union(forest, &new_hg); bool succeeded = writer.Write(new_hg, false); if (!succeeded) abort(); } else { @@ -1067,7 +1068,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); if (!succeeded) abort(); } - new_hg.Union(forest); + HG::Union(forest, &new_hg); bool succeeded = writer.Write(new_hg, false); if (!succeeded) abort(); } else { diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc index 92ed98b2..c47af850 100644 --- a/decoder/hg_test.cc +++ b/decoder/hg_test.cc @@ -6,6 +6,7 @@ #include "json_parse.h" #include "hg_intersect.h" +#include "hg_union.h" #include "viterbi.h" #include "kbest.h" #include "inside_outside.h" @@ -52,7 +53,7 @@ BOOST_AUTO_TEST_CASE(Union) { int l2 = ViterbiPathLength(hg2); cerr << c1 << "\t" << TD::GetString(t1) << endl; cerr << c2 << "\t" << TD::GetString(t2) << endl; - hg1.Union(hg2); + HG::Union(hg2, &hg1); hg1.Reweight(wts); c3 = ViterbiESentence(hg1, &t3); int l3 = ViterbiPathLength(hg1); @@ -121,8 +122,8 @@ BOOST_AUTO_TEST_CASE(InsideScore) { vector post; inside = hg.ComputeBestPathThroughEdges(&post); BOOST_CHECK_CLOSE(-0.3, log(inside), 1e-4); // computed by hand - BOOST_CHECK_EQUAL(post.size(), 4); - for (int i = 0; i < 4; ++i) { + BOOST_CHECK_EQUAL(post.size(), 5); + for (int i = 0; i < 5; ++i) { cerr << "edge post: " << log(post[i]) << '\t' << hg.edges_[i].rule_->AsString() << endl; } } diff --git a/decoder/hg_union.cc b/decoder/hg_union.cc new file mode 100644 index 00000000..37082976 --- /dev/null +++ b/decoder/hg_union.cc @@ -0,0 +1,58 @@ +#include "hg_union.h" + +#include "hg.h" + +using namespace std; + +namespace HG { + +void Union(const Hypergraph& in, Hypergraph* out) { + if (&in == out) return; + if (out->nodes_.empty()) { + out->nodes_ = in.nodes_; + out->edges_ = in.edges_; return; + } + unsigned noff = out->nodes_.size(); + unsigned eoff = out->edges_.size(); + int ogoal = in.nodes_.size() - 1; + int cgoal = noff - 1; + // keep a single goal node, so add nodes.size - 1 + out->nodes_.resize(out->nodes_.size() + ogoal); + // add all edges + out->edges_.resize(out->edges_.size() + in.edges_.size()); + + for (int i = 0; i < ogoal; ++i) { + const Hypergraph::Node& on = in.nodes_[i]; + Hypergraph::Node& cn = out->nodes_[i + noff]; + cn.id_ = i + noff; + cn.in_edges_.resize(on.in_edges_.size()); + for (unsigned j = 0; j < on.in_edges_.size(); ++j) + cn.in_edges_[j] = on.in_edges_[j] + eoff; + + cn.out_edges_.resize(on.out_edges_.size()); + for (unsigned j = 0; j < on.out_edges_.size(); ++j) + cn.out_edges_[j] = on.out_edges_[j] + eoff; + } + + for (unsigned i = 0; i < in.edges_.size(); ++i) { + const Hypergraph::Edge& oe = in.edges_[i]; + Hypergraph::Edge& ce = out->edges_[i + eoff]; + ce.id_ = i + eoff; + ce.rule_ = oe.rule_; + ce.feature_values_ = oe.feature_values_; + if (oe.head_node_ == ogoal) { + ce.head_node_ = cgoal; + out->nodes_[cgoal].in_edges_.push_back(ce.id_); + } else { + ce.head_node_ = oe.head_node_ + noff; + } + ce.tail_nodes_.resize(oe.tail_nodes_.size()); + for (unsigned j = 0; j < oe.tail_nodes_.size(); ++j) + ce.tail_nodes_[j] = oe.tail_nodes_[j] + noff; + } + + out->TopologicallySortNodesAndEdges(cgoal); +} + +} + diff --git a/decoder/hg_union.h b/decoder/hg_union.h new file mode 100644 index 00000000..34624246 --- /dev/null +++ b/decoder/hg_union.h @@ -0,0 +1,9 @@ +#ifndef _HG_UNION_H_ +#define _HG_UNION_H_ + +class Hypergraph; +namespace HG { + void Union(const Hypergraph& in, Hypergraph* out); +}; + +#endif -- cgit v1.2.3