summaryrefslogtreecommitdiff
path: root/rst_parser/arc_factored.cc
blob: 74bf7516f037046914a73e9a1f080468bbac87c6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#include "arc_factored.h"

#include <set>
#include <tr1/unordered_set>

#include <boost/pending/disjoint_sets.hpp>
#include <boost/functional/hash.hpp>

#include "arc_ff.h"

using namespace std;
using namespace std::tr1;
using namespace boost;

void EdgeSubset::ExtractFeatures(const TaggedSentence& sentence,
                                 const ArcFeatureFunctions& ffs,
                                 SparseVector<double>* features) const {
  SparseVector<weight_t> efmap;
  for (int j = 0; j < h_m_pairs.size(); ++j) {
    efmap.clear();
    ffs.EdgeFeatures(sentence, h_m_pairs[j].first,
                     h_m_pairs[j].second,
                     &efmap);
    (*features) += efmap;
  }
  for (int j = 0; j < roots.size(); ++j) {
    efmap.clear();
    ffs.EdgeFeatures(sentence, -1, roots[j], &efmap);
    (*features) += efmap;
  }
}

void ArcFactoredForest::ExtractFeatures(const TaggedSentence& sentence,
                                        const ArcFeatureFunctions& ffs) {
  for (int m = 0; m < num_words_; ++m) {
    for (int h = 0; h < num_words_; ++h) {
      ffs.EdgeFeatures(sentence, h, m, &edges_(h,m).features);
    }
    ffs.EdgeFeatures(sentence, -1, m, &root_edges_[m].features);
  }
}

void ArcFactoredForest::PickBestParentForEachWord(EdgeSubset* st) const {
  for (int m = 0; m < num_words_; ++m) {
    int best_head = -2;
    prob_t best_score;
    for (int h = -1; h < num_words_; ++h) {
      const Edge& edge = (*this)(h,m);
      if (best_head < -1 || edge.edge_prob > best_score) {
        best_score = edge.edge_prob;
        best_head = h;
      }
    }
    assert(best_head >= -1);
    if (best_head >= 0)
      st->h_m_pairs.push_back(make_pair<short,short>(best_head, m));
    else
      st->roots.push_back(m);
  }
}

struct WeightedEdge {
  WeightedEdge() : h(), m(), weight() {}
  WeightedEdge(short hh, short mm, float w) : h(hh), m(mm), weight(w) {}
  short h, m;
  float weight;
  inline bool operator==(const WeightedEdge& o) const {
    return h == o.h && m == o.m && weight == o.weight;
  }
  inline bool operator!=(const WeightedEdge& o) const {
    return h != o.h || m != o.m || weight != o.weight;
  }
};
inline bool operator<(const WeightedEdge& l, const WeightedEdge& o) { return l.weight < o.weight; }
inline size_t hash_value(const WeightedEdge& e) { return reinterpret_cast<const size_t&>(e); }


struct PriorityQueue {
  void push(const WeightedEdge& e) {}
  const WeightedEdge& top() const {
    static WeightedEdge w(1,2,3);
    return w;
  }
  void pop() {}
  void increment_all(float p) {}
};

// based on Trajan 1977
void ArcFactoredForest::MaximumSpanningTree(EdgeSubset* st) const {
  typedef disjoint_sets_with_storage<identity_property_map, identity_property_map,
      find_with_full_path_compression> DisjointSet;
  DisjointSet strongly(num_words_ + 1);
  DisjointSet weakly(num_words_ + 1);
  set<unsigned> roots, rset;
  unordered_set<WeightedEdge, boost::hash<WeightedEdge> > h;
  vector<PriorityQueue> qs(num_words_ + 1);
  vector<WeightedEdge> enter(num_words_ + 1);
  vector<unsigned> mins(num_words_ + 1);
  const WeightedEdge kDUMMY(0,0,0.0f);
  for (unsigned i = 0; i <= num_words_; ++i) {
    if (i > 0) {
      // I(i) incidence on i -- all incoming edges
      for (unsigned j = 0; j <= num_words_; ++j) {
        qs[i].push(WeightedEdge(j, i, Weight(j,i)));
      }
    }
    strongly.make_set(i);
    weakly.make_set(i);
    roots.insert(i);
    enter[i] = kDUMMY;
    mins[i] = i;
  }
  while(!roots.empty()) {
    set<unsigned>::iterator it = roots.begin();
    const unsigned k = *it;
    roots.erase(it);
    cerr << "k=" << k << endl;
    WeightedEdge ij = qs[k].top();  // MAX(k)
    qs[k].pop();
    if (ij.weight <= 0) {
      rset.insert(k);
    } else {
      if (strongly.find_set(ij.h) == k) {
        roots.insert(k);
      } else {
        h.insert(ij);
        if (weakly.find_set(ij.h) != weakly.find_set(ij.m)) {
          weakly.union_set(ij.h, ij.m);
          enter[k] = ij;
        } else {
          unsigned vertex = 0;
          float val = 99999999999;
          WeightedEdge xy = ij;
          while(xy != kDUMMY) {
            if (xy.weight < val) {
              val = xy.weight;
              vertex = strongly.find_set(xy.m);
            }
            xy = enter[strongly.find_set(xy.h)];
          }
          qs[k].increment_all(val - ij.weight);
          mins[k] = mins[vertex];
          xy = enter[strongly.find_set(ij.h)];
          while (xy != kDUMMY) {
          }
        }
      }
    }
  }
}