summaryrefslogtreecommitdiff
path: root/gi
diff options
context:
space:
mode:
Diffstat (limited to 'gi')
-rw-r--r--gi/posterior-regularisation/train_pr_parallel.py18
1 files changed, 13 insertions, 5 deletions
diff --git a/gi/posterior-regularisation/train_pr_parallel.py b/gi/posterior-regularisation/train_pr_parallel.py
index 72c5cf25..d5df87b5 100644
--- a/gi/posterior-regularisation/train_pr_parallel.py
+++ b/gi/posterior-regularisation/train_pr_parallel.py
@@ -1,7 +1,7 @@
import sys
import scipy.optimize
from numpy import *
-from numpy.random import random
+from numpy.random import random, seed
#
# Step 1: load the concordance counts
@@ -48,6 +48,10 @@ num_types = len(types)
num_phrases = len(edges_phrase_to_context)
num_contexts = len(edges_context_to_phrase)
delta = float(sys.argv[1])
+assert sys.argv[2] in ('local', 'global')
+local = sys.argv[2] == 'local'
+if len(sys.argv) >= 2:
+ seed(int(sys.argv[3]))
def normalise(a):
return a / float(sum(a))
@@ -267,7 +271,7 @@ for iteration in range(20):
# E-step
llh = kl = l1lmax = 0
- if False:
+ if local:
for p in range(num_phrases):
o = LocalDualObjective(p, delta)
#print '\toptimising lambda for phrase', p, '=', edges_phrase_to_context[p][0]
@@ -284,18 +288,22 @@ for iteration in range(20):
for i in range(4):
for t in range(num_tags):
contextWordCounts[i][t][types[context[i]]] += count * o.q[j,t]
+
+ #print 'iteration', iteration, 'LOCAL objective', (llh + kl + delta * l1lmax), 'llh', llh, 'kl', kl, 'l1lmax', l1lmax
else:
o = GlobalDualObjective(delta)
obj = o.optimize()
llh, kl, l1lmax = o.optimize()
+ index = 0
for p, (phrase, edges) in enumerate(edges_phrase_to_context):
- for j, (context, count) in enumerate(edges):
+ for context, count in edges:
for t in range(num_tags):
- tagCounts[p][t] += count * o.q[j,t]
+ tagCounts[p][t] += count * o.q[index,t]
for i in range(4):
for t in range(num_tags):
- contextWordCounts[i][t][types[context[i]]] += count * o.q[j,t]
+ contextWordCounts[i][t][types[context[i]]] += count * o.q[index,t]
+ index += 1
print 'iteration', iteration, 'objective', (llh + kl + delta * l1lmax), 'llh', llh, 'kl', kl, 'l1lmax', l1lmax