summaryrefslogtreecommitdiff
path: root/gi/pyp-topics/scripts/spans2labels.py
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pyp-topics/scripts/spans2labels.py')
-rwxr-xr-xgi/pyp-topics/scripts/spans2labels.py63
1 files changed, 51 insertions, 12 deletions
diff --git a/gi/pyp-topics/scripts/spans2labels.py b/gi/pyp-topics/scripts/spans2labels.py
index 73ea20f2..50fa8106 100755
--- a/gi/pyp-topics/scripts/spans2labels.py
+++ b/gi/pyp-topics/scripts/spans2labels.py
@@ -4,7 +4,7 @@ import sys
from operator import itemgetter
if len(sys.argv) <= 2:
- print "Usage: spans2labels.py phrase_context_index [order] [threshold] [languages={s,t,b}{s,t,b}]"
+ print "Usage: spans2labels.py phrase_context_index [order] [threshold] [languages={s,t,b}{s,t,b}] [type={tag,tok,both},{tag,tok,both}]"
exit(1)
order=1
@@ -19,8 +19,13 @@ if len(sys.argv) > 4:
phr, ctx = sys.argv[4]
assert phr in 'stb'
assert ctx in 'stb'
+phr_typ = ctx_typ = 'both'
+if len(sys.argv) > 5:
+ phr_typ, ctx_typ = sys.argv[5].split(',')
+ assert phr_typ in ('tag', 'tok', 'both')
+ assert ctx_typ in ('tag', 'tok', 'both')
-print >>sys.stderr, "Loading phrase index"
+#print >>sys.stderr, "Loading phrase index"
phrase_context_index = {}
for line in file(sys.argv[1], 'r'):
phrase,tail= line.split('\t')
@@ -43,13 +48,49 @@ for line in file(sys.argv[1], 'r'):
phrase_context_index[(phrase,contexts[i])] = category
#print (phrase,contexts[i]), category
-print >>sys.stderr, "Labelling spans"
+#print >>sys.stderr, "Labelling spans"
for line in sys.stdin:
- line_segments = line.split('|||')
+ #print >>sys.stderr, "line", line.strip()
+ line_segments = line.split(' ||| ')
+ assert len(line_segments) >= 3
source = ['<s>' for x in range(order)] + line_segments[0].split() + ['</s>' for x in range(order)]
target = ['<s>' for x in range(order)] + line_segments[1].split() + ['</s>' for x in range(order)]
phrases = [ [int(i) for i in x.split('-')] for x in line_segments[2].split()]
+ if phr_typ != 'both' or ctx_typ != 'both':
+ if phr in 'tb' or ctx in 'tb':
+ target_toks = ['<s>' for x in range(order)] + map(lambda x: x.rsplit('_', 1)[0], line_segments[1].split()) + ['</s>' for x in range(order)]
+ target_tags = ['<s>' for x in range(order)] + map(lambda x: x.rsplit('_', 1)[-1], line_segments[1].split()) + ['</s>' for x in range(order)]
+
+ if phr in 'tb':
+ if phr_typ == 'tok':
+ targetP = target_toks
+ elif phr_typ == 'tag':
+ targetP = target_tags
+ if ctx in 'tb':
+ if ctx_typ == 'tok':
+ targetC = target_toks
+ elif ctx_typ == 'tag':
+ targetC = target_tags
+
+ if phr in 'sb' or ctx in 'sb':
+ source_toks = ['<s>' for x in range(order)] + map(lambda x: x.rsplit('_', 1)[0], line_segments[0].split()) + ['</s>' for x in range(order)]
+ source_tags = ['<s>' for x in range(order)] + map(lambda x: x.rsplit('_', 1)[-1], line_segments[0].split()) + ['</s>' for x in range(order)]
+
+ if phr in 'sb':
+ if phr_typ == 'tok':
+ sourceP = source_toks
+ elif phr_typ == 'tag':
+ sourceP = source_tags
+ if ctx in 'sb':
+ if ctx_typ == 'tok':
+ sourceC = source_toks
+ elif ctx_typ == 'tag':
+ sourceC = source_tags
+ else:
+ sourceP = sourceC = source
+ targetP = targetC = target
+
#print >>sys.stderr, "line", source, '---', target, 'phrases', phrases
print "|||",
@@ -62,17 +103,17 @@ for line in sys.stdin:
phraset = phrases = contextt = contexts = ''
if phr in 'tb':
- phraset = reduce(lambda x, y: x+y+" ", target[t1:t2], "").strip()
+ phraset = reduce(lambda x, y: x+y+" ", targetP[t1:t2], "").strip()
if phr in 'sb':
- phrases = reduce(lambda x, y: x+y+" ", source[s1:s2], "").strip()
+ phrases = reduce(lambda x, y: x+y+" ", sourceP[s1:s2], "").strip()
if ctx in 'tb':
- left_context = reduce(lambda x, y: x+y+" ", target[t1-order:t1], "")
- right_context = reduce(lambda x, y: x+y+" ", target[t2:t2+order], "").strip()
+ left_context = reduce(lambda x, y: x+y+" ", targetC[t1-order:t1], "")
+ right_context = reduce(lambda x, y: x+y+" ", targetC[t2:t2+order], "").strip()
contextt = "%s<PHRASE> %s" % (left_context, right_context)
if ctx in 'sb':
- left_context = reduce(lambda x, y: x+y+" ", source[s1-order:s1], "")
- right_context = reduce(lambda x, y: x+y+" ", source[s2:s2+order], "").strip()
+ left_context = reduce(lambda x, y: x+y+" ", sourceC[s1-order:s1], "")
+ right_context = reduce(lambda x, y: x+y+" ", sourceC[s2:s2+order], "").strip()
contexts = "%s<PHRASE> %s" % (left_context, right_context)
if phr == 'b':
@@ -94,5 +135,3 @@ for line in sys.stdin:
if label != cutoff_cat: #cutoff'd spans are left unlabelled
print "%d-%d-%d-%d:X%s" % (s1-order,s2-order,t1-order,t2-order,label),
print
-
-