From d2a54056e2acbdfd48e7c088fe25cc24cf280575 Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Wed, 14 Jul 2010 15:15:35 +0000 Subject: Made PhraseObjective thread safe git-svn-id: https://ws10smt.googlecode.com/svn/trunk@248 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../prjava/src/phrase/PhraseCluster.java | 45 +++++++++++++++------- .../prjava/src/phrase/PhraseContextObjective.java | 4 +- .../prjava/src/phrase/PhraseObjective.java | 24 ++++++------ .../prjava/src/phrase/Trainer.java | 3 ++ 4 files changed, 48 insertions(+), 28 deletions(-) (limited to 'gi/posterior-regularisation/prjava/src/phrase') diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 1f73764e..a369b319 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -2,8 +2,6 @@ package phrase; import gnu.trove.TIntArrayList; import org.apache.commons.math.special.Gamma; -import io.FileUtil; -import java.io.IOException; import java.io.PrintStream; import java.util.Arrays; import java.util.List; @@ -11,9 +9,10 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import phrase.Corpus.Edge; -import util.MathUtil; + public class PhraseCluster { @@ -21,7 +20,11 @@ public class PhraseCluster { private int n_phrases, n_words, n_contexts, n_positions; public Corpus c; public ExecutorService pool; - + + double[] lambdaPTCT; + double[][] lambdaPT; + boolean cacheLambda = true; + // emit[tag][position][word] = p(word | tag, position in context) double emit[][][]; // pi[phrase][tag] = p(tag | phrase) @@ -232,14 +235,19 @@ public class PhraseCluster { { double [][][]exp_emit=new double[K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; + + if (lambdaPT == null && cacheLambda) + lambdaPT = new double[n_phrases][]; double loglikelihood=0, kl=0, l1lmax=0, primal=0; int failures=0, iterations=0; + long start = System.currentTimeMillis(); //E for(int phrase=0; phrase 0) System.out.println("WARNING: failed to converge in " + failures + "/" + n_phrases + " cases"); - System.out.println("\tmean iters: " + iterations/(double)n_phrases); + System.out.println("\tmean iters: " + iterations/(double)n_phrases + " elapsed time " + (end - start) / 1000.0); System.out.println("\tllh: " + loglikelihood); System.out.println("\tKL: " + kl); System.out.println("\tphrase l1lmax: " + l1lmax); @@ -295,7 +304,12 @@ public class PhraseCluster { double loglikelihood=0, kl=0, l1lmax=0, primal=0; final AtomicInteger failures = new AtomicInteger(0); + final AtomicLong elapsed = new AtomicLong(0l); int iterations=0; + long start = System.currentTimeMillis(); + + if (lambdaPT == null && cacheLambda) + lambdaPT = new double[n_phrases][]; //E for(int phrase=0;phrase edges = c.getEdgesForPhrase(phrase); for(int edge=0;edge 0) System.out.println("WARNING: failed to converge in " + failures.get() + "/" + n_phrases + " cases"); - System.out.println("\tmean iters: " + iterations/(double)n_phrases); + System.out.println("\tmean iters: " + iterations/(double)n_phrases + " walltime " + (end-start)/1000.0 + " threads " + elapsed.get() / 1000.0); System.out.println("\tllh: " + loglikelihood); System.out.println("\tKL: " + kl); System.out.println("\tphrase l1lmax: " + l1lmax); @@ -376,16 +396,15 @@ public class PhraseCluster { return primal; } - double[] lambda; - public double PREM_phrase_context_constraints(double scalePT, double scaleCT) { double[][][] exp_emit = new double [K][n_positions][n_words]; double[][] exp_pi = new double[n_phrases][K]; //E step - PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool, scalePT, scaleCT); - lambda = pco.optimizeWithProjectedGradientDescent(); + PhraseContextObjective pco = new PhraseContextObjective(this, lambdaPTCT, pool, scalePT, scaleCT); + boolean ok = pco.optimizeWithProjectedGradientDescent(); + if (cacheLambda) lambdaPTCT = pco.getParameters(); //now extract expectations List edges = c.getEdges(); diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java index 7e6c7f60..06a9f8cb 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java @@ -318,7 +318,7 @@ public class PhraseContextObjective extends ProjectedObjective return q[edgeIndex]; } - public double[] optimizeWithProjectedGradientDescent() + public boolean optimizeWithProjectedGradientDescent() { projectionTime = 0; actualProjectionTime = 0; @@ -354,7 +354,7 @@ public class PhraseContextObjective extends ProjectedObjective System.out.println(" and " + total + " ms: projection " + projectionTime + " actual " + actualProjectionTime + " objective " + objectiveTime); - return parameters; + return success; } double loglikelihood() diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index e62b62f4..7c32d9c0 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -25,7 +25,7 @@ public class PhraseObjective extends ProjectedObjective static int ITERATIONS = 100; //private double c1=0.0001; // wolf stuff //private double c2=0.9; - private static double lambda[][]; + //private static double lambda[][]; private PhraseCluster c; /**@brief @@ -64,23 +64,18 @@ public class PhraseObjective extends ProjectedObjective */ public double llh; - public PhraseObjective(PhraseCluster cluster, int phraseIdx, double scale){ + public PhraseObjective(PhraseCluster cluster, int phraseIdx, double scale, double[] lambda){ phrase=phraseIdx; c=cluster; data=c.c.getEdgesForPhrase(phrase); n_param=data.size()*c.K; //System.out.println("Num parameters " + n_param + " for phrase #" + phraseIdx); - if (lambda==null){ - lambda=new double[c.c.getNumPhrases()][]; - } - - if (lambda[phrase]==null){ - lambda[phrase]=new double[n_param]; - } + if (lambda==null) + lambda=new double[n_param]; - parameters=lambda[phrase]; - newPoint = new double[n_param]; + parameters = lambda; + newPoint = new double[n_param]; gradient = new double[n_param]; initP(); projection=new SimplexProjection(scale); @@ -163,8 +158,12 @@ public class PhraseObjective extends ProjectedObjective public double [][]posterior(){ return q; } - + + long optimizationTime; + public boolean optimizeWithProjectedGradientDescent(){ + long start = System.currentTimeMillis(); + LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc (new InterpolationPickFirstStep(INIT_STEP_SIZE)); @@ -188,7 +187,6 @@ public class PhraseObjective extends ProjectedObjective //}else{ // System.out.println("Failed to optimize"); //} - lambda[phrase]=parameters; // ps.println(Arrays.toString(parameters)); // for(int edge=0;edge 0) cluster.useThreadPool(threads); if (vb) cluster.initialiseVB(alphaEmit, alphaPi); + if (options.has("no-parameter-cache")) + cluster.cacheLambda = false; } double last = 0; -- cgit v1.2.3