summaryrefslogtreecommitdiff
path: root/rst_parser/arc_factored.cc
blob: b2c2c4274a777a08655b888bc8b0fb4edcb28cf8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include "arc_factored.h"

#include <set>
#include <tr1/unordered_set>

#include <boost/pending/disjoint_sets.hpp>
#include <boost/functional/hash.hpp>

using namespace std;
using namespace std::tr1;
using namespace boost;

void ArcFactoredForest::PickBestParentForEachWord(EdgeSubset* st) const {
  for (int m = 1; m <= num_words_; ++m) {
    int best_head = -1;
    prob_t best_score;
    for (int h = 0; h <= num_words_; ++h) {
      const Edge& edge = (*this)(h,m);
      if (best_head < 0 || edge.edge_prob > best_score) {
        best_score = edge.edge_prob;
        best_head = h;
      }
    }
    assert(best_head >= 0); 
    if (best_head)
      st->h_m_pairs.push_back(make_pair<short,short>(best_head, m));
    else
      st->roots.push_back(m);
  }
}

struct WeightedEdge {
  WeightedEdge() : h(), m(), weight() {}
  WeightedEdge(short hh, short mm, float w) : h(hh), m(mm), weight(w) {}
  short h, m;
  float weight;
  inline bool operator==(const WeightedEdge& o) const {
    return h == o.h && m == o.m && weight == o.weight;
  }
  inline bool operator!=(const WeightedEdge& o) const {
    return h != o.h || m != o.m || weight != o.weight;
  }
};
inline bool operator<(const WeightedEdge& l, const WeightedEdge& o) { return l.weight < o.weight; }
inline size_t hash_value(const WeightedEdge& e) { return reinterpret_cast<const size_t&>(e); }


struct PriorityQueue {
  void push(const WeightedEdge& e) {}
  const WeightedEdge& top() const {
    static WeightedEdge w(1,2,3);
    return w;
  }
  void pop() {}
  void increment_all(float p) {}
};

// based on Trajan 1977
void ArcFactoredForest::MaximumEdgeSubset(EdgeSubset* st) const {
  typedef disjoint_sets_with_storage<identity_property_map, identity_property_map,
      find_with_full_path_compression> DisjointSet;
  DisjointSet strongly(num_words_ + 1);
  DisjointSet weakly(num_words_ + 1);
  set<unsigned> roots, rset;
  unordered_set<WeightedEdge, boost::hash<WeightedEdge> > h;
  vector<PriorityQueue> qs(num_words_ + 1);
  vector<WeightedEdge> enter(num_words_ + 1);
  vector<unsigned> mins(num_words_ + 1);
  const WeightedEdge kDUMMY(0,0,0.0f);
  for (unsigned i = 0; i <= num_words_; ++i) {
    if (i > 0) {
      // I(i) incidence on i -- all incoming edges
      for (unsigned j = 0; j <= num_words_; ++j) {
        qs[i].push(WeightedEdge(j, i, Weight(j,i)));
      }
    }
    strongly.make_set(i);
    weakly.make_set(i);
    roots.insert(i);
    enter[i] = kDUMMY;
    mins[i] = i;
  }
  while(!roots.empty()) {
    set<unsigned>::iterator it = roots.begin();
    const unsigned k = *it;
    roots.erase(it);
    cerr << "k=" << k << endl;
    WeightedEdge ij = qs[k].top();  // MAX(k)
    qs[k].pop();
    if (ij.weight <= 0) {
      rset.insert(k);
    } else {
      if (strongly.find_set(ij.h) == k) {
        roots.insert(k);
      } else {
        h.insert(ij);
        if (weakly.find_set(ij.h) != weakly.find_set(ij.m)) {
          weakly.union_set(ij.h, ij.m);
          enter[k] = ij;
        } else {
          unsigned vertex = 0;
          float val = 99999999999;
          WeightedEdge xy = ij;
          while(xy != kDUMMY) {
            if (xy.weight < val) {
              val = xy.weight;
              vertex = strongly.find_set(xy.m);
            }
            xy = enter[strongly.find_set(xy.h)];
          }
          qs[k].increment_all(val - ij.weight);
          mins[k] = mins[vertex];
          xy = enter[strongly.find_set(ij.h)];
          while (xy != kDUMMY) {
          }
        }
      }
    }
  }
}