summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-14 15:15:35 +0000
committertrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-14 15:15:35 +0000
commitd2a54056e2acbdfd48e7c088fe25cc24cf280575 (patch)
tree73589c6900293e34f12a3070ec80bef3398c9a99
parentc29321deae3bc178e9ea0501f598a40894c6bc98 (diff)
Made PhraseObjective thread safe
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@248 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java45
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java4
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java24
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Trainer.java3
4 files changed, 48 insertions, 28 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index 1f73764e..a369b319 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -2,8 +2,6 @@ package phrase;
import gnu.trove.TIntArrayList;
import org.apache.commons.math.special.Gamma;
-import io.FileUtil;
-import java.io.IOException;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.List;
@@ -11,9 +9,10 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
import phrase.Corpus.Edge;
-import util.MathUtil;
+
public class PhraseCluster {
@@ -21,7 +20,11 @@ public class PhraseCluster {
private int n_phrases, n_words, n_contexts, n_positions;
public Corpus c;
public ExecutorService pool;
-
+
+ double[] lambdaPTCT;
+ double[][] lambdaPT;
+ boolean cacheLambda = true;
+
// emit[tag][position][word] = p(word | tag, position in context)
double emit[][][];
// pi[phrase][tag] = p(tag | phrase)
@@ -232,14 +235,19 @@ public class PhraseCluster {
{
double [][][]exp_emit=new double[K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
+
+ if (lambdaPT == null && cacheLambda)
+ lambdaPT = new double[n_phrases][];
double loglikelihood=0, kl=0, l1lmax=0, primal=0;
int failures=0, iterations=0;
+ long start = System.currentTimeMillis();
//E
for(int phrase=0; phrase<n_phrases; phrase++){
- PhraseObjective po=new PhraseObjective(this, phrase, scalePT);
+ PhraseObjective po = new PhraseObjective(this, phrase, scalePT, (cacheLambda) ? lambdaPT[phrase] : null);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) ++failures;
+ if (cacheLambda) lambdaPT[phrase] = po.getParameters();
iterations += po.getNumberUpdateCalls();
double [][] q=po.posterior();
loglikelihood += po.loglikelihood();
@@ -263,9 +271,10 @@ public class PhraseCluster {
}
}
+ long end = System.currentTimeMillis();
if (failures > 0)
System.out.println("WARNING: failed to converge in " + failures + "/" + n_phrases + " cases");
- System.out.println("\tmean iters: " + iterations/(double)n_phrases);
+ System.out.println("\tmean iters: " + iterations/(double)n_phrases + " elapsed time " + (end - start) / 1000.0);
System.out.println("\tllh: " + loglikelihood);
System.out.println("\tKL: " + kl);
System.out.println("\tphrase l1lmax: " + l1lmax);
@@ -295,7 +304,12 @@ public class PhraseCluster {
double loglikelihood=0, kl=0, l1lmax=0, primal=0;
final AtomicInteger failures = new AtomicInteger(0);
+ final AtomicLong elapsed = new AtomicLong(0l);
int iterations=0;
+ long start = System.currentTimeMillis();
+
+ if (lambdaPT == null && cacheLambda)
+ lambdaPT = new double[n_phrases][];
//E
for(int phrase=0;phrase<n_phrases;phrase++){
@@ -304,9 +318,13 @@ public class PhraseCluster {
public void run() {
try {
//System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
- PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT);
+ long start = System.currentTimeMillis();
+ PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT, (cacheLambda) ? lambdaPT[p] : null);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) failures.incrementAndGet();
+ long end = System.currentTimeMillis();
+ elapsed.addAndGet(end - start);
+
//System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
expectations.put(po);
//System.out.println("" + Thread.currentThread().getId() + " added to queue " + p);
@@ -327,6 +345,7 @@ public class PhraseCluster {
PhraseObjective po = expectations.take();
// process
int phrase = po.phrase;
+ if (cacheLambda) lambdaPT[phrase] = po.getParameters();
//System.out.println("" + Thread.currentThread().getId() + " taken phrase " + phrase);
double [][] q=po.posterior();
loglikelihood += po.loglikelihood();
@@ -335,7 +354,6 @@ public class PhraseCluster {
primal += po.primal(scalePT);
iterations += po.getNumberUpdateCalls();
-
List<Edge> edges = c.getEdgesForPhrase(phrase);
for(int edge=0;edge<q.length;edge++){
Edge e = edges.get(edge);
@@ -356,9 +374,11 @@ public class PhraseCluster {
}
}
+ long end = System.currentTimeMillis();
+
if (failures.get() > 0)
System.out.println("WARNING: failed to converge in " + failures.get() + "/" + n_phrases + " cases");
- System.out.println("\tmean iters: " + iterations/(double)n_phrases);
+ System.out.println("\tmean iters: " + iterations/(double)n_phrases + " walltime " + (end-start)/1000.0 + " threads " + elapsed.get() / 1000.0);
System.out.println("\tllh: " + loglikelihood);
System.out.println("\tKL: " + kl);
System.out.println("\tphrase l1lmax: " + l1lmax);
@@ -376,16 +396,15 @@ public class PhraseCluster {
return primal;
}
- double[] lambda;
-
public double PREM_phrase_context_constraints(double scalePT, double scaleCT)
{
double[][][] exp_emit = new double [K][n_positions][n_words];
double[][] exp_pi = new double[n_phrases][K];
//E step
- PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool, scalePT, scaleCT);
- lambda = pco.optimizeWithProjectedGradientDescent();
+ PhraseContextObjective pco = new PhraseContextObjective(this, lambdaPTCT, pool, scalePT, scaleCT);
+ boolean ok = pco.optimizeWithProjectedGradientDescent();
+ if (cacheLambda) lambdaPTCT = pco.getParameters();
//now extract expectations
List<Corpus.Edge> edges = c.getEdges();
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
index 7e6c7f60..06a9f8cb 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
@@ -318,7 +318,7 @@ public class PhraseContextObjective extends ProjectedObjective
return q[edgeIndex];
}
- public double[] optimizeWithProjectedGradientDescent()
+ public boolean optimizeWithProjectedGradientDescent()
{
projectionTime = 0;
actualProjectionTime = 0;
@@ -354,7 +354,7 @@ public class PhraseContextObjective extends ProjectedObjective
System.out.println(" and " + total + " ms: projection " + projectionTime +
" actual " + actualProjectionTime + " objective " + objectiveTime);
- return parameters;
+ return success;
}
double loglikelihood()
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
index e62b62f4..7c32d9c0 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
@@ -25,7 +25,7 @@ public class PhraseObjective extends ProjectedObjective
static int ITERATIONS = 100;
//private double c1=0.0001; // wolf stuff
//private double c2=0.9;
- private static double lambda[][];
+ //private static double lambda[][];
private PhraseCluster c;
/**@brief
@@ -64,23 +64,18 @@ public class PhraseObjective extends ProjectedObjective
*/
public double llh;
- public PhraseObjective(PhraseCluster cluster, int phraseIdx, double scale){
+ public PhraseObjective(PhraseCluster cluster, int phraseIdx, double scale, double[] lambda){
phrase=phraseIdx;
c=cluster;
data=c.c.getEdgesForPhrase(phrase);
n_param=data.size()*c.K;
//System.out.println("Num parameters " + n_param + " for phrase #" + phraseIdx);
- if (lambda==null){
- lambda=new double[c.c.getNumPhrases()][];
- }
-
- if (lambda[phrase]==null){
- lambda[phrase]=new double[n_param];
- }
+ if (lambda==null)
+ lambda=new double[n_param];
- parameters=lambda[phrase];
- newPoint = new double[n_param];
+ parameters = lambda;
+ newPoint = new double[n_param];
gradient = new double[n_param];
initP();
projection=new SimplexProjection(scale);
@@ -163,8 +158,12 @@ public class PhraseObjective extends ProjectedObjective
public double [][]posterior(){
return q;
}
-
+
+ long optimizationTime;
+
public boolean optimizeWithProjectedGradientDescent(){
+ long start = System.currentTimeMillis();
+
LineSearchMethod ls =
new ArmijoLineSearchMinimizationAlongProjectionArc
(new InterpolationPickFirstStep(INIT_STEP_SIZE));
@@ -188,7 +187,6 @@ public class PhraseObjective extends ProjectedObjective
//}else{
// System.out.println("Failed to optimize");
//}
- lambda[phrase]=parameters;
// ps.println(Arrays.toString(parameters));
// for(int edge=0;edge<data.getSize();edge++){
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
index 240c4d64..20f6c905 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
@@ -31,6 +31,7 @@ public class Trainer
parser.accepts("alpha-emit").withRequiredArg().ofType(Double.class).defaultsTo(0.1);
parser.accepts("alpha-pi").withRequiredArg().ofType(Double.class).defaultsTo(0.01);
parser.accepts("agree");
+ parser.accepts("no-parameter-cache");
OptionSet options = parser.parse(args);
if (options.has("help") || !options.has("in"))
@@ -96,6 +97,8 @@ public class Trainer
cluster = new PhraseCluster(tags, corpus);
if (threads > 0) cluster.useThreadPool(threads);
if (vb) cluster.initialiseVB(alphaEmit, alphaPi);
+ if (options.has("no-parameter-cache"))
+ cluster.cacheLambda = false;
}
double last = 0;