From 9ac87abac855aaaa6c1dcf686b38443092a10ce6 Mon Sep 17 00:00:00 2001 From: desaicwtf Date: Thu, 22 Jul 2010 23:54:25 +0000 Subject: variational bayes inference git-svn-id: https://ws10smt.googlecode.com/svn/trunk@372 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../prjava/src/phrase/PhraseCluster.java | 24 +- .../prjava/src/phrase/VB.java | 317 +++++++++++++++++++++ 2 files changed, 329 insertions(+), 12 deletions(-) create mode 100644 gi/posterior-regularisation/prjava/src/phrase/VB.java diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 13ac14ba..ccb6ae9d 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -111,8 +111,9 @@ public class PhraseCluster { TIntArrayList context = edge.getContext(); for(int tag=0;tag= 1 && c.getPhrase(phrase).size() > phraseSizeLimit) { - System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K); + //System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K); continue; } + Arrays.fill(exp_pi, 1e-10); + // FIXME: add rare edge check to phrase objective & posterior processing PhraseObjective po = new PhraseObjective(this, phrase, scalePT, (cacheLambda) ? lambdaPT[phrase] : null); boolean ok = po.optimizeWithProjectedGradientDescent(); @@ -294,9 +295,12 @@ public class PhraseCluster { exp_emit[tag][pos][context.get(pos)]+=q[edge][tag]*contextCnt; } - exp_pi[phrase][tag]+=q[edge][tag]*contextCnt; + exp_pi[tag]+=q[edge][tag]*contextCnt; + } } + arr.F.l1normalize(exp_pi); + System.arraycopy(exp_pi, 0, pi[phrase], 0, K); } long end = System.currentTimeMillis(); @@ -313,10 +317,6 @@ public class PhraseCluster { arr.F.l1normalize(j); emit=exp_emit; - for(double []j:exp_pi) - arr.F.l1normalize(j); - pi=exp_pi; - return primal; } diff --git a/gi/posterior-regularisation/prjava/src/phrase/VB.java b/gi/posterior-regularisation/prjava/src/phrase/VB.java new file mode 100644 index 00000000..cc1c1c96 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/VB.java @@ -0,0 +1,317 @@ +package phrase; + +import gnu.trove.TIntArrayList; + +import io.FileUtil; + +import java.io.File; +import java.io.IOException; +import java.io.PrintStream; +import java.util.Arrays; +import java.util.List; + +import org.apache.commons.math.special.Gamma; + +import phrase.Corpus.Edge; + +public class VB { + + public static int MAX_ITER=40; + + /**@brief + * hyper param for beta + * where beta is multinomial + * for generating words from a topic + */ + public double lambda=0.1; + /**@brief + * hyper param for theta + * where theta is dirichlet for z + */ + public double alpha=0.000001; + /**@brief + * variational param for beta + */ + private double rho[][][]; + /**@brief + * variational param for z + */ + private double phi[][]; + /**@brief + * variational param for theta + */ + private double gamma[]; + + private static double VAL_DIFF_RATIO=0.001; + + /**@brief + * objective for a single document + */ + private double obj; + + private int n_positions; + private int n_words; + private int K; + + private Corpus c; + public static void main(String[] args) { + String in="../pdata/canned.con"; + //String in="../pdata/btec.con"; + String out="../pdata/vb.out"; + int numCluster=25; + Corpus corpus = null; + File infile = new File(in); + try { + System.out.println("Reading concordance from " + infile); + corpus = Corpus.readFromFile(FileUtil.reader(infile)); + corpus.printStats(System.out); + } catch (IOException e) { + System.err.println("Failed to open input file: " + infile); + e.printStackTrace(); + System.exit(1); + } + + VB vb=new VB(numCluster, corpus); + int iter=20; + for(int i=0;idoc=c.getEdgesForPhrase(d); + for(int n=0;n doc=c.getEdgesForPhrase(phraseID); + phi=new double[doc.size()][K]; + for(int i=0;i 0){ + phisum = log_sum(phisum, phi[n][i]); + } + else{ + phisum = phi[n][i]; + } + + }//end of a word + + for(int i=0;i1e-10){ + obj+=phi[n][i]*Math.log(phi[n][i]); + } + + double beta_sum=0; + for(int pos=0;pos0 && (obj-prev_val)/Math.abs(obj)doc=c.getEdgesForPhrase(d); + for(int n=0;n doc=c.getEdgesForPhrase(d); + for(int n=0;n