diff options
Diffstat (limited to 'gi/posterior-regularisation/train_pr_global.py')
-rw-r--r-- | gi/posterior-regularisation/train_pr_global.py | 45 |
1 files changed, 24 insertions, 21 deletions
diff --git a/gi/posterior-regularisation/train_pr_global.py b/gi/posterior-regularisation/train_pr_global.py index f2806b6e..8521bccb 100644 --- a/gi/posterior-regularisation/train_pr_global.py +++ b/gi/posterior-regularisation/train_pr_global.py @@ -45,7 +45,7 @@ print 'edges_phrase_to_context', edges_phrase_to_context # Step 2: initialise the model parameters # -num_tags = 5 +num_tags = 10 num_types = len(types) num_phrases = len(edges_phrase_to_context) num_contexts = len(edges_context_to_phrase) @@ -56,11 +56,11 @@ def normalise(a): return a / float(sum(a)) # Pr(tag | phrase) -#tagDist = [normalise(random(num_tags)+1) for p in range(num_phrases)] -tagDist = [normalise(array(range(1,num_tags+1))) for p in range(num_phrases)] +tagDist = [normalise(random(num_tags)+1) for p in range(num_phrases)] +#tagDist = [normalise(array(range(1,num_tags+1))) for p in range(num_phrases)] # Pr(context at pos i = w | tag) indexed by i, tag, word -contextWordDist = [[normalise(array(range(1,num_types+1))) for t in range(num_tags)] for i in range(4)] -#contextWordDist = [[normalise(random(num_types)+1) for t in range(num_tags)] for i in range(4)] +#contextWordDist = [[normalise(array(range(1,num_types+1))) for t in range(num_tags)] for i in range(4)] +contextWordDist = [[normalise(random(num_types)+1) for t in range(num_tags)] for i in range(4)] # PR langrange multipliers lamba = zeros(2 * num_edges * num_tags) omega_offset = num_edges * num_tags @@ -99,6 +99,8 @@ for iteration in range(20): cz = sum(conditionals) conditionals /= cz + #print 'dual', phrase, context, count, 'p =', conditionals + local_z = 0 for t in range(num_tags): li = lamba_index[phrase,context] + t @@ -106,8 +108,8 @@ for iteration in range(20): logz += log(local_z) * count #print 'ls', ls - print 'lambda', list(ls) - print 'dual', logz + #print 'lambda', list(ls) + #print 'dual', logz return logz def loglikelihood(): @@ -146,12 +148,12 @@ for iteration in range(20): for t in range(num_tags): best = -1e500 for phrase, count in pcs: - li = lamba_index[phrase,context] + t + li = omega_offset + lamba_index[phrase,context] + t s = expectations[li] if s > best: best = s ct_l1linf += best - return llh, kl, pt_l1linf, ct_l1linf, llh + kl + delta * pt_l1linf + gamma * ct_l1linf + return llh, kl, pt_l1linf, ct_l1linf, llh - kl - delta * pt_l1linf - gamma * ct_l1linf def dual_deriv(ls): # d/dl log(z) = E_q[phi] @@ -173,13 +175,13 @@ for iteration in range(20): scores[t] = conditionals[t] * exp(-ls[li] - ls[omega_offset + li]) local_z = sum(scores) + #print 'ddual', phrase, context, count, 'q =', scores / local_z + for t in range(num_tags): - if delta > 0: - deriv[lamba_index[phrase,context] + t] -= count * scores[t] / local_z - if gamma > 0: - deriv[omega_offset + lamba_index[phrase,context] + t] -= count * scores[t] / local_z + deriv[lamba_index[phrase,context] + t] -= count * scores[t] / local_z + deriv[omega_offset + lamba_index[phrase,context] + t] -= count * scores[t] / local_z - print 'ddual', list(deriv) + #print 'ddual', list(deriv) return deriv def constraints(ls): @@ -244,7 +246,7 @@ for iteration in range(20): print 'Post lambda optimisation dual', dual(lamba), 'primal', primal(lamba) # E-step - llh = z = 0 + llh = log_z = 0 for p, (phrase, ccs) in enumerate(edges_phrase_to_context): for context, count in ccs: conditionals = zeros(num_tags) @@ -257,20 +259,21 @@ for iteration in range(20): conditionals /= cz llh += log(cz) * count - scores = zeros(num_tags) + q = zeros(num_tags) li = lamba_index[phrase, context] for t in range(num_tags): - scores[t] = conditionals[t] * exp(-lamba[li + t] - lamba[omega_offset + li + t]) - z += count * sum(scores) + q[t] = conditionals[t] * exp(-lamba[li + t] - lamba[omega_offset + li + t]) + qz = sum(q) + log_z += count * log(qz) for t in range(num_tags): - tagCounts[p][t] += count * scores[t] + tagCounts[p][t] += count * q[t] / qz for i in range(4): for t in range(num_tags): - contextWordCounts[i][t][types[context[i]]] += count * scores[t] + contextWordCounts[i][t][types[context[i]]] += count * q[t] / qz - print 'iteration', iteration, 'llh', llh, 'logz', log(z) + print 'iteration', iteration, 'llh', llh, 'logz', log_z # M-step for p in range(num_phrases): |