diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava')
6 files changed, 198 insertions, 53 deletions
diff --git a/gi/posterior-regularisation/prjava/build.xml b/gi/posterior-regularisation/prjava/build.xml new file mode 100644 index 00000000..c9ed2e8d --- /dev/null +++ b/gi/posterior-regularisation/prjava/build.xml @@ -0,0 +1,38 @@ +<project name="prjava" default="dist" basedir="."> + <!-- set global properties for this build --> + <property name="src" location="src"/> + <property name="build" location="build"/> + <property name="dist" location="lib"/> + <path id="classpath"> + <pathelement location="lib/trove-2.0.2.jar"/> + <pathelement location="lib/optimization.jar"/> + </path> + + <target name="init"> + <!-- Create the time stamp --> + <tstamp/> + <!-- Create the build directory structure used by compile --> + <mkdir dir="${build}"/> + </target> + + <target name="compile" depends="init" + description="compile the source " > + <!-- Compile the java code from ${src} into ${build} --> + <javac srcdir="${src}" destdir="${build}"> + <classpath refid="classpath"/> + </javac> + </target> + + <target name="dist" depends="compile" + description="generate the distribution" > + <jar jarfile="${dist}/prjava-${DSTAMP}.jar" basedir="${build}"/> + <symlink link="prjava.jar" resource="${dist}/prjava-${DSTAMP}.jar" overwrite="true"/> + </target> + + <target name="clean" + description="clean up" > + <!-- Delete the ${build} and ${dist} directory trees --> + <delete dir="${build}"/> + <delete dir="${dist}"/> + </target> +</project> diff --git a/gi/posterior-regularisation/prjava/src/arr/F.java b/gi/posterior-regularisation/prjava/src/arr/F.java index c194496e..5821af42 100644 --- a/gi/posterior-regularisation/prjava/src/arr/F.java +++ b/gi/posterior-regularisation/prjava/src/arr/F.java @@ -1,12 +1,16 @@ package arr;
+import java.util.Random;
+
public class F {
+ private static Random rng = new Random(); //(9562724l);
+
public static void randomise(double probs[])
{
double z = 0;
for (int i = 0; i < probs.length; ++i)
{
- probs[i] = 3 + Math.random();
+ probs[i] = 3 + rng.nextDouble();
z += probs[i];
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index e4db2a1a..63a60682 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -18,7 +18,7 @@ public class PhraseCluster { public double scalePT, scaleCT;
private int n_phrases, n_words, n_contexts, n_positions;
public Corpus c;
- private ExecutorService pool;
+ public ExecutorService pool;
// emit[tag][position][word] = p(word | tag, position in context)
private double emit[][][];
@@ -88,7 +88,8 @@ public class PhraseCluster { //cluster.displayModelParam(ps);
//ps.close();
- cluster.finish();
+ if (cluster.pool != null)
+ cluster.pool.shutdown();
}
public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){
@@ -100,7 +101,7 @@ public class PhraseCluster { n_positions=c.getNumContextPositions();
this.scalePT = scalep;
this.scaleCT = scalec;
- if (threads > 0 && scalec <= 0)
+ if (threads > 0)
pool = Executors.newFixedThreadPool(threads);
emit=new double [K][n_positions][n_words];
@@ -116,12 +117,7 @@ public class PhraseCluster { arr.F.randomise(j);
}
}
-
- public void finish()
- {
- if (pool != null)
- pool.shutdown();
- }
+
public double EM(){
double [][][]exp_emit=new double [K][n_positions][n_words];
@@ -318,13 +314,13 @@ public class PhraseCluster { public double PREM_phrase_context_constraints(){
assert (scaleCT > 0);
- double [][][]exp_emit=new double [K][n_positions][n_words];
- double [][]exp_pi=new double[n_phrases][K];
+ double[][][] exp_emit = new double [K][n_positions][n_words];
+ double[][] exp_pi = new double[n_phrases][K];
+ double[] lambda = null;
//E step
- // TODO: cache the lambda values (the null below)
- PhraseContextObjective pco = new PhraseContextObjective(this, null);
- pco.optimizeWithProjectedGradientDescent();
+ PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool);
+ lambda = pco.optimizeWithProjectedGradientDescent();
//now extract expectations
List<Corpus.Edge> edges = c.getEdges();
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java index 3273f0ad..fbf43a7f 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java @@ -1,10 +1,13 @@ package phrase;
-import java.io.PrintStream;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
import optimization.gradientBasedMethods.ProjectedGradientDescent;
import optimization.gradientBasedMethods.ProjectedObjective;
@@ -12,7 +15,6 @@ import optimization.gradientBasedMethods.stats.OptimizerStats; import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;
import optimization.linesearch.InterpolationPickFirstStep;
import optimization.linesearch.LineSearchMethod;
-import optimization.linesearch.WolfRuleLineSearch;
import optimization.projections.SimplexProjection;
import optimization.stopCriteria.CompositeStopingCriteria;
import optimization.stopCriteria.ProjectedGradientL2Norm;
@@ -52,11 +54,17 @@ public class PhraseContextObjective extends ProjectedObjective private Map<Corpus.Edge, Integer> edgeIndex;
- public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters)
+ private long projectionTime;
+ private long objectiveTime;
+ private long actualProjectionTime;
+ private ExecutorService pool;
+
+ public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters, ExecutorService pool)
{
c=cluster;
data=c.c.getEdges();
n_param=data.size()*c.K*2;
+ this.pool=pool;
parameters = startingParameters;
if (parameters == null)
@@ -99,6 +107,7 @@ public class PhraseContextObjective extends ProjectedObjective updateCalls++;
loglikelihood=0;
+ long begin = System.currentTimeMillis();
for (int e=0; e<data.size(); e++)
{
Edge edge = data.get(e);
@@ -129,29 +138,64 @@ public class PhraseContextObjective extends ProjectedObjective gradient[ic]=-q[e][tag];
}
}
- //System.out.println("objective " + loglikelihood + " gradient: " + Arrays.toString(gradient));
+ //System.out.println("objective " + loglikelihood + " gradient: " + Arrays.toString(gradient));
+ objectiveTime += System.currentTimeMillis() - begin;
}
@Override
public double[] projectPoint(double[] point)
{
+ long begin = System.currentTimeMillis();
+ List<Future<?>> tasks = new ArrayList<Future<?>>();
+
//System.out.println("projectPoint: " + Arrays.toString(point));
Arrays.fill(newPoint, 0, newPoint.length, 0);
+
if (c.scalePT > 0)
{
// first project using the phrase-tag constraints,
// for all p,t: sum_c lambda_ptc < scaleP
- for (int p = 0; p < c.c.getNumPhrases(); ++p)
+ if (pool == null)
{
- List<Edge> edges = c.c.getEdgesForPhrase(p);
- double toProject[] = new double[edges.size()];
- for(int tag=0;tag<c.K;tag++)
+ for (int p = 0; p < c.c.getNumPhrases(); ++p)
+ {
+ List<Edge> edges = c.c.getEdgesForPhrase(p);
+ double[] toProject = new double[edges.size()];
+ for(int tag=0;tag<c.K;tag++)
+ {
+ for(int e=0; e<edges.size(); e++)
+ toProject[e] = point[index(edges.get(e), tag, true)];
+ long lbegin = System.currentTimeMillis();
+ projectionPhrase.project(toProject);
+ actualProjectionTime += System.currentTimeMillis() - lbegin;
+ for(int e=0; e<edges.size(); e++)
+ newPoint[index(edges.get(e), tag, true)] = toProject[e];
+ }
+ }
+ }
+ else // do above in parallel using thread pool
+ {
+ for (int p = 0; p < c.c.getNumPhrases(); ++p)
{
- for(int e=0; e<edges.size(); e++)
- toProject[e] = point[index(edges.get(e), tag, true)];
- projectionPhrase.project(toProject);
- for(int e=0; e<edges.size(); e++)
- newPoint[index(edges.get(e),tag, true)] = toProject[e];
+ final int phrase = p;
+ final double[] inPoint = point;
+ Runnable task = new Runnable()
+ {
+ public void run()
+ {
+ List<Edge> edges = c.c.getEdgesForPhrase(phrase);
+ double toProject[] = new double[edges.size()];
+ for(int tag=0;tag<c.K;tag++)
+ {
+ for(int e=0; e<edges.size(); e++)
+ toProject[e] = inPoint[index(edges.get(e), tag, true)];
+ projectionPhrase.project(toProject);
+ for(int e=0; e<edges.size(); e++)
+ newPoint[index(edges.get(e), tag, true)] = toProject[e];
+ }
+ }
+ };
+ tasks.add(pool.submit(task));
}
}
}
@@ -161,22 +205,79 @@ public class PhraseContextObjective extends ProjectedObjective {
// now project using the context-tag constraints,
// for all c,t: sum_p omega_pct < scaleC
- for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
+ if (pool == null)
{
- List<Edge> edges = c.c.getEdgesForContext(ctx);
- double toProject[] = new double[edges.size()];
- for(int tag=0;tag<c.K;tag++)
+ for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
{
- for(int e=0; e<edges.size(); e++)
- toProject[e] = point[index(edges.get(e), tag, false)];
- projectionContext.project(toProject);
- for(int e=0; e<edges.size(); e++)
- newPoint[index(edges.get(e),tag, false)] = toProject[e];
+ List<Edge> edges = c.c.getEdgesForContext(ctx);
+ double toProject[] = new double[edges.size()];
+ for(int tag=0;tag<c.K;tag++)
+ {
+ for(int e=0; e<edges.size(); e++)
+ toProject[e] = point[index(edges.get(e), tag, false)];
+ long lbegin = System.currentTimeMillis();
+ projectionContext.project(toProject);
+ actualProjectionTime += System.currentTimeMillis() - lbegin;
+ for(int e=0; e<edges.size(); e++)
+ newPoint[index(edges.get(e), tag, false)] = toProject[e];
+ }
+ }
+ }
+ else
+ {
+ // do above in parallel using thread pool
+ for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
+ {
+ final int context = ctx;
+ final double[] inPoint = point;
+ Runnable task = new Runnable()
+ {
+ public void run()
+ {
+ List<Edge> edges = c.c.getEdgesForContext(context);
+ double toProject[] = new double[edges.size()];
+ for(int tag=0;tag<c.K;tag++)
+ {
+ for(int e=0; e<edges.size(); e++)
+ toProject[e] = inPoint[index(edges.get(e), tag, false)];
+ projectionContext.project(toProject);
+ for(int e=0; e<edges.size(); e++)
+ newPoint[index(edges.get(e), tag, false)] = toProject[e];
+ }
+ }
+ };
+ tasks.add(pool.submit(task));
}
}
}
+
+ if (pool != null)
+ {
+ // wait for all the jobs to complete
+ Exception failure = null;
+ for (Future<?> task: tasks)
+ {
+ try {
+ task.get();
+ } catch (InterruptedException e) {
+ System.err.println("ERROR: Projection thread interrupted");
+ e.printStackTrace();
+ failure = e;
+ } catch (ExecutionException e) {
+ System.err.println("ERROR: Projection thread died");
+ e.printStackTrace();
+ failure = e;
+ }
+ }
+ // rethrow the exception
+ if (failure != null)
+ throw new RuntimeException(failure);
+ }
+
double[] tmp = newPoint;
newPoint = point;
+ projectionTime += System.currentTimeMillis() - begin;
+
//System.out.println("\treturning " + Arrays.toString(tmp));
return tmp;
@@ -214,6 +315,11 @@ public class PhraseContextObjective extends ProjectedObjective public double[] optimizeWithProjectedGradientDescent()
{
+ projectionTime = 0;
+ actualProjectionTime = 0;
+ objectiveTime = 0;
+ long start = System.currentTimeMillis();
+
LineSearchMethod ls =
new ArmijoLineSearchMinimizationAlongProjectionArc
(new InterpolationPickFirstStep(INIT_STEP_SIZE));
@@ -230,20 +336,17 @@ public class PhraseContextObjective extends ProjectedObjective compositeStop.add(stopValue);
optimizer.setMaxIterations(ITERATIONS);
updateFunction();
- boolean succed = optimizer.optimize(this,stats,compositeStop);
+ boolean success = optimizer.optimize(this,stats,compositeStop);
// System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
- if(succed){
- //System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
- }else{
- System.out.println("Failed to optimize");
- }
- // ps.println(Arrays.toString(parameters));
-
- // for(int edge=0;edge<data.getSize();edge++){
- // ps.println(Arrays.toString(q[edge]));
- // }
- //System.out.println(Arrays.toString(parameters));
+ if (success)
+ System.out.print("\toptimization took " + optimizer.getCurrentIteration() + " iterations");
+ else
+ System.out.print("\toptimization failed to converge");
+ long total = System.currentTimeMillis() - start;
+ System.out.println(" and " + total + " ms: projection " + projectionTime +
+ " actual " + actualProjectionTime + " objective " + objectiveTime);
+
return parameters;
}
@@ -298,5 +401,4 @@ public class PhraseContextObjective extends ProjectedObjective {
return loglikelihood() - KL_divergence() - c.scalePT * phrase_l1lmax() - c.scalePT * context_l1lmax();
}
-
-}
+}
\ No newline at end of file diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 015ef106..0a76e2dc 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -22,7 +22,7 @@ public class PhraseObjective extends ProjectedObjective {
static final double GRAD_DIFF = 0.00002;
static double INIT_STEP_SIZE = 10;
- static double VAL_DIFF = 1e-6; // FIXME needs to be tuned
+ static double VAL_DIFF = 1e-4; // FIXME needs to be tuned - and this might be too weak
static int ITERATIONS = 100;
//private double c1=0.0001; // wolf stuff
//private double c2=0.9;
@@ -128,7 +128,8 @@ public class PhraseObjective extends ProjectedObjective }
@Override
- public double[] projectPoint(double[] point) {
+ public double[] projectPoint(double[] point)
+ {
double toProject[]=new double[data.size()];
for(int tag=0;tag<c.K;tag++){
for(int edge=0;edge<data.size();edge++){
diff --git a/gi/posterior-regularisation/prjava/train-PR-cluster.sh b/gi/posterior-regularisation/prjava/train-PR-cluster.sh new file mode 100755 index 00000000..b86d564b --- /dev/null +++ b/gi/posterior-regularisation/prjava/train-PR-cluster.sh @@ -0,0 +1,4 @@ +#!/bin/sh + +d=`dirname $0` +java -ea -Xmx8g -cp $d/prjava.jar:$d/lib/trove-2.0.2.jar:$d/lib/optimization.jar phrase.PhraseCluster $* |