summaryrefslogtreecommitdiff
path: root/src/util.py
blob: 7ce1c7fcb97d4529e68b9825770d4bfba1a8852f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import re
from collections import defaultdict

ARITY_SEP = '@'
ARITY_STR = 's'
ARITY_ANY = '*'

def after_nth(mrl, token, n):
  #print mrl, token
  while n > 0:
    m = re.search(r'\b%s\b' % token, mrl)
    #m = re.search(r'(^|[(, ])%s[(),]' % token, mrl)
    mrl = mrl[m.end()-1:]
    n = n - 1;
  return mrl

def count_arguments(s):
  args = False;
  parens = 0;
  commas = 0;
  i = 0
  #while parens >= 0 and i < len(s):
  while i < len(s) and ((not args and parens == 0) or (args and parens > 0)):
    c = s[i:i+1]
    if c == '(':
      args = True
      parens += 1
    elif c == ')':
      parens -= 1
    elif parens == 1 and c == ',':
      commas += 1
    elif parens < 1 and c == ',':
      break
    i += 1
  if args:
   return commas + 1
  else:
    assert commas == 0
    return 0

def fun_to_mrl(mrl, star_top=False):
  mrl = mrl.strip()

  mrl = re.sub(r"' *([A-Za-z0-9_ ]+?) *'", lambda x: '%s%s%s' % (x.group(1).replace(' ', '_'), ARITY_SEP, ARITY_STR), mrl)
  mrl = re.sub(r'\s+', ' ', mrl)
  mrl_noparens = re.sub(r'[\(\)]', ' ', mrl)
  mrl_noparens = re.sub(r'\s+', ' ', mrl_noparens)
  mrl_nocommas = re.sub(r',', ' ', mrl_noparens)
  mrl_nocommas = re.sub(r'\s+', ' ', mrl_nocommas)

  mrl_labeled_tokens = []
  seen = defaultdict(lambda:0)
  for token in mrl_nocommas.split():
    seen[token] += 1
    args = count_arguments(after_nth(mrl, token, seen[token]))
    #print token, args, after_nth(mrl, token, seen[token])
    if token[-len(ARITY_SEP)-len(ARITY_STR):] == '%s%s' % (ARITY_SEP, ARITY_STR):
      mrl_labeled_tokens.append(token)
    else:
      mrl_labeled_tokens.append('%s%s%d' % (token, ARITY_SEP, args))

  if star_top:
    tok = mrl_labeled_tokens[0]
    sep = tok.rindex(ARITY_SEP)
    mrl_labeled_tokens[0] = tok[:sep] + ARITY_SEP + ARITY_ANY
  
  return ' '.join(mrl_labeled_tokens)