summaryrefslogtreecommitdiff
path: root/gi/pyp-topics/scripts/spans2labels.py
blob: 50fa8106b9c4794278b742504a2db46b92b5b9fb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/python

import sys
from operator import itemgetter

if len(sys.argv) <= 2:
  print "Usage: spans2labels.py phrase_context_index [order] [threshold] [languages={s,t,b}{s,t,b}] [type={tag,tok,both},{tag,tok,both}]"
  exit(1)

order=1
threshold = 0
cutoff_cat = "<UNK>"
if len(sys.argv) > 2:
  order = int(sys.argv[2])
if len(sys.argv) > 3:
  threshold = float(sys.argv[3])
phr=ctx='t'
if len(sys.argv) > 4:
  phr, ctx = sys.argv[4]
  assert phr in 'stb'
  assert ctx in 'stb'
phr_typ = ctx_typ = 'both'
if len(sys.argv) > 5:
  phr_typ, ctx_typ = sys.argv[5].split(',')
  assert phr_typ in ('tag', 'tok', 'both')
  assert ctx_typ in ('tag', 'tok', 'both')

#print >>sys.stderr, "Loading phrase index"
phrase_context_index = {}
for line in file(sys.argv[1], 'r'):
  phrase,tail= line.split('\t')
  contexts = tail.split(" ||| ")
  try: # remove Phil's bizarre integer pair
       x,y = contexts[0].split()
       x=int(x); y=int(y)
       contexts = contexts[1:]
  except:
       pass
  if len(contexts) == 1: continue
  assert len(contexts) % 2 == 0
  for i in range(0, len(contexts), 2):
    #parse contexts[i+1] = " C=1 P=0.8 ... "
    features=dict([ keyval.split('=') for keyval in contexts[i+1].split()])
    category = features['C']    
    if features.has_key('P') and float(features['P']) < threshold:
	category = cutoff_cat
    
    phrase_context_index[(phrase,contexts[i])] = category 
    #print (phrase,contexts[i]), category

#print >>sys.stderr, "Labelling spans"
for line in sys.stdin:
  #print >>sys.stderr, "line", line.strip()
  line_segments = line.split(' ||| ')
  assert len(line_segments) >= 3
  source = ['<s>' for x in range(order)] + line_segments[0].split() + ['</s>' for x in range(order)]
  target = ['<s>' for x in range(order)] + line_segments[1].split() + ['</s>' for x in range(order)]
  phrases = [ [int(i) for i in x.split('-')] for x in line_segments[2].split()]

  if phr_typ != 'both' or ctx_typ != 'both':
    if phr in 'tb' or ctx in 'tb':
        target_toks = ['<s>' for x in range(order)] + map(lambda x: x.rsplit('_', 1)[0], line_segments[1].split()) + ['</s>' for x in range(order)]
        target_tags = ['<s>' for x in range(order)] + map(lambda x: x.rsplit('_', 1)[-1], line_segments[1].split()) + ['</s>' for x in range(order)]

        if phr in 'tb':
            if phr_typ == 'tok':
                targetP = target_toks
            elif phr_typ == 'tag':
                targetP = target_tags
        if ctx in 'tb':
            if ctx_typ == 'tok':
                targetC = target_toks
            elif ctx_typ == 'tag':
                targetC = target_tags

    if phr in 'sb' or ctx in 'sb':
        source_toks = ['<s>' for x in range(order)] + map(lambda x: x.rsplit('_', 1)[0], line_segments[0].split()) + ['</s>' for x in range(order)]
        source_tags = ['<s>' for x in range(order)] + map(lambda x: x.rsplit('_', 1)[-1], line_segments[0].split()) + ['</s>' for x in range(order)]

        if phr in 'sb':
            if phr_typ == 'tok':
                sourceP = source_toks
            elif phr_typ == 'tag':
                sourceP = source_tags
        if ctx in 'sb':
            if ctx_typ == 'tok':
                sourceC = source_toks
            elif ctx_typ == 'tag':
                sourceC = source_tags
  else:
    sourceP = sourceC = source
    targetP = targetC = target

  #print >>sys.stderr, "line", source, '---', target, 'phrases', phrases

  print "|||",

  for s1,s2,t1,t2 in phrases:
    s1 += order
    s2 += order
    t1 += order
    t2 += order

    phraset = phrases = contextt = contexts = ''
    if phr in 'tb':
        phraset = reduce(lambda x, y: x+y+" ", targetP[t1:t2], "").strip()
    if phr in 'sb':
        phrases = reduce(lambda x, y: x+y+" ", sourceP[s1:s2], "").strip()

    if ctx in 'tb':
        left_context = reduce(lambda x, y: x+y+" ", targetC[t1-order:t1], "")
        right_context = reduce(lambda x, y: x+y+" ", targetC[t2:t2+order], "").strip()
        contextt = "%s<PHRASE> %s" % (left_context, right_context)
    if ctx in 'sb':
        left_context = reduce(lambda x, y: x+y+" ", sourceC[s1-order:s1], "")
        right_context = reduce(lambda x, y: x+y+" ", sourceC[s2:s2+order], "").strip()
        contexts = "%s<PHRASE> %s" % (left_context, right_context)

    if phr == 'b':
        phrase = phraset + ' <SPLIT> ' + phrases
    elif phr == 's':
        phrase = phrases
    else:
        phrase = phraset

    if ctx == 'b':
        context = contextt + ' <SPLIT> ' + contexts
    elif ctx == 's':
        context = contexts
    else:
        context = contextt

    #print "%d-%d-%d-%d looking up" % (s1-order,s2-order,t1-order,t2-order), (phrase, context)
    label = phrase_context_index.get((phrase,context), cutoff_cat)
    if label != cutoff_cat: #cutoff'd spans are left unlabelled
      print "%d-%d-%d-%d:X%s" % (s1-order,s2-order,t1-order,t2-order,label),
  print