summaryrefslogtreecommitdiff
path: root/gi/pyp-topics/scripts/extract_contexts.py
blob: b2723f2a02a7700db0de3f2b5491a7877db718a4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#!/usr/bin/python

import sys,collections

def extract_backoff(context_list, order):
  assert len(context_list) == (2*order)
  backoffs = []
  for i in range(1,order+1):
    if i == order:
      backoffs.append(([context_list[i-1]+"|"], ["|"+context_list[i]]))
    else:
      right_limit = 2*order-i
      core = context_list[i:right_limit]
      left = [context_list[i-1]+"|"*(order-i+1)]
      right = ["|"*(order-i+1)+context_list[right_limit]]
      backoffs.append((core, left, right))
# print context_list, backoffs
  return backoffs

def tuple_to_str(t):
  s=""
  for i,x in enumerate(t):
    if i > 0: s += "|"
    s += str(x)
  return s

if len(sys.argv) < 3:
  print "Usage: extract-contexts.py output_filename order cutoff lowercase"
  exit(1)

output_filename = sys.argv[1]
order = int(sys.argv[2])
cutoff = 0
if len(sys.argv) > 3:
  cutoff = int(sys.argv[3])
lowercase = False
if len(sys.argv) > 4:
  lowercase = bool(sys.argv[4])

contexts_dict={}
contexts_list=[]
contexts_freq=collections.defaultdict(int)
contexts_backoff={}

token_dict={}
token_list=[]
documents_dict=collections.defaultdict(dict)

contexts_at_order = [i for i in range(order+1)]

prefix = ["<s%d>|<s>"%i for i in range(order)]
suffix = ["</s%d>|</s>"%i for i in range(order)]

for line in sys.stdin:
  tokens = list(prefix)
  tokens.extend(line.split())
  tokens.extend(suffix)
  if lowercase:
    tokens = map(lambda x: x.lower(), tokens)

  for i in range(order, len(tokens)-order):
    context_list = []
    term=""
    for j in range(i-order, i+order+1):
      token,tag = tokens[j].rsplit('|',2)
      if j != i:
        context_list.append(token)
      else:
        if token not in token_dict: 
          token_dict[token] = len(token_dict)
          token_list.append(token)
        term = token_dict[token] 

    context = tuple_to_str(tuple(context_list))

    if context not in contexts_dict: 
      context_index = len(contexts_dict)
      contexts_dict[context] = context_index
      contexts_list.append(context)
      contexts_at_order[0] += 1

      # handle backoff
      backoff_contexts = extract_backoff(context_list, order)
      bo_indexes=[(context_index,)]
#     bo_indexes=[(context,)]
      for i,bo in enumerate(backoff_contexts):
        factor_indexes=[]
        for factor in bo:
          bo_tuple = tuple_to_str(tuple(factor))
          if bo_tuple not in contexts_dict:
            contexts_dict[bo_tuple] = len(contexts_dict)
            contexts_list.append(bo_tuple)
            contexts_at_order[i+1] += 1
#         factor_indexes.append(bo_tuple)
          factor_indexes.append(contexts_dict[bo_tuple])
        bo_indexes.append(tuple(factor_indexes))
      
      for i in range(len(bo_indexes)-1):
        contexts_backoff[bo_indexes[i][0]] = bo_indexes[i+1]

    context_index = contexts_dict[context]
    contexts_freq[context_index] += 1

    if context_index not in documents_dict[term]:
      documents_dict[term][context_index] = 1
    else:
      documents_dict[term][context_index] += 1

term_file = open(output_filename+".terms",'w')
for t in token_list: print >>term_file, t
term_file.close()

contexts_file = open(output_filename+".contexts",'w')
for c in contexts_list: 
  print >>contexts_file, c
contexts_file.close()

data_file = open(output_filename+".data",'w')
for t in range(len(token_list)): 
  line=""
  num_active=0
  for c in documents_dict[t]:
    count = documents_dict[t][c]
    if contexts_freq[c] >= cutoff:
      line += (' ' + str(c) + ':' + str(count))
      num_active += 1
  if num_active > 0:
    print >>data_file, "%d%s" % (num_active,line)
data_file.close()

contexts_backoff_file = open(output_filename+".contexts_backoff",'w')
print >>contexts_backoff_file, len(contexts_list), order,
#for x in contexts_at_order: 
#  print >>contexts_backoff_file, x,
#print >>contexts_backoff_file
for x in range(order-1):
  print >>contexts_backoff_file, 3,
print >>contexts_backoff_file, 2

for x in contexts_backoff: 
  print >>contexts_backoff_file, x, 
  for y in contexts_backoff[x]: print >>contexts_backoff_file, y,
  print >>contexts_backoff_file 
contexts_backoff_file.close()