diff options
Diffstat (limited to 'gi/posterior-regularisation')
| -rw-r--r-- | gi/posterior-regularisation/train_pr_parallel.py | 18 | 
1 files changed, 13 insertions, 5 deletions
| diff --git a/gi/posterior-regularisation/train_pr_parallel.py b/gi/posterior-regularisation/train_pr_parallel.py index 72c5cf25..d5df87b5 100644 --- a/gi/posterior-regularisation/train_pr_parallel.py +++ b/gi/posterior-regularisation/train_pr_parallel.py @@ -1,7 +1,7 @@  import sys  import scipy.optimize  from numpy import * -from numpy.random import random +from numpy.random import random, seed  #  # Step 1: load the concordance counts @@ -48,6 +48,10 @@ num_types = len(types)  num_phrases = len(edges_phrase_to_context)  num_contexts = len(edges_context_to_phrase)  delta = float(sys.argv[1]) +assert sys.argv[2] in ('local', 'global') +local = sys.argv[2] == 'local' +if len(sys.argv) >= 2: +     seed(int(sys.argv[3]))  def normalise(a):      return a / float(sum(a)) @@ -267,7 +271,7 @@ for iteration in range(20):      # E-step      llh = kl = l1lmax = 0 -    if False: +    if local:          for p in range(num_phrases):              o = LocalDualObjective(p, delta)              #print '\toptimising lambda for phrase', p, '=', edges_phrase_to_context[p][0] @@ -284,18 +288,22 @@ for iteration in range(20):                  for i in range(4):                      for t in range(num_tags):                          contextWordCounts[i][t][types[context[i]]] += count * o.q[j,t] + +        #print 'iteration', iteration, 'LOCAL objective', (llh + kl + delta * l1lmax), 'llh', llh, 'kl', kl, 'l1lmax', l1lmax      else:          o = GlobalDualObjective(delta)          obj = o.optimize()          llh, kl, l1lmax = o.optimize() +        index = 0          for p, (phrase, edges) in enumerate(edges_phrase_to_context): -            for j, (context, count) in enumerate(edges): +            for context, count in edges:                  for t in range(num_tags): -                    tagCounts[p][t] += count * o.q[j,t] +                    tagCounts[p][t] += count * o.q[index,t]                  for i in range(4):                      for t in range(num_tags): -                        contextWordCounts[i][t][types[context[i]]] += count * o.q[j,t] +                        contextWordCounts[i][t][types[context[i]]] += count * o.q[index,t] +                index += 1      print 'iteration', iteration, 'objective', (llh + kl + delta * l1lmax), 'llh', llh, 'kl', kl, 'l1lmax', l1lmax | 
