summaryrefslogtreecommitdiff
path: root/decoder/tree_fragment.h
blob: b1dbbae09cd3440e9521bd07dc6a071a0c829a8b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#ifndef TREE_FRAGMENT
#define TREE_FRAGMENT

#include <queue>
#include <iostream>
#include <vector>
#include <string>

#include "tdict.h"

namespace cdec {

class BreadthFirstIterator;

static const unsigned NT_BIT       = 0x40000000u;
static const unsigned FRONTIER_BIT = 0x80000000u;
static const unsigned ALL_MASK     = 0x0FFFFFFFu;

inline bool IsInternalNT(unsigned x) {
  return (x & NT_BIT);
}

inline bool IsFrontier(unsigned x) {
  return (x & FRONTIER_BIT);
}

struct TreeFragmentProduction {
  TreeFragmentProduction() {}
  TreeFragmentProduction(int nttype, const std::vector<unsigned>& r) : lhs(nttype), rhs(r) {}
  unsigned lhs;
  std::vector<unsigned> rhs;
};

// this data structure represents a tree or forest
// productions can have mixtures of terminals and nonterminal symbols and non-terminal frontier sites
class TreeFragment {
 public:
  TreeFragment() : frontier_sites(), terminals() {}
  // (S (NP a (X b) c d) (VP (V foo) (NP (NN bar))))
  explicit TreeFragment(const std::string& tree, bool allow_frontier_sites = false);
  void DebugRec(unsigned cur, std::ostream* out) const;
  typedef BreadthFirstIterator iterator;
  typedef ptrdiff_t difference_type;
  typedef unsigned value_type;
  typedef const unsigned * pointer;
  typedef const unsigned & reference;

  iterator begin() const;
  iterator end() const;

 private:
  // cp is the character index in the tree
  // np keeps track of the nodes (nonterminals) that have been built
  // symp keeps track of the terminal symbols that have been built
  void ParseRec(const std::string& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp);
 public:
  unsigned root;
  unsigned char frontier_sites;
  unsigned short terminals;

  std::vector<TreeFragmentProduction> nodes;
};

struct TFIState {
  TFIState() : node(), rhspos() {}
  TFIState(unsigned n, unsigned p) : node(n), rhspos(p) {}
  bool operator==(const TFIState& o) const { return node == o.node && rhspos == o.rhspos; }
  bool operator!=(const TFIState& o) const { return node != o.node && rhspos != o.rhspos; }
  unsigned short node;
  unsigned short rhspos;
};

class BreadthFirstIterator : public std::iterator<std::forward_iterator_tag, unsigned> {
  const TreeFragment* tf_;
  std::queue<TFIState> q_;
  unsigned sym;
 public:
  explicit BreadthFirstIterator(const TreeFragment* tf) : tf_(tf) {
    q_.push(TFIState(tf->nodes.size() - 1, 0));
    Stage();
  }
  BreadthFirstIterator(const TreeFragment* tf, int) : tf_(tf) {}
  const unsigned& operator*() const { return sym; }
  const unsigned* operator->() const { return &sym; }
  bool operator==(const BreadthFirstIterator& other) const {
    return (tf_ == other.tf_) && (q_ == other.q_);
  }
  bool operator!=(const BreadthFirstIterator& other) const {
    return (tf_ != other.tf_) || (q_ != other.q_);
  }
  void Stage() {
    if (q_.empty()) return;
    const TFIState& s = q_.front();
    if (s.rhspos < 0) {
      sym = tf_->nodes[s.node].lhs;
    } else {
      sym = tf_->nodes[s.node].rhs[s.rhspos];
      if (IsInternalNT(sym)) {
        q_.push(TFIState(sym & ALL_MASK, 0));
        sym = tf_->nodes[sym & ALL_MASK].lhs;
      }
    }
  }
  const BreadthFirstIterator& operator++() {
    TFIState& s = q_.front();
    const unsigned len = tf_->nodes[s.node].rhs.size();
    s.rhspos++;
    if (s.rhspos > len) {
      q_.pop();
      Stage();
    } else if (s.rhspos == len) {
      sym = 0;
    } else {
      Stage();
    }
    return *this;
  }
  BreadthFirstIterator operator++(int) {
    BreadthFirstIterator res = *this;
    ++(*this);
    return res;
  }
};

inline std::ostream& operator<<(std::ostream& os, const TreeFragment& x) {
  x.DebugRec(x.nodes.size() - 1, &os);
  return os;
}

}

#endif