From f2c8e89b45a18bc4d16e9623d7beb96f43732453 Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Fri, 9 Jul 2010 15:00:17 +0000 Subject: Parallelised PhraseContextObjective.projectPoint Added ANT build file git-svn-id: https://ws10smt.googlecode.com/svn/trunk@193 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/posterior-regularisation/prjava/build.xml | 38 +++++ gi/posterior-regularisation/prjava/src/arr/F.java | 6 +- .../prjava/src/phrase/PhraseCluster.java | 24 ++- .../prjava/src/phrase/PhraseContextObjective.java | 174 ++++++++++++++++----- .../prjava/src/phrase/PhraseObjective.java | 5 +- .../prjava/train-PR-cluster.sh | 4 + 6 files changed, 198 insertions(+), 53 deletions(-) create mode 100644 gi/posterior-regularisation/prjava/build.xml create mode 100755 gi/posterior-regularisation/prjava/train-PR-cluster.sh (limited to 'gi') diff --git a/gi/posterior-regularisation/prjava/build.xml b/gi/posterior-regularisation/prjava/build.xml new file mode 100644 index 00000000..c9ed2e8d --- /dev/null +++ b/gi/posterior-regularisation/prjava/build.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/gi/posterior-regularisation/prjava/src/arr/F.java b/gi/posterior-regularisation/prjava/src/arr/F.java index c194496e..5821af42 100644 --- a/gi/posterior-regularisation/prjava/src/arr/F.java +++ b/gi/posterior-regularisation/prjava/src/arr/F.java @@ -1,12 +1,16 @@ package arr; +import java.util.Random; + public class F { + private static Random rng = new Random(); //(9562724l); + public static void randomise(double probs[]) { double z = 0; for (int i = 0; i < probs.length; ++i) { - probs[i] = 3 + Math.random(); + probs[i] = 3 + rng.nextDouble(); z += probs[i]; } diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index e4db2a1a..63a60682 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -18,7 +18,7 @@ public class PhraseCluster { public double scalePT, scaleCT; private int n_phrases, n_words, n_contexts, n_positions; public Corpus c; - private ExecutorService pool; + public ExecutorService pool; // emit[tag][position][word] = p(word | tag, position in context) private double emit[][][]; @@ -88,7 +88,8 @@ public class PhraseCluster { //cluster.displayModelParam(ps); //ps.close(); - cluster.finish(); + if (cluster.pool != null) + cluster.pool.shutdown(); } public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){ @@ -100,7 +101,7 @@ public class PhraseCluster { n_positions=c.getNumContextPositions(); this.scalePT = scalep; this.scaleCT = scalec; - if (threads > 0 && scalec <= 0) + if (threads > 0) pool = Executors.newFixedThreadPool(threads); emit=new double [K][n_positions][n_words]; @@ -116,12 +117,7 @@ public class PhraseCluster { arr.F.randomise(j); } } - - public void finish() - { - if (pool != null) - pool.shutdown(); - } + public double EM(){ double [][][]exp_emit=new double [K][n_positions][n_words]; @@ -318,13 +314,13 @@ public class PhraseCluster { public double PREM_phrase_context_constraints(){ assert (scaleCT > 0); - double [][][]exp_emit=new double [K][n_positions][n_words]; - double [][]exp_pi=new double[n_phrases][K]; + double[][][] exp_emit = new double [K][n_positions][n_words]; + double[][] exp_pi = new double[n_phrases][K]; + double[] lambda = null; //E step - // TODO: cache the lambda values (the null below) - PhraseContextObjective pco = new PhraseContextObjective(this, null); - pco.optimizeWithProjectedGradientDescent(); + PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool); + lambda = pco.optimizeWithProjectedGradientDescent(); //now extract expectations List edges = c.getEdges(); diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java index 3273f0ad..fbf43a7f 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java @@ -1,10 +1,13 @@ package phrase; -import java.io.PrintStream; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import optimization.gradientBasedMethods.ProjectedGradientDescent; import optimization.gradientBasedMethods.ProjectedObjective; @@ -12,7 +15,6 @@ import optimization.gradientBasedMethods.stats.OptimizerStats; import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc; import optimization.linesearch.InterpolationPickFirstStep; import optimization.linesearch.LineSearchMethod; -import optimization.linesearch.WolfRuleLineSearch; import optimization.projections.SimplexProjection; import optimization.stopCriteria.CompositeStopingCriteria; import optimization.stopCriteria.ProjectedGradientL2Norm; @@ -52,11 +54,17 @@ public class PhraseContextObjective extends ProjectedObjective private Map edgeIndex; - public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters) + private long projectionTime; + private long objectiveTime; + private long actualProjectionTime; + private ExecutorService pool; + + public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters, ExecutorService pool) { c=cluster; data=c.c.getEdges(); n_param=data.size()*c.K*2; + this.pool=pool; parameters = startingParameters; if (parameters == null) @@ -99,6 +107,7 @@ public class PhraseContextObjective extends ProjectedObjective updateCalls++; loglikelihood=0; + long begin = System.currentTimeMillis(); for (int e=0; e> tasks = new ArrayList>(); + //System.out.println("projectPoint: " + Arrays.toString(point)); Arrays.fill(newPoint, 0, newPoint.length, 0); + if (c.scalePT > 0) { // first project using the phrase-tag constraints, // for all p,t: sum_c lambda_ptc < scaleP - for (int p = 0; p < c.c.getNumPhrases(); ++p) + if (pool == null) { - List edges = c.c.getEdgesForPhrase(p); - double toProject[] = new double[edges.size()]; - for(int tag=0;tag edges = c.c.getEdgesForPhrase(p); + double[] toProject = new double[edges.size()]; + for(int tag=0;tag edges = c.c.getEdgesForPhrase(phrase); + double toProject[] = new double[edges.size()]; + for(int tag=0;tag edges = c.c.getEdgesForContext(ctx); - double toProject[] = new double[edges.size()]; - for(int tag=0;tag edges = c.c.getEdgesForContext(ctx); + double toProject[] = new double[edges.size()]; + for(int tag=0;tag edges = c.c.getEdgesForContext(context); + double toProject[] = new double[edges.size()]; + for(int tag=0;tag task: tasks) + { + try { + task.get(); + } catch (InterruptedException e) { + System.err.println("ERROR: Projection thread interrupted"); + e.printStackTrace(); + failure = e; + } catch (ExecutionException e) { + System.err.println("ERROR: Projection thread died"); + e.printStackTrace(); + failure = e; + } + } + // rethrow the exception + if (failure != null) + throw new RuntimeException(failure); + } + double[] tmp = newPoint; newPoint = point; + projectionTime += System.currentTimeMillis() - begin; + //System.out.println("\treturning " + Arrays.toString(tmp)); return tmp; @@ -214,6 +315,11 @@ public class PhraseContextObjective extends ProjectedObjective public double[] optimizeWithProjectedGradientDescent() { + projectionTime = 0; + actualProjectionTime = 0; + objectiveTime = 0; + long start = System.currentTimeMillis(); + LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc (new InterpolationPickFirstStep(INIT_STEP_SIZE)); @@ -230,20 +336,17 @@ public class PhraseContextObjective extends ProjectedObjective compositeStop.add(stopValue); optimizer.setMaxIterations(ITERATIONS); updateFunction(); - boolean succed = optimizer.optimize(this,stats,compositeStop); + boolean success = optimizer.optimize(this,stats,compositeStop); // System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1)); - if(succed){ - //System.out.println("Ended optimization in " + optimizer.getCurrentIteration()); - }else{ - System.out.println("Failed to optimize"); - } - // ps.println(Arrays.toString(parameters)); - - // for(int edge=0;edge