summaryrefslogtreecommitdiff
path: root/python/src/sa/sym.pxi
blob: 3fd6c5a7ebe02bd08841cd4f2445ad9a182fc9a3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from libc.string cimport strrchr, strstr, strcpy, strlen
from libc.stdlib cimport malloc, realloc, strtol

cdef int INDEX_SHIFT = 3
cdef int INDEX_MASK = (1<<INDEX_SHIFT)-1

cdef class Alphabet:
    cdef readonly StringMap terminals, nonterminals
    cdef int first_nonterminal, last_nonterminal
    cdef dict id2sym

    def __cinit__(self):
        self.terminals = StringMap()
        self.nonterminals = StringMap()
        self.id2sym = {}
        self.first_nonterminal = -1

    def __dealloc__(self):
        pass

    cdef int isvar(self, int sym):
        return sym < 0

    cdef int isword(self, int sym):
        return sym >= 0

    cdef int getindex(self, int sym):
        return -sym & INDEX_MASK

    cdef int setindex(self, int sym, int ind):
        return -(-sym & ~INDEX_MASK | ind)

    cdef int clearindex(self, int sym):
        return -(-sym& ~INDEX_MASK)

    cdef int match(self, int sym1, int sym2):
        return self.clearindex(sym1) == self.clearindex(sym2);

    cdef char* tocat(self, int sym):
        return self.nonterminals.word((-sym >> INDEX_SHIFT)-1)

    cdef int fromcat(self, char *s):
        cdef int i
        i = self.nonterminals.index(s)
        if self.first_nonterminal == -1:
            self.first_nonterminal = i
        if i > self.last_nonterminal:
            self.last_nonterminal = i
        return -(i+1 << INDEX_SHIFT)

    cdef char* tostring(self, int sym):
        cdef int ind
        if self.isvar(sym):
            if sym in self.id2sym:
                return self.id2sym[sym]

            ind = self.getindex(sym)
            if ind > 0:
                self.id2sym[sym] = "[%s,%d]" % (self.tocat(sym), ind)
            else:
                self.id2sym[sym] = "[%s]" % self.tocat(sym)
            return self.id2sym[sym]
                
        else:
            return self.terminals.word(sym)

    cdef int fromstring(self, char *s, bint terminal):
        """Warning: this method is allowed to alter s."""
        cdef char *comma
        cdef int n
        n = strlen(s)
        cdef char *sep
        sep = strstr(s,"_SEP_")
        if n >= 3 and s[0] == c'[' and s[n-1] == c']' and sep == NULL:
            if terminal:
                s1 = "\\"+s
                return self.terminals.index(s1)
            s[n-1] = c'\0'
            s = s + 1
            comma = strrchr(s, c',')
            if comma != NULL:
                comma[0] = c'\0'
                return self.setindex(self.fromcat(s), strtol(comma+1, NULL, 10))
            else:
                return self.fromcat(s)
        else:
            return self.terminals.index(s)

cdef Alphabet ALPHABET = Alphabet()

def sym_tostring(int sym):
    return ALPHABET.tostring(sym)

def sym_fromstring(bytes string, bint terminal):
    return ALPHABET.fromstring(string, terminal)

def sym_isvar(int sym):
    return ALPHABET.isvar(sym)

cdef int sym_setindex(int sym, int id):
    return ALPHABET.setindex(sym, id)