summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation')
-rw-r--r--gi/posterior-regularisation/train_pr_agree.py90
1 files changed, 81 insertions, 9 deletions
diff --git a/gi/posterior-regularisation/train_pr_agree.py b/gi/posterior-regularisation/train_pr_agree.py
index bbd6e007..9d41362d 100644
--- a/gi/posterior-regularisation/train_pr_agree.py
+++ b/gi/posterior-regularisation/train_pr_agree.py
@@ -250,7 +250,10 @@ class ProductModel:
# return the overall objective
return llh1 + llh2 + kl
-class InterpolatedModel:
+class RegularisedProductModel:
+ # as above, but with a slack regularisation term which kills the
+ # closed-form solution for the E-step
+
def __init__(self, epsilon):
self.pcm = PhraseToContextModel()
self.cpm = ContextToPhraseModel()
@@ -260,7 +263,7 @@ class InterpolatedModel:
def prob(self, pid, cid):
p1 = self.pcm.prob(pid, cid)
p2 = self.cpm.prob(pid, cid)
- return (p1 + p2) / 2
+ return (p1 / sum(p1)) * (p2 / sum(p2))
def dual(self, lamba):
return self.logz(lamba) + self.epsilon * dot(lamba, lamba) ** 0.5
@@ -272,17 +275,19 @@ class InterpolatedModel:
# PR-step: optimise lambda to minimise log(z_lambda) + eps ||lambda||_2
self.lamba = scipy.optimize.fmin_slsqp(self.dual, self.lamba,
bounds=[(0, 1e100)] * num_tags,
- fprime=self.dual_gradient, iprint=0)
+ fprime=self.dual_gradient, iprint=1)
# E,M-steps: collect expected counts under q_lambda and normalise
- #llh1 = self.pcm.expectation_maximisation_step(self.lamba)
- #llh2 = self.cpm.expectation_maximisation_step(-self.lamba)
+ llh1 = self.pcm.expectation_maximisation_step(self.lamba)
+ llh2 = self.cpm.expectation_maximisation_step(-self.lamba)
- # return the overall objective: llh1 + llh2 - KL(q||p1.p2)
- pass
+ # return the overall objective: llh - KL(q||p1.p2)
+ # llh = llh1 + llh2
+ # kl = sum q log q / p1 p2 = sum q { lambda . phi } - log Z
+ return llh1 + llh2 + self.logz(self.lamba) \
+ - dot(self.lamba, self.expected_features(self.lamba))
def logz(self, lamba):
- # FIXME: complete this!
lz = 0
for pid, cid, cnt in edges:
p1 = self.pcm.prob(pid, cid)
@@ -306,14 +311,81 @@ class InterpolatedModel:
fs -= cnt * q2 / sum(q2)
return fs
+
+class InterpolatedModel:
+ def __init__(self, epsilon):
+ self.pcm = PhraseToContextModel()
+ self.cpm = ContextToPhraseModel()
+ self.epsilon = epsilon
+ self.lamba = zeros(num_tags)
+
+ def prob(self, pid, cid):
+ p1 = self.pcm.prob(pid, cid)
+ p2 = self.cpm.prob(pid, cid)
+ return (p1 + p2) / 2
+
+ def dual(self, lamba):
+ return self.logz(lamba) + self.epsilon * dot(lamba, lamba) ** 0.5
+
+ def dual_gradient(self, lamba):
+ return self.expected_features(lamba) + self.epsilon * 2 * lamba
+
+ def expectation_maximisation_step(self):
+ # PR-step: optimise lambda to minimise log(z_lambda) + eps ||lambda||_2
+ self.lamba = scipy.optimize.fmin_slsqp(self.dual, self.lamba,
+ bounds=[(0, 1e100)] * num_tags,
+ fprime=self.dual_gradient, iprint=2)
+
+ # E,M-steps: collect expected counts under q_lambda and normalise
+ llh1 = self.pcm.expectation_maximisation_step(self.lamba)
+ llh2 = self.cpm.expectation_maximisation_step(self.lamba)
+
+ # return the overall objective: llh1 + llh2 - KL(q||p1.p2)
+ # kl = sum_y q log q / 0.5 * (p1 + p2) = sum_y q(y) { -lambda . phi(y) } - log Z
+ # = -log Z + lambda . (E_q1[-phi] + E_q2[-phi]) / 2
+ kl = -self.logz(self.lamba) + dot(self.lamba, self.expected_features(self.lamba))
+ return llh1 + llh2 - kl, llh1, llh2, kl
+ # FIXME: KL comes out negative...
+
+ def logz(self, lamba):
+ lz = 0
+ for pid, cid, cnt in edges:
+ p1 = self.pcm.prob(pid, cid)
+ q1 = p1 / sum(p1) * exp(-lamba)
+ q1z = sum(q1)
+
+ p2 = self.cpm.prob(pid, cid)
+ q2 = p2 / sum(p2) * exp(-lamba)
+ q2z = sum(q2)
+
+ lz += log(0.5 * (q1z + q2z)) * cnt
+ return lz
+
+ # z = 1/2 * (sum_y p1(y|x) exp (-lambda . phi(y)) + sum_y p2(y|x) exp (-lambda . phi(y)))
+ # = 1/2 (z1 + z2)
+ # d (log z) / dlambda = 1/2 (E_q1 [ -phi ] + E_q2 [ -phi ] )
+ def expected_features(self, lamba):
+ fs = zeros(num_tags)
+ for pid, cid, cnt in edges:
+ p1 = self.pcm.prob(pid, cid)
+ q1 = (p1 / sum(p1)) * exp(-lamba)
+ fs -= 0.5 * cnt * q1 / sum(q1)
+
+ p2 = self.cpm.prob(pid, cid)
+ q2 = (p2 / sum(p2)) * exp(-lamba)
+ fs -= 0.5 * cnt * q2 / sum(q2)
+ return fs
+
if style == 'p2c':
m = PhraseToContextModel()
elif style == 'c2p':
m = ContextToPhraseModel()
elif style == 'prod':
m = ProductModel()
+elif style == 'prodslack':
+ m = RegularisedProductModel(0.5)
elif style == 'sum':
- m = InterpolatedModel()
+ m = InterpolatedModel(0.5)
for iteration in range(30):
obj = m.expectation_maximisation_step()