diff options
author | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-09 22:29:02 +0000 |
---|---|---|
committer | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-09 22:29:02 +0000 |
commit | 296c7cc2082557db4a82b4a1208986c6e93ad935 (patch) | |
tree | 3cf7cb3c0c26d9703c8a2a6b583ce71e21e8fdfe /gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java | |
parent | d5105daa487d67752cd599267f74b7c8d502ef1e (diff) |
Added initial VB implementation for symetric Dirichlet prior.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@215 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java')
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java | 183 |
1 files changed, 145 insertions, 38 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index b9b1b98c..7bc63c33 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -1,6 +1,7 @@ package phrase;
import gnu.trove.TIntArrayList;
+import org.apache.commons.math.special.Gamma;
import io.FileUtil;
import java.io.IOException;
import java.io.PrintStream;
@@ -12,6 +13,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger;
import phrase.Corpus.Edge;
+import util.MathUtil;
public class PhraseCluster {
@@ -26,7 +28,12 @@ public class PhraseCluster { // pi[phrase][tag] = p(tag | phrase)
private double pi[][];
- public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){
+ double alphaEmit;
+ double alphaPi;
+
+ public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads,
+ double alphaEmit, double alphaPi)
+ {
K=numCluster;
c=corpus;
n_words=c.getNumWords();
@@ -41,29 +48,41 @@ public class PhraseCluster { emit=new double [K][n_positions][n_words];
pi=new double[n_phrases][K];
- for(double [][]i:emit){
- for(double []j:i){
- arr.F.randomise(j);
+ for(double [][]i:emit)
+ {
+ for(double []j:i)
+ {
+ arr.F.randomise(j, alphaEmit <= 0);
+ if (alphaEmit > 0)
+ digammaNormalize(j, alphaEmit);
}
}
- for(double []j:pi){
- arr.F.randomise(j);
+ for(double []j:pi)
+ {
+ arr.F.randomise(j, alphaPi <= 0);
+ if (alphaPi > 0)
+ digammaNormalize(j, alphaPi);
}
+
+ this.alphaEmit = alphaEmit;
+ this.alphaPi = alphaPi;
}
-
- public double EM(){
+ public double EM()
+ {
double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
double loglikelihood=0;
//E
- for(int phrase=0; phrase < n_phrases; phrase++){
+ for(int phrase=0; phrase < n_phrases; phrase++)
+ {
List<Edge> contexts = c.getEdgesForPhrase(phrase);
- for (int ctx=0; ctx<contexts.size(); ctx++){
+ for (int ctx=0; ctx<contexts.size(); ctx++)
+ {
Edge edge = contexts.get(ctx);
double p[]=posterior(edge);
double z = arr.F.l1norm(p);
@@ -74,34 +93,127 @@ public class PhraseCluster { int count = edge.getCount();
//increment expected count
TIntArrayList context = edge.getContext();
- for(int tag=0;tag<K;tag++){
- for(int pos=0;pos<n_positions;pos++){
- exp_emit[tag][pos][context.get(pos)]+=p[tag]*count;
- }
-
+ for(int tag=0;tag<K;tag++)
+ {
+ for(int pos=0;pos<n_positions;pos++)
+ exp_emit[tag][pos][context.get(pos)]+=p[tag]*count;
exp_pi[phrase][tag]+=p[tag]*count;
}
}
}
-
- //System.out.println("Log likelihood: "+loglikelihood);
-
+
//M
- for(double [][]i:exp_emit){
- for(double []j:i){
+ for(double [][]i:exp_emit)
+ for(double []j:i)
arr.F.l1normalize(j);
- }
- }
+ for(double []j:exp_pi)
+ arr.F.l1normalize(j);
+
emit=exp_emit;
+ pi=exp_pi;
+
+ return loglikelihood;
+ }
+
+ public double VBEM()
+ {
+ // FIXME: broken - needs to be done entirely in log-space
- for(double []j:exp_pi){
- arr.F.l1normalize(j);
- }
+ double [][][]exp_emit = new double [K][n_positions][n_words];
+ double [][]exp_pi = new double[n_phrases][K];
+ double loglikelihood=0;
+
+ //E
+ for(int phrase=0; phrase < n_phrases; phrase++)
+ {
+ List<Edge> contexts = c.getEdgesForPhrase(phrase);
+
+ for (int ctx=0; ctx<contexts.size(); ctx++)
+ {
+ Edge edge = contexts.get(ctx);
+ double p[] = posterior(edge);
+ double z = arr.F.l1norm(p);
+ assert z > 0;
+ loglikelihood += edge.getCount() * Math.log(z);
+ arr.F.l1normalize(p);
+
+ int count = edge.getCount();
+ //increment expected count
+ TIntArrayList context = edge.getContext();
+ for(int tag=0;tag<K;tag++)
+ {
+ for(int pos=0;pos<n_positions;pos++)
+ exp_emit[tag][pos][context.get(pos)] += p[tag]*count;
+ exp_pi[phrase][tag] += p[tag]*count;
+ }
+ }
+ }
+
+ // find the KL terms, KL(q||p) where p is symmetric Dirichlet prior and q are the expectations
+ double kl = 0;
+ for (int phrase=0; phrase < n_phrases; phrase++)
+ kl += KL_symmetric_dirichlet(exp_pi[phrase], alphaPi);
+
+ for (int tag=0;tag<K;tag++)
+ for (int pos=0;pos<n_positions; ++pos)
+ kl += this.KL_symmetric_dirichlet(exp_emit[tag][pos], alphaEmit);
+ // FIXME: exp_emit[tag][pos] has structural zeros - certain words are *never* seen in that position
+
+ //M
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ digammaNormalize(j, alphaEmit);
+ emit=exp_emit;
+ for(double []j:exp_pi)
+ digammaNormalize(j, alphaPi);
pi=exp_pi;
+
+ System.out.println("KL=" + kl + " llh=" + loglikelihood);
+ System.out.println(Arrays.toString(pi[0]));
+ System.out.println(Arrays.toString(exp_emit[0][0]));
+ return kl + loglikelihood;
+ }
+
+ public void digammaNormalize(double [] a, double alpha)
+ {
+ double sum=0;
+ for(int i=0;i<a.length;i++)
+ sum += a[i];
- return loglikelihood;
+ assert sum > 1e-20;
+ double dgs = Gamma.digamma(sum + alpha);
+
+ for(int i=0;i<a.length;i++)
+ a[i] = Math.exp(Gamma.digamma(a[i] + alpha/a.length) - dgs);
+ }
+
+ private double KL_symmetric_dirichlet(double[] q, double alpha)
+ {
+ // assumes that zeros in q are structural & should be skipped
+
+ double p0 = alpha;
+ double q0 = 0;
+ int n = 0;
+ for (int i=0; i<q.length; i++)
+ {
+ if (q[i] > 0)
+ {
+ q0 += q[i];
+ n += 1;
+ }
+ }
+
+ double kl = Gamma.logGamma(q0) - Gamma.logGamma(p0);
+ kl += n * Gamma.logGamma(alpha / n);
+ double digamma_q0 = Gamma.digamma(q0);
+ for (int i=0; i<q.length; i++)
+ {
+ if (q[i] > 0)
+ kl -= -Gamma.logGamma(q[i]) - (q[i] - alpha/q.length) * (Gamma.digamma(q[i]) - digamma_q0);
+ }
+ return kl;
}
public double PREM_phrase_constraints(){
@@ -117,7 +229,7 @@ public class PhraseCluster { PhraseObjective po=new PhraseObjective(this,phrase);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) ++failures;
- iterations += po.iterations;
+ iterations += po.getNumberUpdateCalls();
double [][] q=po.posterior();
loglikelihood += po.loglikelihood();
kl += po.KL_divergence();
@@ -142,24 +254,19 @@ public class PhraseCluster { if (failures > 0)
System.out.println("WARNING: failed to converge in " + failures + "/" + n_phrases + " cases");
- System.out.println("\tmean iters: " + iterations/(double)n_phrases);
+ System.out.println("\tmean iters: " + iterations/(double)n_phrases);
System.out.println("\tllh: " + loglikelihood);
System.out.println("\tKL: " + kl);
System.out.println("\tphrase l1lmax: " + l1lmax);
//M
- for(double [][]i:exp_emit){
- for(double []j:i){
+ for(double [][]i:exp_emit)
+ for(double []j:i)
arr.F.l1normalize(j);
- }
- }
-
emit=exp_emit;
- for(double []j:exp_pi){
+ for(double []j:exp_pi)
arr.F.l1normalize(j);
- }
-
pi=exp_pi;
return primal;
@@ -216,7 +323,7 @@ public class PhraseCluster { kl += po.KL_divergence();
l1lmax += po.l1lmax();
primal += po.primal();
- iterations += po.iterations;
+ iterations += po.getNumberUpdateCalls();
List<Edge> edges = c.getEdgesForPhrase(phrase);
@@ -241,7 +348,7 @@ public class PhraseCluster { if (failures.get() > 0)
System.out.println("WARNING: failed to converge in " + failures.get() + "/" + n_phrases + " cases");
- System.out.println("\tmean iters: " + iterations/(double)n_phrases);
+ System.out.println("\tmean iters: " + iterations/(double)n_phrases);
System.out.println("\tllh: " + loglikelihood);
System.out.println("\tKL: " + kl);
System.out.println("\tphrase l1lmax: " + l1lmax);
|