diff options
Diffstat (limited to 'gi/posterior-regularisation')
-rw-r--r-- | gi/posterior-regularisation/train_pr_parallel.py | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/gi/posterior-regularisation/train_pr_parallel.py b/gi/posterior-regularisation/train_pr_parallel.py index 72c5cf25..d5df87b5 100644 --- a/gi/posterior-regularisation/train_pr_parallel.py +++ b/gi/posterior-regularisation/train_pr_parallel.py @@ -1,7 +1,7 @@ import sys import scipy.optimize from numpy import * -from numpy.random import random +from numpy.random import random, seed # # Step 1: load the concordance counts @@ -48,6 +48,10 @@ num_types = len(types) num_phrases = len(edges_phrase_to_context) num_contexts = len(edges_context_to_phrase) delta = float(sys.argv[1]) +assert sys.argv[2] in ('local', 'global') +local = sys.argv[2] == 'local' +if len(sys.argv) >= 2: + seed(int(sys.argv[3])) def normalise(a): return a / float(sum(a)) @@ -267,7 +271,7 @@ for iteration in range(20): # E-step llh = kl = l1lmax = 0 - if False: + if local: for p in range(num_phrases): o = LocalDualObjective(p, delta) #print '\toptimising lambda for phrase', p, '=', edges_phrase_to_context[p][0] @@ -284,18 +288,22 @@ for iteration in range(20): for i in range(4): for t in range(num_tags): contextWordCounts[i][t][types[context[i]]] += count * o.q[j,t] + + #print 'iteration', iteration, 'LOCAL objective', (llh + kl + delta * l1lmax), 'llh', llh, 'kl', kl, 'l1lmax', l1lmax else: o = GlobalDualObjective(delta) obj = o.optimize() llh, kl, l1lmax = o.optimize() + index = 0 for p, (phrase, edges) in enumerate(edges_phrase_to_context): - for j, (context, count) in enumerate(edges): + for context, count in edges: for t in range(num_tags): - tagCounts[p][t] += count * o.q[j,t] + tagCounts[p][t] += count * o.q[index,t] for i in range(4): for t in range(num_tags): - contextWordCounts[i][t][types[context[i]]] += count * o.q[j,t] + contextWordCounts[i][t][types[context[i]]] += count * o.q[index,t] + index += 1 print 'iteration', iteration, 'objective', (llh + kl + delta * l1lmax), 'llh', llh, 'kl', kl, 'l1lmax', l1lmax |