summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/train_pr_global.py
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/train_pr_global.py')
-rw-r--r--gi/posterior-regularisation/train_pr_global.py45
1 files changed, 24 insertions, 21 deletions
diff --git a/gi/posterior-regularisation/train_pr_global.py b/gi/posterior-regularisation/train_pr_global.py
index f2806b6e..8521bccb 100644
--- a/gi/posterior-regularisation/train_pr_global.py
+++ b/gi/posterior-regularisation/train_pr_global.py
@@ -45,7 +45,7 @@ print 'edges_phrase_to_context', edges_phrase_to_context
# Step 2: initialise the model parameters
#
-num_tags = 5
+num_tags = 10
num_types = len(types)
num_phrases = len(edges_phrase_to_context)
num_contexts = len(edges_context_to_phrase)
@@ -56,11 +56,11 @@ def normalise(a):
return a / float(sum(a))
# Pr(tag | phrase)
-#tagDist = [normalise(random(num_tags)+1) for p in range(num_phrases)]
-tagDist = [normalise(array(range(1,num_tags+1))) for p in range(num_phrases)]
+tagDist = [normalise(random(num_tags)+1) for p in range(num_phrases)]
+#tagDist = [normalise(array(range(1,num_tags+1))) for p in range(num_phrases)]
# Pr(context at pos i = w | tag) indexed by i, tag, word
-contextWordDist = [[normalise(array(range(1,num_types+1))) for t in range(num_tags)] for i in range(4)]
-#contextWordDist = [[normalise(random(num_types)+1) for t in range(num_tags)] for i in range(4)]
+#contextWordDist = [[normalise(array(range(1,num_types+1))) for t in range(num_tags)] for i in range(4)]
+contextWordDist = [[normalise(random(num_types)+1) for t in range(num_tags)] for i in range(4)]
# PR langrange multipliers
lamba = zeros(2 * num_edges * num_tags)
omega_offset = num_edges * num_tags
@@ -99,6 +99,8 @@ for iteration in range(20):
cz = sum(conditionals)
conditionals /= cz
+ #print 'dual', phrase, context, count, 'p =', conditionals
+
local_z = 0
for t in range(num_tags):
li = lamba_index[phrase,context] + t
@@ -106,8 +108,8 @@ for iteration in range(20):
logz += log(local_z) * count
#print 'ls', ls
- print 'lambda', list(ls)
- print 'dual', logz
+ #print 'lambda', list(ls)
+ #print 'dual', logz
return logz
def loglikelihood():
@@ -146,12 +148,12 @@ for iteration in range(20):
for t in range(num_tags):
best = -1e500
for phrase, count in pcs:
- li = lamba_index[phrase,context] + t
+ li = omega_offset + lamba_index[phrase,context] + t
s = expectations[li]
if s > best: best = s
ct_l1linf += best
- return llh, kl, pt_l1linf, ct_l1linf, llh + kl + delta * pt_l1linf + gamma * ct_l1linf
+ return llh, kl, pt_l1linf, ct_l1linf, llh - kl - delta * pt_l1linf - gamma * ct_l1linf
def dual_deriv(ls):
# d/dl log(z) = E_q[phi]
@@ -173,13 +175,13 @@ for iteration in range(20):
scores[t] = conditionals[t] * exp(-ls[li] - ls[omega_offset + li])
local_z = sum(scores)
+ #print 'ddual', phrase, context, count, 'q =', scores / local_z
+
for t in range(num_tags):
- if delta > 0:
- deriv[lamba_index[phrase,context] + t] -= count * scores[t] / local_z
- if gamma > 0:
- deriv[omega_offset + lamba_index[phrase,context] + t] -= count * scores[t] / local_z
+ deriv[lamba_index[phrase,context] + t] -= count * scores[t] / local_z
+ deriv[omega_offset + lamba_index[phrase,context] + t] -= count * scores[t] / local_z
- print 'ddual', list(deriv)
+ #print 'ddual', list(deriv)
return deriv
def constraints(ls):
@@ -244,7 +246,7 @@ for iteration in range(20):
print 'Post lambda optimisation dual', dual(lamba), 'primal', primal(lamba)
# E-step
- llh = z = 0
+ llh = log_z = 0
for p, (phrase, ccs) in enumerate(edges_phrase_to_context):
for context, count in ccs:
conditionals = zeros(num_tags)
@@ -257,20 +259,21 @@ for iteration in range(20):
conditionals /= cz
llh += log(cz) * count
- scores = zeros(num_tags)
+ q = zeros(num_tags)
li = lamba_index[phrase, context]
for t in range(num_tags):
- scores[t] = conditionals[t] * exp(-lamba[li + t] - lamba[omega_offset + li + t])
- z += count * sum(scores)
+ q[t] = conditionals[t] * exp(-lamba[li + t] - lamba[omega_offset + li + t])
+ qz = sum(q)
+ log_z += count * log(qz)
for t in range(num_tags):
- tagCounts[p][t] += count * scores[t]
+ tagCounts[p][t] += count * q[t] / qz
for i in range(4):
for t in range(num_tags):
- contextWordCounts[i][t][types[context[i]]] += count * scores[t]
+ contextWordCounts[i][t][types[context[i]]] += count * q[t] / qz
- print 'iteration', iteration, 'llh', llh, 'logz', log(z)
+ print 'iteration', iteration, 'llh', llh, 'logz', log_z
# M-step
for p in range(num_phrases):