From 9801ac3df2cbf2656b8d21b2fb0046bfb4046e98 Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Fri, 9 Jul 2010 22:29:02 +0000 Subject: Added initial VB implementation for symetric Dirichlet prior. git-svn-id: https://ws10smt.googlecode.com/svn/trunk@215 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../prjava/src/phrase/PhraseCluster.java | 183 ++++++++++++++++----- 1 file changed, 145 insertions(+), 38 deletions(-) (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java') diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index b9b1b98c..7bc63c33 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -1,6 +1,7 @@ package phrase; import gnu.trove.TIntArrayList; +import org.apache.commons.math.special.Gamma; import io.FileUtil; import java.io.IOException; import java.io.PrintStream; @@ -12,6 +13,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger; import phrase.Corpus.Edge; +import util.MathUtil; public class PhraseCluster { @@ -26,7 +28,12 @@ public class PhraseCluster { // pi[phrase][tag] = p(tag | phrase) private double pi[][]; - public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){ + double alphaEmit; + double alphaPi; + + public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads, + double alphaEmit, double alphaPi) + { K=numCluster; c=corpus; n_words=c.getNumWords(); @@ -41,29 +48,41 @@ public class PhraseCluster { emit=new double [K][n_positions][n_words]; pi=new double[n_phrases][K]; - for(double [][]i:emit){ - for(double []j:i){ - arr.F.randomise(j); + for(double [][]i:emit) + { + for(double []j:i) + { + arr.F.randomise(j, alphaEmit <= 0); + if (alphaEmit > 0) + digammaNormalize(j, alphaEmit); } } - for(double []j:pi){ - arr.F.randomise(j); + for(double []j:pi) + { + arr.F.randomise(j, alphaPi <= 0); + if (alphaPi > 0) + digammaNormalize(j, alphaPi); } + + this.alphaEmit = alphaEmit; + this.alphaPi = alphaPi; } - - public double EM(){ + public double EM() + { double [][][]exp_emit=new double [K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; double loglikelihood=0; //E - for(int phrase=0; phrase < n_phrases; phrase++){ + for(int phrase=0; phrase < n_phrases; phrase++) + { List contexts = c.getEdgesForPhrase(phrase); - for (int ctx=0; ctx contexts = c.getEdgesForPhrase(phrase); + + for (int ctx=0; ctx 0; + loglikelihood += edge.getCount() * Math.log(z); + arr.F.l1normalize(p); + + int count = edge.getCount(); + //increment expected count + TIntArrayList context = edge.getContext(); + for(int tag=0;tag edges = c.getEdgesForPhrase(phrase); @@ -241,7 +348,7 @@ public class PhraseCluster { if (failures.get() > 0) System.out.println("WARNING: failed to converge in " + failures.get() + "/" + n_phrases + " cases"); - System.out.println("\tmean iters: " + iterations/(double)n_phrases); + System.out.println("\tmean iters: " + iterations/(double)n_phrases); System.out.println("\tllh: " + loglikelihood); System.out.println("\tKL: " + kl); System.out.println("\tphrase l1lmax: " + l1lmax); -- cgit v1.2.3