summaryrefslogtreecommitdiff
path: root/decoder/inside_outside.h
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/inside_outside.h')
-rw-r--r--decoder/inside_outside.h109
1 files changed, 86 insertions, 23 deletions
diff --git a/decoder/inside_outside.h b/decoder/inside_outside.h
index 62daca1f..128d89da 100644
--- a/decoder/inside_outside.h
+++ b/decoder/inside_outside.h
@@ -34,8 +34,9 @@ WeightType Inside(const Hypergraph& hg,
const int num_nodes = hg.nodes_.size();
std::vector<WeightType> dummy;
std::vector<WeightType>& inside_score = result ? *result : dummy;
+ inside_score.clear();
inside_score.resize(num_nodes);
- std::fill(inside_score.begin(), inside_score.end(), WeightType());
+// std::fill(inside_score.begin(), inside_score.end(), WeightType()); // clear handles
for (int i = 0; i < num_nodes; ++i) {
const Hypergraph::Node& cur_node = hg.nodes_[i];
WeightType* const cur_node_inside_score = &inside_score[i];
@@ -61,14 +62,17 @@ template<typename WeightType, typename WeightFunction>
void Outside(const Hypergraph& hg,
std::vector<WeightType>& inside_score,
std::vector<WeightType>* result,
- const WeightFunction& weight = WeightFunction()) {
+ const WeightFunction& weight = WeightFunction(),
+ WeightType scale_outside = WeightType(1)
+ ) {
assert(result);
const int num_nodes = hg.nodes_.size();
assert(inside_score.size() == num_nodes);
std::vector<WeightType>& outside_score = *result;
+ outside_score.clear();
outside_score.resize(num_nodes);
- std::fill(outside_score.begin(), outside_score.end(), WeightType());
- outside_score.back() = WeightType(1);
+// std::fill(outside_score.begin(), outside_score.end(), WeightType()); // cleared
+ outside_score.back() = scale_outside;
for (int i = num_nodes - 1; i >= 0; --i) {
const Hypergraph::Node& cur_node = hg.nodes_[i];
const WeightType& head_node_outside_score = outside_score[i];
@@ -94,6 +98,80 @@ void Outside(const Hypergraph& hg,
}
}
+template <class K> // obviously not all semirings have a multiplicative inverse
+struct OutsideNormalize {
+ bool enable;
+ OutsideNormalize(bool enable=true) : enable(enable) {}
+ K operator()(K k) { return enable?K(1)/k:K(1); }
+};
+template <class K>
+struct Outside1 {
+ K operator()(K) { return K(1); }
+};
+
+template <class KType>
+struct InsideOutsides {
+// typedef typename KWeightFunction::Weight KType;
+ typedef std::vector<KType> Ks;
+ Ks inside,outside;
+ KType root_inside() {
+ return inside.back();
+ }
+ InsideOutsides() { }
+ template <class KWeightFunction>
+ KType compute(Hypergraph const& hg,KWeightFunction const& kwf=KWeightFunction()) {
+ return compute(hg,Outside1<KType>(),kwf);
+ }
+ template <class KWeightFunction,class O1>
+ KType compute(Hypergraph const& hg,O1 outside1,KWeightFunction const& kwf=KWeightFunction()) {
+ typedef typename KWeightFunction::Weight KType2;
+ assert(sizeof(KType2)==sizeof(KType)); // why am I doing this? because I want to share the vectors used for tropical and prob_t semirings. should instead have separate value type from semiring operations? or suck it up and split the code calling in Prune* into 2 types (template)
+ typedef std::vector<KType2> K2s;
+ K2s &inside2=reinterpret_cast<K2s &>(inside);
+ Inside<KType2,KWeightFunction>(hg, &inside2, kwf);
+ KType scale=outside1(reinterpret_cast<KType const&>(inside2.back()));
+ Outside<KType2,KWeightFunction>(hg, inside2, reinterpret_cast<K2s *>(&outside), kwf, reinterpret_cast<KType2 const&>(scale));
+ return root_inside();
+ }
+// XWeightFunction::Result is result
+ template <class XWeightFunction>
+ typename XWeightFunction::Result expect(Hypergraph const& hg,XWeightFunction const& xwf=XWeightFunction()) {
+ typename XWeightFunction::Result x; // default constructor is semiring 0
+ for (int i = 0,num_nodes=hg.nodes_.size(); i < num_nodes; ++i) {
+ const Hypergraph::Node& cur_node = hg.nodes_[i];
+ const int num_in_edges = cur_node.in_edges_.size();
+ for (int j = 0; j < num_in_edges; ++j) {
+ const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]];
+ KType kbar_e = outside[i];
+ const int num_tail_nodes = edge.tail_nodes_.size();
+ for (int k = 0; k < num_tail_nodes; ++k)
+ kbar_e *= inside[edge.tail_nodes_[k]];
+ x += xwf(edge) * kbar_e;
+ }
+ }
+ return x;
+ }
+ template <class V,class VWeight>
+ void compute_edge_marginals(Hypergraph const& hg,std::vector<V> &vs,VWeight const& weight) {
+ vs.resize(hg.edges_.size());
+ for (int i = 0,num_nodes=hg.nodes_.size(); i < num_nodes; ++i) {
+ const Hypergraph::Node& cur_node = hg.nodes_[i];
+ const int num_in_edges = cur_node.in_edges_.size();
+ for (int j = 0; j < num_in_edges; ++j) {
+ int edgei=cur_node.in_edges_[j];
+ const Hypergraph::Edge& edge = hg.edges_[edgei];
+ V x=weight(edge)*outside[i];
+ const int num_tail_nodes = edge.tail_nodes_.size();
+ for (int k = 0; k < num_tail_nodes; ++k)
+ x *= inside[edge.tail_nodes_[k]];
+ vs[edgei] = x;
+ }
+ }
+ }
+
+};
+
+
// this is the Inside-Outside optimization described in Li and Eisner (EMNLP 2009)
// for computing the inside algorithm over expensive semirings
// (such as expectations over features). See Figure 4.
@@ -105,25 +183,10 @@ KType InsideOutside(const Hypergraph& hg,
XType* result_x,
const KWeightFunction& kwf = KWeightFunction(),
const XWeightFunction& xwf = XWeightFunction()) {
- const int num_nodes = hg.nodes_.size();
- std::vector<KType> inside, outside;
- const KType k = Inside<KType,KWeightFunction>(hg, &inside, kwf);
- Outside<KType,KWeightFunction>(hg, inside, &outside, kwf);
- XType& x = *result_x;
- x = XType(); // default constructor is semiring 0
- for (int i = 0; i < num_nodes; ++i) {
- const Hypergraph::Node& cur_node = hg.nodes_[i];
- const int num_in_edges = cur_node.in_edges_.size();
- for (int j = 0; j < num_in_edges; ++j) {
- const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]];
- KType kbar_e = outside[i];
- const int num_tail_nodes = edge.tail_nodes_.size();
- for (int k = 0; k < num_tail_nodes; ++k)
- kbar_e *= inside[edge.tail_nodes_[k]];
- x += xwf(edge) * kbar_e;
- }
- }
- return k;
+ InsideOutsides<KType> io;
+ io.compute(hg,kwf);
+ *result_x=io.expect(hg,xwf);
+ return io.root_inside();
}
#endif