summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/train_pr_parallel.py
diff options
context:
space:
mode:
authortrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-07 14:11:42 +0000
committertrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-07 14:11:42 +0000
commita15d666d23169dafdf01b7f5923570a9ba10787b (patch)
tree5f74648ef8e0d9a6c36c211d0d31a0465b2a295c /gi/posterior-regularisation/train_pr_parallel.py
parent43d74920424e83c397321db549290f167e15db46 (diff)
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@173 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/train_pr_parallel.py')
-rw-r--r--gi/posterior-regularisation/train_pr_parallel.py25
1 files changed, 15 insertions, 10 deletions
diff --git a/gi/posterior-regularisation/train_pr_parallel.py b/gi/posterior-regularisation/train_pr_parallel.py
index 4de7f504..3b9cefed 100644
--- a/gi/posterior-regularisation/train_pr_parallel.py
+++ b/gi/posterior-regularisation/train_pr_parallel.py
@@ -41,7 +41,7 @@ for line in sys.stdin:
# Step 2: initialise the model parameters
#
-num_tags = 5
+num_tags = 25
num_types = len(types)
num_phrases = len(edges_phrase_to_context)
num_contexts = len(edges_context_to_phrase)
@@ -86,7 +86,7 @@ class GlobalDualObjective:
self.posterior[index,t] = prob
z = sum(self.posterior[index,:])
self.posterior[index,:] /= z
- self.llh += log(z)
+ self.llh += log(z) * count
index += 1
def objective(self, ls):
@@ -192,7 +192,7 @@ class LocalDualObjective:
self.posterior[i,t] = prob
z = sum(self.posterior[i,:])
self.posterior[i,:] /= z
- self.llh += log(z)
+ self.llh += log(z) * count
def objective(self, ls):
edges = edges_phrase_to_context[self.phraseId][1]
@@ -243,9 +243,10 @@ class LocalDualObjective:
gradient[t,i,t] -= count
return gradient.reshape((num_tags, len(edges)*num_tags))
- def optimize(self):
+ def optimize(self, ls=None):
edges = edges_phrase_to_context[self.phraseId][1]
- ls = zeros(len(edges) * num_tags)
+ if ls == None:
+ ls = zeros(len(edges) * num_tags)
#print '\tpre lambda optimisation dual', self.objective(ls) #, 'primal', primal(lamba)
ls = scipy.optimize.fmin_slsqp(self.objective, ls,
bounds=[(0, self.scale)] * len(edges) * num_tags,
@@ -253,6 +254,7 @@ class LocalDualObjective:
fprime=self.gradient,
fprime_ieqcons=self.constraints_gradient,
iprint=0) # =2 for verbose
+ #print '\tlambda', list(ls)
#print '\tpost lambda optimisation dual', self.objective(ls) #, 'primal', primal(lamba)
# returns llh, kl and l1lmax contribution
@@ -263,8 +265,9 @@ class LocalDualObjective:
lmax = max(lmax, self.q[i,t])
l1lmax += lmax
- return self.llh, -self.objective(ls) + dot(ls, self.gradient(ls)), l1lmax
+ return self.llh, -self.objective(ls) + dot(ls, self.gradient(ls)), l1lmax, ls
+ls = [None] * num_phrases
for iteration in range(20):
tagCounts = [zeros(num_tags) for p in range(num_phrases)]
contextWordCounts = [[zeros(num_types) for t in range(num_tags)] for i in range(4)]
@@ -275,11 +278,13 @@ for iteration in range(20):
for p in range(num_phrases):
o = LocalDualObjective(p, delta)
#print '\toptimising lambda for phrase', p, '=', edges_phrase_to_context[p][0]
- obj = o.optimize()
- print '\tphrase', p, 'deltas', obj
+ #print '\toptimising lambda for phrase', p, 'ls', ls[p]
+ obj = o.optimize(ls[p])
+ #print '\tphrase', p, 'deltas', obj
llh += obj[0]
kl += obj[1]
l1lmax += obj[2]
+ ls[p] = obj[3]
edges = edges_phrase_to_context[p][1]
for j, (context, count) in enumerate(edges):
@@ -305,7 +310,7 @@ for iteration in range(20):
contextWordCounts[i][t][types[context[i]]] += count * o.q[index,t]
index += 1
- print 'iteration', iteration, 'objective', (llh + kl + delta * l1lmax), 'llh', llh, 'kl', kl, 'l1lmax', l1lmax
+ print 'iteration', iteration, 'objective', (llh - kl - delta * l1lmax), 'llh', llh, 'kl', kl, 'l1lmax', l1lmax
# M-step
for p in range(num_phrases):
@@ -325,4 +330,4 @@ for p, (phrase, ccs) in enumerate(edges_phrase_to_context):
cz = sum(conditionals)
conditionals /= cz
- print '%s\t%s ||| C=%d ||| %d |||' % (phrase, context, count, argmax(conditionals)), conditionals
+ print '%s\t%s ||| C=%d |||' % (phrase, context, argmax(conditionals)), conditionals