1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
|
from libc.string cimport strrchr, strstr, strcpy, strlen
from libc.stdlib cimport malloc, realloc, strtol
cdef int INDEX_SHIFT = 3
cdef int INDEX_MASK = (1<<INDEX_SHIFT)-1
cdef class Alphabet:
cdef readonly StringMap terminals, nonterminals
cdef int first_nonterminal, last_nonterminal
cdef dict id2sym
def __cinit__(self):
self.terminals = StringMap()
self.nonterminals = StringMap()
self.id2sym = {}
self.first_nonterminal = -1
def __dealloc__(self):
pass
cdef int isvar(self, int sym):
return sym < 0
cdef int isword(self, int sym):
return sym >= 0
cdef int getindex(self, int sym):
return -sym & INDEX_MASK
cdef int setindex(self, int sym, int ind):
return -(-sym & ~INDEX_MASK | ind)
cdef int clearindex(self, int sym):
return -(-sym& ~INDEX_MASK)
cdef int match(self, int sym1, int sym2):
return self.clearindex(sym1) == self.clearindex(sym2);
cdef char* tocat(self, int sym):
return self.nonterminals.word((-sym >> INDEX_SHIFT)-1)
cdef int fromcat(self, char *s):
cdef int i
i = self.nonterminals.index(s)
if self.first_nonterminal == -1:
self.first_nonterminal = i
if i > self.last_nonterminal:
self.last_nonterminal = i
return -(i+1 << INDEX_SHIFT)
cdef char* tostring(self, int sym):
cdef int ind
if self.isvar(sym):
if sym in self.id2sym:
return self.id2sym[sym]
ind = self.getindex(sym)
if ind > 0:
self.id2sym[sym] = "[%s,%d]" % (self.tocat(sym), ind)
else:
self.id2sym[sym] = "[%s]" % self.tocat(sym)
return self.id2sym[sym]
else:
return self.terminals.word(sym)
cdef int fromstring(self, char *s, bint terminal):
"""Warning: this method is allowed to alter s."""
cdef char *comma
cdef int n
n = strlen(s)
cdef char *sep
sep = strstr(s,"_SEP_")
if n >= 3 and s[0] == c'[' and s[n-1] == c']' and sep == NULL:
if terminal:
s1 = "\\"+s
return self.terminals.index(s1)
s[n-1] = c'\0'
s = s + 1
comma = strrchr(s, c',')
if comma != NULL:
comma[0] = c'\0'
return self.setindex(self.fromcat(s), strtol(comma+1, NULL, 10))
else:
return self.fromcat(s)
else:
return self.terminals.index(s)
cdef Alphabet ALPHABET = Alphabet()
cdef char* sym_tostring(int sym):
return ALPHABET.tostring(sym)
cdef char* sym_tocat(int sym):
return ALPHABET.tocat(sym)
cdef int sym_isvar(int sym):
return ALPHABET.isvar(sym)
cdef int sym_getindex(int sym):
return ALPHABET.getindex(sym)
cdef int sym_setindex(int sym, int id):
return ALPHABET.setindex(sym, id)
cdef int sym_fromstring(char* string, bint terminal):
return ALPHABET.fromstring(string, terminal)
def make_lattice(words):
word_ids = (sym_fromstring(word, True) for word in words)
return tuple(((word, None, 1), ) for word in word_ids)
def decode_lattice(lattice):
return tuple((sym_tostring(sym), weight, dist) for (sym, weight, dist) in arc
for arc in node for node in lattice)
def decode_sentence(lattice):
return tuple(sym_tostring(sym) for ((sym, _, _),) in lattice)
def encode_words(words):
return tuple(sym_fromstring(word, True) for word in words)
def decode_words(syms):
return tuple(sym_tostring(sym) for sym in syms)
|