From 0f6b65e81a655e5935fd7260abb554f3a54d94d1 Mon Sep 17 00:00:00 2001 From: desaicwtf Date: Fri, 23 Jul 2010 17:08:53 +0000 Subject: vb runnable from trainer git-svn-id: https://ws10smt.googlecode.com/svn/trunk@380 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../prjava/src/phrase/Trainer.java | 43 ++++++++++------------ 1 file changed, 20 insertions(+), 23 deletions(-) (limited to 'gi/posterior-regularisation/prjava/src/phrase/Trainer.java') diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index b51db919..cea6a20a 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -18,7 +18,7 @@ public class Trainer { public static void main(String[] args) { - + OptionParser parser = new OptionParser(); parser.accepts("help"); parser.accepts("in").withRequiredArg().ofType(File.class); @@ -107,6 +107,7 @@ public class Trainer PhraseCluster cluster = null; Agree2Sides agree2sides = null; Agree agree= null; + VB vbModel=null; if (options.has("agree-language")) agree2sides = new Agree2Sides(tags, corpus,corpus1); else if (options.has("agree-direction")) @@ -115,7 +116,11 @@ public class Trainer { cluster = new PhraseCluster(tags, corpus); if (threads > 0) cluster.useThreadPool(threads); - if (vb) cluster.initialiseVB(alphaEmit, alphaPi); + + if (vb) { + //cluster.initialiseVB(alphaEmit, alphaPi); + vbModel=new VB(tags,corpus); + } if (options.has("no-parameter-cache")) cluster.cacheLambda = false; if (options.has("start")) @@ -149,7 +154,7 @@ public class Trainer if (!vb) o = cluster.EM((i < skip) ? i+1 : 0); else - o = cluster.VBEM(alphaEmit, alphaPi); + o = vbModel.EM(); } else o = cluster.PREM(scale_phrase, scale_context, (i < skip) ? i+1 : 0); @@ -166,10 +171,8 @@ public class Trainer last = o; } - if (cluster == null && agree != null) + if (cluster == null) cluster = agree.model1; - else if (cluster == null && agree2sides != null) - cluster = agree2sides.model1; double pl1lmax = cluster.phrase_l1lmax(); double cl1lmax = cluster.context_l1lmax(); @@ -180,26 +183,20 @@ public class Trainer File outfile = (File) options.valueOf("out"); try { PrintStream ps = FileUtil.printstream(outfile); - List test = corpus.getEdges(); - if (options.has("test")) // just use the training + List test; + if (!options.has("test")) // just use the training + test = corpus.getEdges(); + else { // if --test supplied, load up the file - if (agree2sides == null) - { - infile = (File) options.valueOf("test"); - System.out.println("Reading testing concordance from " + infile); - test = corpus.readEdges(FileUtil.reader(infile)); - } - else - System.err.println("Can't run bilingual agreement model on different test data cf training (yet); --test ignored."); + infile = (File) options.valueOf("test"); + System.out.println("Reading testing concordance from " + infile); + test = corpus.readEdges(FileUtil.reader(infile)); } - - if (agree != null) - agree.displayPosterior(ps, test); - else if (agree2sides != null) - agree2sides.displayPosterior(ps); - else + if(vb){ + vbModel.displayPosterior(ps); + }else{ cluster.displayPosterior(ps, test); - + } ps.close(); } catch (IOException e) { System.err.println("Failed to open either testing file or output file"); -- cgit v1.2.3