From 9801ac3df2cbf2656b8d21b2fb0046bfb4046e98 Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Fri, 9 Jul 2010 22:29:02 +0000 Subject: Added initial VB implementation for symetric Dirichlet prior. git-svn-id: https://ws10smt.googlecode.com/svn/trunk@215 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/posterior-regularisation/prjava/src/arr/F.java | 15 +- .../gradientBasedMethods/Objective.java | 4 + .../prjava/src/phrase/PhraseCluster.java | 183 ++++++++++++++++----- .../prjava/src/phrase/PhraseObjective.java | 7 +- .../prjava/src/phrase/Trainer.java | 15 +- 5 files changed, 174 insertions(+), 50 deletions(-) (limited to 'gi/posterior-regularisation/prjava/src') diff --git a/gi/posterior-regularisation/prjava/src/arr/F.java b/gi/posterior-regularisation/prjava/src/arr/F.java index 7f2b140a..54dadeac 100644 --- a/gi/posterior-regularisation/prjava/src/arr/F.java +++ b/gi/posterior-regularisation/prjava/src/arr/F.java @@ -4,18 +4,25 @@ import java.util.Random; public class F { public static Random rng = new Random(); - + public static void randomise(double probs[]) + { + randomise(probs, true); + } + + public static void randomise(double probs[], boolean normalise) { double z = 0; for (int i = 0; i < probs.length; ++i) { probs[i] = 3 + rng.nextDouble(); - z += probs[i]; + if (normalise) + z += probs[i]; } - for (int i = 0; i < probs.length; ++i) - probs[i] /= z; + if (normalise) + for (int i = 0; i < probs.length; ++i) + probs[i] /= z; } public static void l1normalize(double [] a){ diff --git a/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/Objective.java b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/Objective.java index 0e2e27ac..6be01bf9 100644 --- a/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/Objective.java +++ b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/Objective.java @@ -59,6 +59,10 @@ public abstract class Objective { return gradientCalls; } + public int getNumberUpdateCalls() { + return updateCalls; + } + public String finalInfoString() { return "FE: " + functionCalls + " GE " + gradientCalls + " Params updates" + updateCalls; diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index b9b1b98c..7bc63c33 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -1,6 +1,7 @@ package phrase; import gnu.trove.TIntArrayList; +import org.apache.commons.math.special.Gamma; import io.FileUtil; import java.io.IOException; import java.io.PrintStream; @@ -12,6 +13,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger; import phrase.Corpus.Edge; +import util.MathUtil; public class PhraseCluster { @@ -26,7 +28,12 @@ public class PhraseCluster { // pi[phrase][tag] = p(tag | phrase) private double pi[][]; - public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){ + double alphaEmit; + double alphaPi; + + public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads, + double alphaEmit, double alphaPi) + { K=numCluster; c=corpus; n_words=c.getNumWords(); @@ -41,29 +48,41 @@ public class PhraseCluster { emit=new double [K][n_positions][n_words]; pi=new double[n_phrases][K]; - for(double [][]i:emit){ - for(double []j:i){ - arr.F.randomise(j); + for(double [][]i:emit) + { + for(double []j:i) + { + arr.F.randomise(j, alphaEmit <= 0); + if (alphaEmit > 0) + digammaNormalize(j, alphaEmit); } } - for(double []j:pi){ - arr.F.randomise(j); + for(double []j:pi) + { + arr.F.randomise(j, alphaPi <= 0); + if (alphaPi > 0) + digammaNormalize(j, alphaPi); } + + this.alphaEmit = alphaEmit; + this.alphaPi = alphaPi; } - - public double EM(){ + public double EM() + { double [][][]exp_emit=new double [K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; double loglikelihood=0; //E - for(int phrase=0; phrase < n_phrases; phrase++){ + for(int phrase=0; phrase < n_phrases; phrase++) + { List contexts = c.getEdgesForPhrase(phrase); - for (int ctx=0; ctx contexts = c.getEdgesForPhrase(phrase); + + for (int ctx=0; ctx 0; + loglikelihood += edge.getCount() * Math.log(z); + arr.F.l1normalize(p); + + int count = edge.getCount(); + //increment expected count + TIntArrayList context = edge.getContext(); + for(int tag=0;tag edges = c.getEdgesForPhrase(phrase); @@ -241,7 +348,7 @@ public class PhraseCluster { if (failures.get() > 0) System.out.println("WARNING: failed to converge in " + failures.get() + "/" + n_phrases + " cases"); - System.out.println("\tmean iters: " + iterations/(double)n_phrases); + System.out.println("\tmean iters: " + iterations/(double)n_phrases); System.out.println("\tllh: " + loglikelihood); System.out.println("\tKL: " + kl); System.out.println("\tphrase l1lmax: " + l1lmax); diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index f24b903d..cc12546d 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -1,7 +1,5 @@ package phrase; -import java.io.PrintStream; -import java.util.Arrays; import java.util.List; import optimization.gradientBasedMethods.ProjectedGradientDescent; @@ -163,9 +161,7 @@ public class PhraseObjective extends ProjectedObjective public double [][]posterior(){ return q; } - - public int iterations = 0; - + public boolean optimizeWithProjectedGradientDescent(){ LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc @@ -184,7 +180,6 @@ public class PhraseObjective extends ProjectedObjective optimizer.setMaxIterations(ITERATIONS); updateFunction(); boolean success = optimizer.optimize(this,stats,compositeStop); - iterations += optimizer.getCurrentIteration(); // System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1)); //if(succed){ //System.out.println("Ended optimization in " + optimizer.getCurrentIteration()); diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index b19f3fb9..439fb337 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -27,6 +27,9 @@ public class Trainer parser.accepts("scale-context").withRequiredArg().ofType(Double.class).defaultsTo(0.0); parser.accepts("seed").withRequiredArg().ofType(Long.class).defaultsTo(0l); parser.accepts("convergence-threshold").withRequiredArg().ofType(Double.class).defaultsTo(1e-6); + parser.accepts("variational-bayes"); + parser.accepts("alpha-emit").withRequiredArg().ofType(Double.class).defaultsTo(0.1); + parser.accepts("alpha-pi").withRequiredArg().ofType(Double.class).defaultsTo(0.01); OptionSet options = parser.parse(args); if (options.has("help") || !options.has("in")) @@ -47,6 +50,9 @@ public class Trainer double scale_context = (Double) options.valueOf("scale-context"); int threads = (Integer) options.valueOf("threads"); double threshold = (Double) options.valueOf("convergence-threshold"); + boolean vb = options.has("variational-bayes"); + double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0; + double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0; if (options.has("seed")) F.rng = new Random((Long) options.valueOf("seed")); @@ -75,14 +81,19 @@ public class Trainer "and " + threads + " threads"); System.out.println(); - PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads); + PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads, alphaEmit, alphaPi); double last = 0; for (int i=0; i= 1) -- cgit v1.2.3