summaryrefslogtreecommitdiff
path: root/decoder/tree_fragment.cc
blob: 696c8601050014f68a53372de245b23b32adba90 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include "tree_fragment.h"

#include <cassert>

#include "tdict.h"

using namespace std;

namespace cdec {

TreeFragment::TreeFragment(const string& tree, bool allow_frontier_sites) {
  int bal = 0;
  const unsigned len = tree.size();
  unsigned cur = 0;
  unsigned open = 0, close = 0;
  for (auto& c : tree) {
    ++cur;
    if (c == '(') { ++open; ++bal; }
    else if (c == ')') {
      ++close; --bal;
      if (bal < 1 && cur != len) {
        cerr << "Badly formed tree detected at column " << cur << " in:\n" << tree << endl;
        abort();
      }
    }
  }
  nodes.resize(open);
  unsigned cp = 0, symp = 0, np = 0;
  ParseRec(tree, allow_frontier_sites, cp, symp, np, &cp, &symp, &np);
  root = nodes.back().lhs;
  //cerr << "ROOT: " << TD::Convert(root & ALL_MASK) << endl;
  //DebugRec(open - 1, &cerr); cerr << "\n";
}

void TreeFragment::DebugRec(unsigned cur, ostream* out) const {
  *out << '(' << TD::Convert(nodes[cur].lhs & ALL_MASK);
  for (auto& x : nodes[cur].rhs) {
    *out << ' ';
    if (IsFrontier(x)) {
      *out << '[' << TD::Convert(x & ALL_MASK) << ']';
    } else if (IsRHS(x)) {
      DebugRec(x & ALL_MASK, out);
    } else { // must be terminal
      *out << TD::Convert(x);
    }
  }
  *out << ')';
}

// cp is the character index in the tree
// np keeps track of the nodes (nonterminals) that have been built
// symp keeps track of the terminal symbols that have been built
void TreeFragment::ParseRec(const string& tree, bool afs, unsigned cp, unsigned symp, unsigned np, unsigned* pcp, unsigned* psymp, unsigned* pnp) {
  if (tree[cp] != '(') {
    cerr << "Expected ( at " << cp << endl;
    abort();
  }
  const unsigned i = symp;
  vector<unsigned> rhs; // w | 0 = terminal, w | NT_BIT, index | FRONTIER_BIT
  ++cp;
  while(tree[cp] == ' ') { ++cp; }
  const unsigned nt_start = cp;
  while(tree[cp] != ' ' && tree[cp] != '(' && tree[cp] != ')') { ++cp; }
  const unsigned nt_end = cp;
  while(tree[cp] == ' ') { ++cp; }
  while (tree[cp] != ')') {
    if (tree[cp] == '(') {
      // recursively call parser to deal with constituent
      ParseRec(tree, afs, cp, symp, np, &cp, &symp, &np);
      unsigned ind = np - 1;
      rhs.push_back(ind | RHS_BIT);
    } else { // deal with terminal / nonterminal substitution
      ++symp;
      assert(tree[cp] != ' ');
      const unsigned t_start = cp;
      while(tree[cp] != ' ' && tree[cp] != ')' && tree[cp] != '(') { ++cp; }
      const unsigned t_end = cp;
      while(tree[cp] == ' ') { ++cp; }
      // TODO: add a terminal symbol to the current edge
      const bool is_terminal = tree[t_start] != '[' || (t_end - t_start < 3 || tree[t_end - 1] != ']');
      if (is_terminal) {
        const unsigned term = TD::Convert(tree.substr(t_start, t_end - t_start));
        rhs.push_back(term);
        // cerr << "T='" << TD::Convert(term) << "'\n";
        ++terminals;
      } else { // frontier site (NT but no recursion)
        const unsigned nt = TD::Convert(tree.substr(t_start + 1, t_end - t_start - 2)) | FRONTIER_BIT;
        rhs.push_back(nt);
        ++frontier_sites;
        // cerr << "FRONT-NT=[" << TD::Convert(nt & ALL_MASK) << "]\n";
        if (!afs) {
          cerr << "Frontier sites not allowed in input: " << tree << endl;
          abort();
        }
      } 
    }
  } // continuent has completed, cp is at ), build node
  const unsigned j = symp; // span from (i,j)
  // add an internal non-terminal symbol
  const unsigned nt = TD::Convert(tree.substr(nt_start, nt_end - nt_start)) | RHS_BIT;
  nodes[np] = TreeFragmentProduction(nt, rhs);
  //cerr << np << " production(" << i << "," << j << ")=  " << TD::Convert(nt & ALL_MASK) << " -->";
  //for (auto& x : rhs) {
  //  cerr << ' ';
  //  if (IsFrontier(x)) cerr << '*';
  //  if (IsInternalNT(x)) cerr << TD::Convert(nodes[x & ALL_MASK].lhs & ALL_MASK); else
  //    cerr << TD::Convert(x & ALL_MASK);
  //}
  //cerr << "\n   "; DebugRec(np,&cerr); cerr << endl;
  ++cp;
  while(tree[cp] == ' ' && cp < tree.size()) { ++cp; }
  *pcp = cp;
  *pnp = np + 1;
  *psymp = symp;
}

}