diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava')
-rw-r--r-- | gi/posterior-regularisation/prjava/build.xml | 1 | ||||
-rw-r--r-- | gi/posterior-regularisation/prjava/lib/commons-math-2.1.jar | bin | 0 -> 832410 bytes | |||
-rw-r--r-- | gi/posterior-regularisation/prjava/lib/optimization-0.1.jar | bin | 120451 -> 0 bytes | |||
-rw-r--r-- | gi/posterior-regularisation/prjava/lib/optimization.jar | bin | 263823 -> 0 bytes | |||
-rw-r--r-- | gi/posterior-regularisation/prjava/src/arr/F.java | 15 | ||||
-rw-r--r-- | gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/Objective.java | 4 | ||||
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java | 183 | ||||
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java | 7 | ||||
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/Trainer.java | 15 | ||||
-rwxr-xr-x | gi/posterior-regularisation/prjava/train-PR-cluster.sh | 2 |
10 files changed, 176 insertions, 51 deletions
diff --git a/gi/posterior-regularisation/prjava/build.xml b/gi/posterior-regularisation/prjava/build.xml index 97569bef..155c45af 100644 --- a/gi/posterior-regularisation/prjava/build.xml +++ b/gi/posterior-regularisation/prjava/build.xml @@ -7,6 +7,7 @@ <pathelement location="lib/trove-2.0.2.jar"/> <pathelement location="lib/optimization.jar"/> <pathelement location="lib/jopt-simple-3.2.jar"/> + <pathelement location="lib/commons-math-2.1.jar"/> </path> <target name="init"> diff --git a/gi/posterior-regularisation/prjava/lib/commons-math-2.1.jar b/gi/posterior-regularisation/prjava/lib/commons-math-2.1.jar Binary files differnew file mode 100644 index 00000000..43b4b369 --- /dev/null +++ b/gi/posterior-regularisation/prjava/lib/commons-math-2.1.jar diff --git a/gi/posterior-regularisation/prjava/lib/optimization-0.1.jar b/gi/posterior-regularisation/prjava/lib/optimization-0.1.jar Binary files differdeleted file mode 100644 index 2d14fce7..00000000 --- a/gi/posterior-regularisation/prjava/lib/optimization-0.1.jar +++ /dev/null diff --git a/gi/posterior-regularisation/prjava/lib/optimization.jar b/gi/posterior-regularisation/prjava/lib/optimization.jar Binary files differdeleted file mode 100644 index f6839c5b..00000000 --- a/gi/posterior-regularisation/prjava/lib/optimization.jar +++ /dev/null diff --git a/gi/posterior-regularisation/prjava/src/arr/F.java b/gi/posterior-regularisation/prjava/src/arr/F.java index 7f2b140a..54dadeac 100644 --- a/gi/posterior-regularisation/prjava/src/arr/F.java +++ b/gi/posterior-regularisation/prjava/src/arr/F.java @@ -4,18 +4,25 @@ import java.util.Random; public class F {
public static Random rng = new Random();
-
+
public static void randomise(double probs[])
{
+ randomise(probs, true);
+ }
+
+ public static void randomise(double probs[], boolean normalise)
+ {
double z = 0;
for (int i = 0; i < probs.length; ++i)
{
probs[i] = 3 + rng.nextDouble();
- z += probs[i];
+ if (normalise)
+ z += probs[i];
}
- for (int i = 0; i < probs.length; ++i)
- probs[i] /= z;
+ if (normalise)
+ for (int i = 0; i < probs.length; ++i)
+ probs[i] /= z;
}
public static void l1normalize(double [] a){
diff --git a/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/Objective.java b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/Objective.java index 0e2e27ac..6be01bf9 100644 --- a/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/Objective.java +++ b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/Objective.java @@ -59,6 +59,10 @@ public abstract class Objective { return gradientCalls; } + public int getNumberUpdateCalls() { + return updateCalls; + } + public String finalInfoString() { return "FE: " + functionCalls + " GE " + gradientCalls + " Params updates" + updateCalls; diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index b9b1b98c..7bc63c33 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -1,6 +1,7 @@ package phrase;
import gnu.trove.TIntArrayList;
+import org.apache.commons.math.special.Gamma;
import io.FileUtil;
import java.io.IOException;
import java.io.PrintStream;
@@ -12,6 +13,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger;
import phrase.Corpus.Edge;
+import util.MathUtil;
public class PhraseCluster {
@@ -26,7 +28,12 @@ public class PhraseCluster { // pi[phrase][tag] = p(tag | phrase)
private double pi[][];
- public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){
+ double alphaEmit;
+ double alphaPi;
+
+ public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads,
+ double alphaEmit, double alphaPi)
+ {
K=numCluster;
c=corpus;
n_words=c.getNumWords();
@@ -41,29 +48,41 @@ public class PhraseCluster { emit=new double [K][n_positions][n_words];
pi=new double[n_phrases][K];
- for(double [][]i:emit){
- for(double []j:i){
- arr.F.randomise(j);
+ for(double [][]i:emit)
+ {
+ for(double []j:i)
+ {
+ arr.F.randomise(j, alphaEmit <= 0);
+ if (alphaEmit > 0)
+ digammaNormalize(j, alphaEmit);
}
}
- for(double []j:pi){
- arr.F.randomise(j);
+ for(double []j:pi)
+ {
+ arr.F.randomise(j, alphaPi <= 0);
+ if (alphaPi > 0)
+ digammaNormalize(j, alphaPi);
}
+
+ this.alphaEmit = alphaEmit;
+ this.alphaPi = alphaPi;
}
-
- public double EM(){
+ public double EM()
+ {
double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
double loglikelihood=0;
//E
- for(int phrase=0; phrase < n_phrases; phrase++){
+ for(int phrase=0; phrase < n_phrases; phrase++)
+ {
List<Edge> contexts = c.getEdgesForPhrase(phrase);
- for (int ctx=0; ctx<contexts.size(); ctx++){
+ for (int ctx=0; ctx<contexts.size(); ctx++)
+ {
Edge edge = contexts.get(ctx);
double p[]=posterior(edge);
double z = arr.F.l1norm(p);
@@ -74,34 +93,127 @@ public class PhraseCluster { int count = edge.getCount();
//increment expected count
TIntArrayList context = edge.getContext();
- for(int tag=0;tag<K;tag++){
- for(int pos=0;pos<n_positions;pos++){
- exp_emit[tag][pos][context.get(pos)]+=p[tag]*count;
- }
-
+ for(int tag=0;tag<K;tag++)
+ {
+ for(int pos=0;pos<n_positions;pos++)
+ exp_emit[tag][pos][context.get(pos)]+=p[tag]*count;
exp_pi[phrase][tag]+=p[tag]*count;
}
}
}
-
- //System.out.println("Log likelihood: "+loglikelihood);
-
+
//M
- for(double [][]i:exp_emit){
- for(double []j:i){
+ for(double [][]i:exp_emit)
+ for(double []j:i)
arr.F.l1normalize(j);
- }
- }
+ for(double []j:exp_pi)
+ arr.F.l1normalize(j);
+
emit=exp_emit;
+ pi=exp_pi;
+
+ return loglikelihood;
+ }
+
+ public double VBEM()
+ {
+ // FIXME: broken - needs to be done entirely in log-space
- for(double []j:exp_pi){
- arr.F.l1normalize(j);
- }
+ double [][][]exp_emit = new double [K][n_positions][n_words];
+ double [][]exp_pi = new double[n_phrases][K];
+ double loglikelihood=0;
+
+ //E
+ for(int phrase=0; phrase < n_phrases; phrase++)
+ {
+ List<Edge> contexts = c.getEdgesForPhrase(phrase);
+
+ for (int ctx=0; ctx<contexts.size(); ctx++)
+ {
+ Edge edge = contexts.get(ctx);
+ double p[] = posterior(edge);
+ double z = arr.F.l1norm(p);
+ assert z > 0;
+ loglikelihood += edge.getCount() * Math.log(z);
+ arr.F.l1normalize(p);
+
+ int count = edge.getCount();
+ //increment expected count
+ TIntArrayList context = edge.getContext();
+ for(int tag=0;tag<K;tag++)
+ {
+ for(int pos=0;pos<n_positions;pos++)
+ exp_emit[tag][pos][context.get(pos)] += p[tag]*count;
+ exp_pi[phrase][tag] += p[tag]*count;
+ }
+ }
+ }
+
+ // find the KL terms, KL(q||p) where p is symmetric Dirichlet prior and q are the expectations
+ double kl = 0;
+ for (int phrase=0; phrase < n_phrases; phrase++)
+ kl += KL_symmetric_dirichlet(exp_pi[phrase], alphaPi);
+
+ for (int tag=0;tag<K;tag++)
+ for (int pos=0;pos<n_positions; ++pos)
+ kl += this.KL_symmetric_dirichlet(exp_emit[tag][pos], alphaEmit);
+ // FIXME: exp_emit[tag][pos] has structural zeros - certain words are *never* seen in that position
+
+ //M
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ digammaNormalize(j, alphaEmit);
+ emit=exp_emit;
+ for(double []j:exp_pi)
+ digammaNormalize(j, alphaPi);
pi=exp_pi;
+
+ System.out.println("KL=" + kl + " llh=" + loglikelihood);
+ System.out.println(Arrays.toString(pi[0]));
+ System.out.println(Arrays.toString(exp_emit[0][0]));
+ return kl + loglikelihood;
+ }
+
+ public void digammaNormalize(double [] a, double alpha)
+ {
+ double sum=0;
+ for(int i=0;i<a.length;i++)
+ sum += a[i];
- return loglikelihood;
+ assert sum > 1e-20;
+ double dgs = Gamma.digamma(sum + alpha);
+
+ for(int i=0;i<a.length;i++)
+ a[i] = Math.exp(Gamma.digamma(a[i] + alpha/a.length) - dgs);
+ }
+
+ private double KL_symmetric_dirichlet(double[] q, double alpha)
+ {
+ // assumes that zeros in q are structural & should be skipped
+
+ double p0 = alpha;
+ double q0 = 0;
+ int n = 0;
+ for (int i=0; i<q.length; i++)
+ {
+ if (q[i] > 0)
+ {
+ q0 += q[i];
+ n += 1;
+ }
+ }
+
+ double kl = Gamma.logGamma(q0) - Gamma.logGamma(p0);
+ kl += n * Gamma.logGamma(alpha / n);
+ double digamma_q0 = Gamma.digamma(q0);
+ for (int i=0; i<q.length; i++)
+ {
+ if (q[i] > 0)
+ kl -= -Gamma.logGamma(q[i]) - (q[i] - alpha/q.length) * (Gamma.digamma(q[i]) - digamma_q0);
+ }
+ return kl;
}
public double PREM_phrase_constraints(){
@@ -117,7 +229,7 @@ public class PhraseCluster { PhraseObjective po=new PhraseObjective(this,phrase);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) ++failures;
- iterations += po.iterations;
+ iterations += po.getNumberUpdateCalls();
double [][] q=po.posterior();
loglikelihood += po.loglikelihood();
kl += po.KL_divergence();
@@ -142,24 +254,19 @@ public class PhraseCluster { if (failures > 0)
System.out.println("WARNING: failed to converge in " + failures + "/" + n_phrases + " cases");
- System.out.println("\tmean iters: " + iterations/(double)n_phrases);
+ System.out.println("\tmean iters: " + iterations/(double)n_phrases);
System.out.println("\tllh: " + loglikelihood);
System.out.println("\tKL: " + kl);
System.out.println("\tphrase l1lmax: " + l1lmax);
//M
- for(double [][]i:exp_emit){
- for(double []j:i){
+ for(double [][]i:exp_emit)
+ for(double []j:i)
arr.F.l1normalize(j);
- }
- }
-
emit=exp_emit;
- for(double []j:exp_pi){
+ for(double []j:exp_pi)
arr.F.l1normalize(j);
- }
-
pi=exp_pi;
return primal;
@@ -216,7 +323,7 @@ public class PhraseCluster { kl += po.KL_divergence();
l1lmax += po.l1lmax();
primal += po.primal();
- iterations += po.iterations;
+ iterations += po.getNumberUpdateCalls();
List<Edge> edges = c.getEdgesForPhrase(phrase);
@@ -241,7 +348,7 @@ public class PhraseCluster { if (failures.get() > 0)
System.out.println("WARNING: failed to converge in " + failures.get() + "/" + n_phrases + " cases");
- System.out.println("\tmean iters: " + iterations/(double)n_phrases);
+ System.out.println("\tmean iters: " + iterations/(double)n_phrases);
System.out.println("\tllh: " + loglikelihood);
System.out.println("\tKL: " + kl);
System.out.println("\tphrase l1lmax: " + l1lmax);
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index f24b903d..cc12546d 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -1,7 +1,5 @@ package phrase;
-import java.io.PrintStream;
-import java.util.Arrays;
import java.util.List;
import optimization.gradientBasedMethods.ProjectedGradientDescent;
@@ -163,9 +161,7 @@ public class PhraseObjective extends ProjectedObjective public double [][]posterior(){
return q;
}
-
- public int iterations = 0;
-
+
public boolean optimizeWithProjectedGradientDescent(){
LineSearchMethod ls =
new ArmijoLineSearchMinimizationAlongProjectionArc
@@ -184,7 +180,6 @@ public class PhraseObjective extends ProjectedObjective optimizer.setMaxIterations(ITERATIONS);
updateFunction();
boolean success = optimizer.optimize(this,stats,compositeStop);
- iterations += optimizer.getCurrentIteration();
// System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
//if(succed){
//System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index b19f3fb9..439fb337 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -27,6 +27,9 @@ public class Trainer parser.accepts("scale-context").withRequiredArg().ofType(Double.class).defaultsTo(0.0); parser.accepts("seed").withRequiredArg().ofType(Long.class).defaultsTo(0l); parser.accepts("convergence-threshold").withRequiredArg().ofType(Double.class).defaultsTo(1e-6); + parser.accepts("variational-bayes"); + parser.accepts("alpha-emit").withRequiredArg().ofType(Double.class).defaultsTo(0.1); + parser.accepts("alpha-pi").withRequiredArg().ofType(Double.class).defaultsTo(0.01); OptionSet options = parser.parse(args); if (options.has("help") || !options.has("in")) @@ -47,6 +50,9 @@ public class Trainer double scale_context = (Double) options.valueOf("scale-context"); int threads = (Integer) options.valueOf("threads"); double threshold = (Double) options.valueOf("convergence-threshold"); + boolean vb = options.has("variational-bayes"); + double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0; + double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0; if (options.has("seed")) F.rng = new Random((Long) options.valueOf("seed")); @@ -75,14 +81,19 @@ public class Trainer "and " + threads + " threads"); System.out.println(); - PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads); + PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads, alphaEmit, alphaPi); double last = 0; for (int i=0; i<em_iterations+pr_iterations; i++) { double o; if (i < em_iterations) - o = cluster.EM(); + { + if (!vb) + o = cluster.EM(); + else + o = cluster.VBEM(); + } else if (scale_context == 0) { if (threads >= 1) diff --git a/gi/posterior-regularisation/prjava/train-PR-cluster.sh b/gi/posterior-regularisation/prjava/train-PR-cluster.sh index 41bb403f..4d4c68d0 100755 --- a/gi/posterior-regularisation/prjava/train-PR-cluster.sh +++ b/gi/posterior-regularisation/prjava/train-PR-cluster.sh @@ -1,4 +1,4 @@ #!/bin/sh d=`dirname $0` -java -ea -Xmx8g -cp $d/prjava.jar:$d/lib/trove-2.0.2.jar:$d/lib/optimization.jar:$d/lib/jopt-simple-3.2.jar phrase.Trainer $* +java -ea -Xmx8g -cp $d/prjava.jar:$d/lib/trove-2.0.2.jar:$d/lib/optimization.jar:$d/lib/jopt-simple-3.2.jar:$d/lib/lib/commons-math-2.1.jar phrase.Trainer $* |