summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava
diff options
context:
space:
mode:
authortrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-09 22:29:02 +0000
committertrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-09 22:29:02 +0000
commit9801ac3df2cbf2656b8d21b2fb0046bfb4046e98 (patch)
tree02f3be210f1a5a060f6ea89cf6093e1ec9dfab95 /gi/posterior-regularisation/prjava
parent6211d023c559f3969ac0a827f4635c5b0959f230 (diff)
Added initial VB implementation for symetric Dirichlet prior.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@215 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava')
-rw-r--r--gi/posterior-regularisation/prjava/build.xml1
-rw-r--r--gi/posterior-regularisation/prjava/lib/commons-math-2.1.jarbin0 -> 832410 bytes
-rw-r--r--gi/posterior-regularisation/prjava/lib/optimization-0.1.jarbin120451 -> 0 bytes
-rw-r--r--gi/posterior-regularisation/prjava/lib/optimization.jarbin263823 -> 0 bytes
-rw-r--r--gi/posterior-regularisation/prjava/src/arr/F.java15
-rw-r--r--gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/Objective.java4
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java183
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java7
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Trainer.java15
-rwxr-xr-xgi/posterior-regularisation/prjava/train-PR-cluster.sh2
10 files changed, 176 insertions, 51 deletions
diff --git a/gi/posterior-regularisation/prjava/build.xml b/gi/posterior-regularisation/prjava/build.xml
index 97569bef..155c45af 100644
--- a/gi/posterior-regularisation/prjava/build.xml
+++ b/gi/posterior-regularisation/prjava/build.xml
@@ -7,6 +7,7 @@
<pathelement location="lib/trove-2.0.2.jar"/>
<pathelement location="lib/optimization.jar"/>
<pathelement location="lib/jopt-simple-3.2.jar"/>
+ <pathelement location="lib/commons-math-2.1.jar"/>
</path>
<target name="init">
diff --git a/gi/posterior-regularisation/prjava/lib/commons-math-2.1.jar b/gi/posterior-regularisation/prjava/lib/commons-math-2.1.jar
new file mode 100644
index 00000000..43b4b369
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/lib/commons-math-2.1.jar
Binary files differ
diff --git a/gi/posterior-regularisation/prjava/lib/optimization-0.1.jar b/gi/posterior-regularisation/prjava/lib/optimization-0.1.jar
deleted file mode 100644
index 2d14fce7..00000000
--- a/gi/posterior-regularisation/prjava/lib/optimization-0.1.jar
+++ /dev/null
Binary files differ
diff --git a/gi/posterior-regularisation/prjava/lib/optimization.jar b/gi/posterior-regularisation/prjava/lib/optimization.jar
deleted file mode 100644
index f6839c5b..00000000
--- a/gi/posterior-regularisation/prjava/lib/optimization.jar
+++ /dev/null
Binary files differ
diff --git a/gi/posterior-regularisation/prjava/src/arr/F.java b/gi/posterior-regularisation/prjava/src/arr/F.java
index 7f2b140a..54dadeac 100644
--- a/gi/posterior-regularisation/prjava/src/arr/F.java
+++ b/gi/posterior-regularisation/prjava/src/arr/F.java
@@ -4,18 +4,25 @@ import java.util.Random;
public class F {
public static Random rng = new Random();
-
+
public static void randomise(double probs[])
{
+ randomise(probs, true);
+ }
+
+ public static void randomise(double probs[], boolean normalise)
+ {
double z = 0;
for (int i = 0; i < probs.length; ++i)
{
probs[i] = 3 + rng.nextDouble();
- z += probs[i];
+ if (normalise)
+ z += probs[i];
}
- for (int i = 0; i < probs.length; ++i)
- probs[i] /= z;
+ if (normalise)
+ for (int i = 0; i < probs.length; ++i)
+ probs[i] /= z;
}
public static void l1normalize(double [] a){
diff --git a/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/Objective.java b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/Objective.java
index 0e2e27ac..6be01bf9 100644
--- a/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/Objective.java
+++ b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/Objective.java
@@ -59,6 +59,10 @@ public abstract class Objective {
return gradientCalls;
}
+ public int getNumberUpdateCalls() {
+ return updateCalls;
+ }
+
public String finalInfoString() {
return "FE: " + functionCalls + " GE " + gradientCalls + " Params updates" +
updateCalls;
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index b9b1b98c..7bc63c33 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -1,6 +1,7 @@
package phrase;
import gnu.trove.TIntArrayList;
+import org.apache.commons.math.special.Gamma;
import io.FileUtil;
import java.io.IOException;
import java.io.PrintStream;
@@ -12,6 +13,7 @@ import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import phrase.Corpus.Edge;
+import util.MathUtil;
public class PhraseCluster {
@@ -26,7 +28,12 @@ public class PhraseCluster {
// pi[phrase][tag] = p(tag | phrase)
private double pi[][];
- public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){
+ double alphaEmit;
+ double alphaPi;
+
+ public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads,
+ double alphaEmit, double alphaPi)
+ {
K=numCluster;
c=corpus;
n_words=c.getNumWords();
@@ -41,29 +48,41 @@ public class PhraseCluster {
emit=new double [K][n_positions][n_words];
pi=new double[n_phrases][K];
- for(double [][]i:emit){
- for(double []j:i){
- arr.F.randomise(j);
+ for(double [][]i:emit)
+ {
+ for(double []j:i)
+ {
+ arr.F.randomise(j, alphaEmit <= 0);
+ if (alphaEmit > 0)
+ digammaNormalize(j, alphaEmit);
}
}
- for(double []j:pi){
- arr.F.randomise(j);
+ for(double []j:pi)
+ {
+ arr.F.randomise(j, alphaPi <= 0);
+ if (alphaPi > 0)
+ digammaNormalize(j, alphaPi);
}
+
+ this.alphaEmit = alphaEmit;
+ this.alphaPi = alphaPi;
}
-
- public double EM(){
+ public double EM()
+ {
double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
double loglikelihood=0;
//E
- for(int phrase=0; phrase < n_phrases; phrase++){
+ for(int phrase=0; phrase < n_phrases; phrase++)
+ {
List<Edge> contexts = c.getEdgesForPhrase(phrase);
- for (int ctx=0; ctx<contexts.size(); ctx++){
+ for (int ctx=0; ctx<contexts.size(); ctx++)
+ {
Edge edge = contexts.get(ctx);
double p[]=posterior(edge);
double z = arr.F.l1norm(p);
@@ -74,34 +93,127 @@ public class PhraseCluster {
int count = edge.getCount();
//increment expected count
TIntArrayList context = edge.getContext();
- for(int tag=0;tag<K;tag++){
- for(int pos=0;pos<n_positions;pos++){
- exp_emit[tag][pos][context.get(pos)]+=p[tag]*count;
- }
-
+ for(int tag=0;tag<K;tag++)
+ {
+ for(int pos=0;pos<n_positions;pos++)
+ exp_emit[tag][pos][context.get(pos)]+=p[tag]*count;
exp_pi[phrase][tag]+=p[tag]*count;
}
}
}
-
- //System.out.println("Log likelihood: "+loglikelihood);
-
+
//M
- for(double [][]i:exp_emit){
- for(double []j:i){
+ for(double [][]i:exp_emit)
+ for(double []j:i)
arr.F.l1normalize(j);
- }
- }
+ for(double []j:exp_pi)
+ arr.F.l1normalize(j);
+
emit=exp_emit;
+ pi=exp_pi;
+
+ return loglikelihood;
+ }
+
+ public double VBEM()
+ {
+ // FIXME: broken - needs to be done entirely in log-space
- for(double []j:exp_pi){
- arr.F.l1normalize(j);
- }
+ double [][][]exp_emit = new double [K][n_positions][n_words];
+ double [][]exp_pi = new double[n_phrases][K];
+ double loglikelihood=0;
+
+ //E
+ for(int phrase=0; phrase < n_phrases; phrase++)
+ {
+ List<Edge> contexts = c.getEdgesForPhrase(phrase);
+
+ for (int ctx=0; ctx<contexts.size(); ctx++)
+ {
+ Edge edge = contexts.get(ctx);
+ double p[] = posterior(edge);
+ double z = arr.F.l1norm(p);
+ assert z > 0;
+ loglikelihood += edge.getCount() * Math.log(z);
+ arr.F.l1normalize(p);
+
+ int count = edge.getCount();
+ //increment expected count
+ TIntArrayList context = edge.getContext();
+ for(int tag=0;tag<K;tag++)
+ {
+ for(int pos=0;pos<n_positions;pos++)
+ exp_emit[tag][pos][context.get(pos)] += p[tag]*count;
+ exp_pi[phrase][tag] += p[tag]*count;
+ }
+ }
+ }
+
+ // find the KL terms, KL(q||p) where p is symmetric Dirichlet prior and q are the expectations
+ double kl = 0;
+ for (int phrase=0; phrase < n_phrases; phrase++)
+ kl += KL_symmetric_dirichlet(exp_pi[phrase], alphaPi);
+
+ for (int tag=0;tag<K;tag++)
+ for (int pos=0;pos<n_positions; ++pos)
+ kl += this.KL_symmetric_dirichlet(exp_emit[tag][pos], alphaEmit);
+ // FIXME: exp_emit[tag][pos] has structural zeros - certain words are *never* seen in that position
+
+ //M
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ digammaNormalize(j, alphaEmit);
+ emit=exp_emit;
+ for(double []j:exp_pi)
+ digammaNormalize(j, alphaPi);
pi=exp_pi;
+
+ System.out.println("KL=" + kl + " llh=" + loglikelihood);
+ System.out.println(Arrays.toString(pi[0]));
+ System.out.println(Arrays.toString(exp_emit[0][0]));
+ return kl + loglikelihood;
+ }
+
+ public void digammaNormalize(double [] a, double alpha)
+ {
+ double sum=0;
+ for(int i=0;i<a.length;i++)
+ sum += a[i];
- return loglikelihood;
+ assert sum > 1e-20;
+ double dgs = Gamma.digamma(sum + alpha);
+
+ for(int i=0;i<a.length;i++)
+ a[i] = Math.exp(Gamma.digamma(a[i] + alpha/a.length) - dgs);
+ }
+
+ private double KL_symmetric_dirichlet(double[] q, double alpha)
+ {
+ // assumes that zeros in q are structural & should be skipped
+
+ double p0 = alpha;
+ double q0 = 0;
+ int n = 0;
+ for (int i=0; i<q.length; i++)
+ {
+ if (q[i] > 0)
+ {
+ q0 += q[i];
+ n += 1;
+ }
+ }
+
+ double kl = Gamma.logGamma(q0) - Gamma.logGamma(p0);
+ kl += n * Gamma.logGamma(alpha / n);
+ double digamma_q0 = Gamma.digamma(q0);
+ for (int i=0; i<q.length; i++)
+ {
+ if (q[i] > 0)
+ kl -= -Gamma.logGamma(q[i]) - (q[i] - alpha/q.length) * (Gamma.digamma(q[i]) - digamma_q0);
+ }
+ return kl;
}
public double PREM_phrase_constraints(){
@@ -117,7 +229,7 @@ public class PhraseCluster {
PhraseObjective po=new PhraseObjective(this,phrase);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) ++failures;
- iterations += po.iterations;
+ iterations += po.getNumberUpdateCalls();
double [][] q=po.posterior();
loglikelihood += po.loglikelihood();
kl += po.KL_divergence();
@@ -142,24 +254,19 @@ public class PhraseCluster {
if (failures > 0)
System.out.println("WARNING: failed to converge in " + failures + "/" + n_phrases + " cases");
- System.out.println("\tmean iters: " + iterations/(double)n_phrases);
+ System.out.println("\tmean iters: " + iterations/(double)n_phrases);
System.out.println("\tllh: " + loglikelihood);
System.out.println("\tKL: " + kl);
System.out.println("\tphrase l1lmax: " + l1lmax);
//M
- for(double [][]i:exp_emit){
- for(double []j:i){
+ for(double [][]i:exp_emit)
+ for(double []j:i)
arr.F.l1normalize(j);
- }
- }
-
emit=exp_emit;
- for(double []j:exp_pi){
+ for(double []j:exp_pi)
arr.F.l1normalize(j);
- }
-
pi=exp_pi;
return primal;
@@ -216,7 +323,7 @@ public class PhraseCluster {
kl += po.KL_divergence();
l1lmax += po.l1lmax();
primal += po.primal();
- iterations += po.iterations;
+ iterations += po.getNumberUpdateCalls();
List<Edge> edges = c.getEdgesForPhrase(phrase);
@@ -241,7 +348,7 @@ public class PhraseCluster {
if (failures.get() > 0)
System.out.println("WARNING: failed to converge in " + failures.get() + "/" + n_phrases + " cases");
- System.out.println("\tmean iters: " + iterations/(double)n_phrases);
+ System.out.println("\tmean iters: " + iterations/(double)n_phrases);
System.out.println("\tllh: " + loglikelihood);
System.out.println("\tKL: " + kl);
System.out.println("\tphrase l1lmax: " + l1lmax);
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
index f24b903d..cc12546d 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
@@ -1,7 +1,5 @@
package phrase;
-import java.io.PrintStream;
-import java.util.Arrays;
import java.util.List;
import optimization.gradientBasedMethods.ProjectedGradientDescent;
@@ -163,9 +161,7 @@ public class PhraseObjective extends ProjectedObjective
public double [][]posterior(){
return q;
}
-
- public int iterations = 0;
-
+
public boolean optimizeWithProjectedGradientDescent(){
LineSearchMethod ls =
new ArmijoLineSearchMinimizationAlongProjectionArc
@@ -184,7 +180,6 @@ public class PhraseObjective extends ProjectedObjective
optimizer.setMaxIterations(ITERATIONS);
updateFunction();
boolean success = optimizer.optimize(this,stats,compositeStop);
- iterations += optimizer.getCurrentIteration();
// System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
//if(succed){
//System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
index b19f3fb9..439fb337 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
@@ -27,6 +27,9 @@ public class Trainer
parser.accepts("scale-context").withRequiredArg().ofType(Double.class).defaultsTo(0.0);
parser.accepts("seed").withRequiredArg().ofType(Long.class).defaultsTo(0l);
parser.accepts("convergence-threshold").withRequiredArg().ofType(Double.class).defaultsTo(1e-6);
+ parser.accepts("variational-bayes");
+ parser.accepts("alpha-emit").withRequiredArg().ofType(Double.class).defaultsTo(0.1);
+ parser.accepts("alpha-pi").withRequiredArg().ofType(Double.class).defaultsTo(0.01);
OptionSet options = parser.parse(args);
if (options.has("help") || !options.has("in"))
@@ -47,6 +50,9 @@ public class Trainer
double scale_context = (Double) options.valueOf("scale-context");
int threads = (Integer) options.valueOf("threads");
double threshold = (Double) options.valueOf("convergence-threshold");
+ boolean vb = options.has("variational-bayes");
+ double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0;
+ double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0;
if (options.has("seed"))
F.rng = new Random((Long) options.valueOf("seed"));
@@ -75,14 +81,19 @@ public class Trainer
"and " + threads + " threads");
System.out.println();
- PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads);
+ PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads, alphaEmit, alphaPi);
double last = 0;
for (int i=0; i<em_iterations+pr_iterations; i++)
{
double o;
if (i < em_iterations)
- o = cluster.EM();
+ {
+ if (!vb)
+ o = cluster.EM();
+ else
+ o = cluster.VBEM();
+ }
else if (scale_context == 0)
{
if (threads >= 1)
diff --git a/gi/posterior-regularisation/prjava/train-PR-cluster.sh b/gi/posterior-regularisation/prjava/train-PR-cluster.sh
index 41bb403f..4d4c68d0 100755
--- a/gi/posterior-regularisation/prjava/train-PR-cluster.sh
+++ b/gi/posterior-regularisation/prjava/train-PR-cluster.sh
@@ -1,4 +1,4 @@
#!/bin/sh
d=`dirname $0`
-java -ea -Xmx8g -cp $d/prjava.jar:$d/lib/trove-2.0.2.jar:$d/lib/optimization.jar:$d/lib/jopt-simple-3.2.jar phrase.Trainer $*
+java -ea -Xmx8g -cp $d/prjava.jar:$d/lib/trove-2.0.2.jar:$d/lib/optimization.jar:$d/lib/jopt-simple-3.2.jar:$d/lib/lib/commons-math-2.1.jar phrase.Trainer $*