From 76ef39de737e7abc0a8fe989dfacb7885617e59f Mon Sep 17 00:00:00 2001 From: desaicwtf Date: Fri, 23 Jul 2010 17:08:53 +0000 Subject: vb runnable from trainer git-svn-id: https://ws10smt.googlecode.com/svn/trunk@380 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../prjava/src/phrase/Trainer.java | 43 +++++----- .../prjava/src/phrase/VB.java | 97 ++++++++++++++++------ 2 files changed, 92 insertions(+), 48 deletions(-) (limited to 'gi/posterior-regularisation/prjava') diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index b51db919..cea6a20a 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -18,7 +18,7 @@ public class Trainer { public static void main(String[] args) { - + OptionParser parser = new OptionParser(); parser.accepts("help"); parser.accepts("in").withRequiredArg().ofType(File.class); @@ -107,6 +107,7 @@ public class Trainer PhraseCluster cluster = null; Agree2Sides agree2sides = null; Agree agree= null; + VB vbModel=null; if (options.has("agree-language")) agree2sides = new Agree2Sides(tags, corpus,corpus1); else if (options.has("agree-direction")) @@ -115,7 +116,11 @@ public class Trainer { cluster = new PhraseCluster(tags, corpus); if (threads > 0) cluster.useThreadPool(threads); - if (vb) cluster.initialiseVB(alphaEmit, alphaPi); + + if (vb) { + //cluster.initialiseVB(alphaEmit, alphaPi); + vbModel=new VB(tags,corpus); + } if (options.has("no-parameter-cache")) cluster.cacheLambda = false; if (options.has("start")) @@ -149,7 +154,7 @@ public class Trainer if (!vb) o = cluster.EM((i < skip) ? i+1 : 0); else - o = cluster.VBEM(alphaEmit, alphaPi); + o = vbModel.EM(); } else o = cluster.PREM(scale_phrase, scale_context, (i < skip) ? i+1 : 0); @@ -166,10 +171,8 @@ public class Trainer last = o; } - if (cluster == null && agree != null) + if (cluster == null) cluster = agree.model1; - else if (cluster == null && agree2sides != null) - cluster = agree2sides.model1; double pl1lmax = cluster.phrase_l1lmax(); double cl1lmax = cluster.context_l1lmax(); @@ -180,26 +183,20 @@ public class Trainer File outfile = (File) options.valueOf("out"); try { PrintStream ps = FileUtil.printstream(outfile); - List test = corpus.getEdges(); - if (options.has("test")) // just use the training + List test; + if (!options.has("test")) // just use the training + test = corpus.getEdges(); + else { // if --test supplied, load up the file - if (agree2sides == null) - { - infile = (File) options.valueOf("test"); - System.out.println("Reading testing concordance from " + infile); - test = corpus.readEdges(FileUtil.reader(infile)); - } - else - System.err.println("Can't run bilingual agreement model on different test data cf training (yet); --test ignored."); + infile = (File) options.valueOf("test"); + System.out.println("Reading testing concordance from " + infile); + test = corpus.readEdges(FileUtil.reader(infile)); } - - if (agree != null) - agree.displayPosterior(ps, test); - else if (agree2sides != null) - agree2sides.displayPosterior(ps); - else + if(vb){ + vbModel.displayPosterior(ps); + }else{ cluster.displayPosterior(ps, test); - + } ps.close(); } catch (IOException e) { System.err.println("Failed to open either testing file or output file"); diff --git a/gi/posterior-regularisation/prjava/src/phrase/VB.java b/gi/posterior-regularisation/prjava/src/phrase/VB.java index cc1c1c96..a858c883 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/VB.java +++ b/gi/posterior-regularisation/prjava/src/phrase/VB.java @@ -16,7 +16,7 @@ import phrase.Corpus.Edge; public class VB { - public static int MAX_ITER=40; + public static int MAX_ITER=400; /**@brief * hyper param for beta @@ -28,11 +28,13 @@ public class VB { * hyper param for theta * where theta is dirichlet for z */ - public double alpha=0.000001; + public double alpha=0.0001; /**@brief * variational param for beta */ private double rho[][][]; + private double digamma_rho[][][]; + private double rho_sum[][]; /**@brief * variational param for z */ @@ -41,8 +43,7 @@ public class VB { * variational param for theta */ private double gamma[]; - - private static double VAL_DIFF_RATIO=0.001; + private static double VAL_DIFF_RATIO=0.005; /**@brief * objective for a single document @@ -55,8 +56,8 @@ public class VB { private Corpus c; public static void main(String[] args) { - String in="../pdata/canned.con"; - //String in="../pdata/btec.con"; + // String in="../pdata/canned.con"; + String in="../pdata/btec.con"; String out="../pdata/vb.out"; int numCluster=25; Corpus corpus = null; @@ -118,6 +119,7 @@ public class VB { } } } + } private void inference(int phraseID){ @@ -128,26 +130,21 @@ public class VB { phi[i][j]=1.0/K; } } - gamma = new double[K]; - double digamma_gamma[]=new double[K]; - for(int i=0;i