From 01739cab52552013a68843d2f64b02e868dcd281 Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Fri, 23 Jul 2010 19:26:17 +0000 Subject: Parallelised VB-EM git-svn-id: https://ws10smt.googlecode.com/svn/trunk@384 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../prjava/src/phrase/PhraseCluster.java | 178 ++++----------------- .../prjava/src/phrase/PhraseContextObjective.java | 3 + .../prjava/src/phrase/PhraseObjective.java | 2 +- .../prjava/src/phrase/Trainer.java | 83 ++++++---- .../prjava/src/phrase/VB.java | 129 ++++++++++----- 5 files changed, 182 insertions(+), 213 deletions(-) diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index ccb6ae9d..c032bb2b 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -4,14 +4,16 @@ import gnu.trove.TIntArrayList; import org.apache.commons.math.special.Gamma; import java.io.BufferedReader; -import java.io.File; import java.io.IOException; import java.io.PrintStream; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.StringTokenizer; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -56,23 +58,9 @@ public class PhraseCluster { arr.F.randomise(j, true); } - public void initialiseVB(double alphaEmit, double alphaPi) + void useThreadPool(ExecutorService pool) { - assert alphaEmit > 0; - assert alphaPi > 0; - - for(double [][]i:emit) - for(double []j:i) - digammaNormalize(j, alphaEmit); - - for(double []j:pi) - digammaNormalize(j, alphaPi); - } - - void useThreadPool(int threads) - { - assert threads > 0; - pool = Executors.newFixedThreadPool(threads); + this.pool = pool; } public double EM(int phraseSizeLimit) @@ -131,107 +119,6 @@ public class PhraseCluster { return loglikelihood; } - public double VBEM(double alphaEmit, double alphaPi) - { - // FIXME: broken - needs to be done entirely in log-space - - double [][][]exp_emit = new double [K][n_positions][n_words]; - double [][]exp_pi = new double[n_phrases][K]; - - double loglikelihood=0; - - //E - for(int phrase=0; phrase < n_phrases; phrase++) - { - List contexts = c.getEdgesForPhrase(phrase); - - for (int ctx=0; ctx 0; - loglikelihood += edge.getCount() * Math.log(z); - arr.F.l1normalize(p); - - double count = edge.getCount(); - //increment expected count - TIntArrayList context = edge.getContext(); - for(int tag=0;tag() { + public PhraseObjective call() { + //System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p); + long start = System.currentTimeMillis(); + PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT, (cacheLambda) ? lambdaPT[p] : null); + boolean ok = po.optimizeWithProjectedGradientDescent(); + if (!ok) failures.incrementAndGet(); + long end = System.currentTimeMillis(); + elapsed.addAndGet(end - start); + //System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p); + return po; } - }); + })); } // aggregate the expectations as they become available - for(int count=0;count fpo : results) + { try { //System.out.println("" + Thread.currentThread().getId() + " reading queue #" + count); // wait (blocking) until something is ready - PhraseObjective po = expectations.take(); + PhraseObjective po = fpo.get(); // process int phrase = po.phrase; if (cacheLambda) lambdaPT[phrase] = po.getParameters(); @@ -408,10 +288,12 @@ public class PhraseCluster { exp_pi[phrase][tag]+=q[edge][tag]*contextCnt; } } - } catch (InterruptedException e) - { + } catch (InterruptedException e) { System.err.println("M-step thread interrupted. Probably fatal!"); - e.printStackTrace(); + throw new RuntimeException(e); + } catch (ExecutionException e) { + System.err.println("M-step thread execution died. Probably fatal!"); + throw new RuntimeException(e); } } diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java index 06a9f8cb..5947c4be 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java @@ -277,7 +277,10 @@ public class PhraseContextObjective extends ProjectedObjective } // rethrow the exception if (failure != null) + { + pool.shutdownNow(); throw new RuntimeException(failure); + } } double[] tmp = newPoint; diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 7c32d9c0..5efe778a 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -192,7 +192,7 @@ public class PhraseObjective extends ProjectedObjective // for(int edge=0;edge 0) + threadPool = Executors.newFixedThreadPool(threads); + if (tags <= 1 || scale_phrase < 0 || scale_context < 0 || threshold < 0) { System.err.println("Invalid arguments. Try again!"); @@ -114,26 +118,30 @@ public class Trainer agree = new Agree(tags, corpus); else { - cluster = new PhraseCluster(tags, corpus); - if (threads > 0) cluster.useThreadPool(threads); - - if (vb) { - //cluster.initialiseVB(alphaEmit, alphaPi); + if (vb) + { vbModel=new VB(tags,corpus); vbModel.alpha=alphaPi; vbModel.lambda=alphaEmit; - } - if (options.has("no-parameter-cache")) - cluster.cacheLambda = false; - if (options.has("start")) + if (threadPool != null) vbModel.useThreadPool(threadPool); + } + else { - try { - System.err.println("Reading starting parameters from " + options.valueOf("start")); - cluster.loadParameters(FileUtil.reader((File)options.valueOf("start"))); - } catch (IOException e) { - System.err.println("Failed to open input file: " + options.valueOf("start")); - e.printStackTrace(); - } + cluster = new PhraseCluster(tags, corpus); + if (threadPool != null) cluster.useThreadPool(threadPool); + + if (options.has("no-parameter-cache")) + cluster.cacheLambda = false; + if (options.has("start")) + { + try { + System.err.println("Reading starting parameters from " + options.valueOf("start")); + cluster.loadParameters(FileUtil.reader((File)options.valueOf("start"))); + } catch (IOException e) { + System.err.println("Failed to open input file: " + options.valueOf("start")); + e.printStackTrace(); + } + } } } @@ -143,9 +151,8 @@ public class Trainer double o; if (agree != null) o = agree.EM(); - else if(agree2sides!=null){ + else if(agree2sides!=null) o = agree2sides.EM(); - } else { if (i < skip) @@ -173,11 +180,25 @@ public class Trainer last = o; } - if (cluster == null) - cluster = agree.model1; + double pl1lmax = 0, cl1lmax = 0; + if (cluster != null) + { + pl1lmax = cluster.phrase_l1lmax(); + cl1lmax = cluster.context_l1lmax(); + } + else if (agree != null) + { + // fairly arbitrary choice of model1 cf model2 + pl1lmax = agree.model1.phrase_l1lmax(); + cl1lmax = agree.model1.context_l1lmax(); + } + else if (agree2sides != null) + { + // fairly arbitrary choice of model1 cf model2 + pl1lmax = agree2sides.model1.phrase_l1lmax(); + cl1lmax = agree2sides.model1.context_l1lmax(); + } - double pl1lmax = cluster.phrase_l1lmax(); - double cl1lmax = cluster.context_l1lmax(); System.out.println("\nFinal posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax); if (options.has("out")) @@ -194,11 +215,18 @@ public class Trainer System.out.println("Reading testing concordance from " + infile); test = corpus.readEdges(FileUtil.reader(infile)); } - if(vb){ + if(vb) { + assert !options.has("test"); vbModel.displayPosterior(ps); - }else{ + } else if (cluster != null) cluster.displayPosterior(ps, test); + else if (agree != null) + agree.displayPosterior(ps, test); + else if (agree2sides != null) { + assert !options.has("test"); + agree2sides.displayPosterior(ps); } + ps.close(); } catch (IOException e) { System.err.println("Failed to open either testing file or output file"); @@ -209,6 +237,7 @@ public class Trainer if (options.has("parameters")) { + assert !vb; File outfile = (File) options.valueOf("parameters"); PrintStream ps; try { @@ -222,7 +251,7 @@ public class Trainer } } - if (cluster.pool != null) + if (cluster != null && cluster.pool != null) cluster.pool.shutdown(); } } diff --git a/gi/posterior-regularisation/prjava/src/phrase/VB.java b/gi/posterior-regularisation/prjava/src/phrase/VB.java index a858c883..cd3f4966 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/VB.java +++ b/gi/posterior-regularisation/prjava/src/phrase/VB.java @@ -7,8 +7,13 @@ import io.FileUtil; import java.io.File; import java.io.IOException; import java.io.PrintStream; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import org.apache.commons.math.special.Gamma; @@ -38,21 +43,17 @@ public class VB { /**@brief * variational param for z */ - private double phi[][]; + //private double phi[][]; /**@brief * variational param for theta */ private double gamma[]; private static double VAL_DIFF_RATIO=0.005; - /**@brief - * objective for a single document - */ - private double obj; - private int n_positions; private int n_words; private int K; + private ExecutorService pool; private Corpus c; public static void main(String[] args) { @@ -122,17 +123,14 @@ public class VB { } - private void inference(int phraseID){ + private double inference(int phraseID, double[][] phi, double[] gamma) + { List doc=c.getEdgesForPhrase(phraseID); - phi=new double[doc.size()][K]; for(int i=0;idoc=c.getEdgesForPhrase(d); - for(int n=0;n doc=c.getEdgesForPhrase(d); + double[][] phi = new double[doc.size()][K]; + double[] gamma = new double[K]; + + emObj += inference(d, phi, gamma); + + for(int n=0;n + { + double[][] phi; + double[] gamma; + double obj; + int d; + PartialEStep(int d) { this.d = d; } + + public PartialEStep call() + { + phi = new double[c.getEdgesForPhrase(d).size()][K]; + gamma = new double[K]; + obj = inference(d, phi, gamma); + return this; + } + } + + List> jobs = new ArrayList>(); + for (int d=0;d job: jobs) + { + try { + PartialEStep e = job.get(); + + emObj += e.obj; + List doc = c.getEdgesForPhrase(e.d); + for(int n=0;n doc=c.getEdgesForPhrase(d); + List doc=c.getEdgesForPhrase(d); + double[][] phi = new double[doc.size()][K]; + for(int i=0;i