diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava')
6 files changed, 198 insertions, 53 deletions
| diff --git a/gi/posterior-regularisation/prjava/build.xml b/gi/posterior-regularisation/prjava/build.xml new file mode 100644 index 00000000..c9ed2e8d --- /dev/null +++ b/gi/posterior-regularisation/prjava/build.xml @@ -0,0 +1,38 @@ +<project name="prjava" default="dist" basedir="."> +  <!-- set global properties for this build --> +  <property name="src" location="src"/> +  <property name="build" location="build"/> +  <property name="dist" location="lib"/> +  <path id="classpath"> +      <pathelement location="lib/trove-2.0.2.jar"/> +      <pathelement location="lib/optimization.jar"/> +  </path> + +  <target name="init"> +    <!-- Create the time stamp --> +    <tstamp/> +    <!-- Create the build directory structure used by compile --> +    <mkdir dir="${build}"/> +  </target> + +  <target name="compile" depends="init" +        description="compile the source " > +    <!-- Compile the java code from ${src} into ${build} --> +    <javac srcdir="${src}" destdir="${build}"> +            <classpath refid="classpath"/> +    </javac> +  </target> + +  <target name="dist" depends="compile" +        description="generate the distribution" > +    <jar jarfile="${dist}/prjava-${DSTAMP}.jar" basedir="${build}"/> +    <symlink link="prjava.jar" resource="${dist}/prjava-${DSTAMP}.jar" overwrite="true"/> +  </target> + +  <target name="clean" +        description="clean up" > +    <!-- Delete the ${build} and ${dist} directory trees --> +    <delete dir="${build}"/> +    <delete dir="${dist}"/> +  </target> +</project> diff --git a/gi/posterior-regularisation/prjava/src/arr/F.java b/gi/posterior-regularisation/prjava/src/arr/F.java index c194496e..5821af42 100644 --- a/gi/posterior-regularisation/prjava/src/arr/F.java +++ b/gi/posterior-regularisation/prjava/src/arr/F.java @@ -1,12 +1,16 @@  package arr;
 +import java.util.Random;
 +
  public class F {
 +	private static Random rng = new Random(); //(9562724l);
 +	
  	public static void randomise(double probs[])
  	{
  		double z = 0;
  		for (int i = 0; i < probs.length; ++i)
  		{
 -			probs[i] = 3 + Math.random();
 +			probs[i] = 3 + rng.nextDouble();
  			z += probs[i];
  		}
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index e4db2a1a..63a60682 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -18,7 +18,7 @@ public class PhraseCluster {  	public double scalePT, scaleCT;
  	private int n_phrases, n_words, n_contexts, n_positions;
  	public Corpus c;
 -	private ExecutorService pool; 
 +	public ExecutorService pool; 
  	// emit[tag][position][word] = p(word | tag, position in context)
  	private double emit[][][];
 @@ -88,7 +88,8 @@ public class PhraseCluster {  		//cluster.displayModelParam(ps);
  		//ps.close();
 -		cluster.finish();
 +		if (cluster.pool != null)
 +			cluster.pool.shutdown();
  	}
  	public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){
 @@ -100,7 +101,7 @@ public class PhraseCluster {  		n_positions=c.getNumContextPositions();
  		this.scalePT = scalep;
  		this.scaleCT = scalec;
 -		if (threads > 0 && scalec <= 0)
 +		if (threads > 0)
  			pool = Executors.newFixedThreadPool(threads);
  		emit=new double [K][n_positions][n_words];
 @@ -116,12 +117,7 @@ public class PhraseCluster {  			arr.F.randomise(j);
  		}
  	}
 -	
 -	public void finish()
 -	{
 -		if (pool != null)
 -			pool.shutdown();
 -	}
 +
  	public double EM(){
  		double [][][]exp_emit=new double [K][n_positions][n_words];
 @@ -318,13 +314,13 @@ public class PhraseCluster {  	public double PREM_phrase_context_constraints(){
  		assert (scaleCT > 0);
 -		double [][][]exp_emit=new double [K][n_positions][n_words];
 -		double [][]exp_pi=new double[n_phrases][K];
 +		double[][][] exp_emit = new double [K][n_positions][n_words];
 +		double[][] exp_pi = new double[n_phrases][K];
 +		double[] lambda = null;
  		//E step
 -		// TODO: cache the lambda values (the null below)
 -		PhraseContextObjective pco = new PhraseContextObjective(this, null);
 -		pco.optimizeWithProjectedGradientDescent();
 +		PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool);
 +		lambda = pco.optimizeWithProjectedGradientDescent();
  		//now extract expectations
  		List<Corpus.Edge> edges = c.getEdges();
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java index 3273f0ad..fbf43a7f 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java @@ -1,10 +1,13 @@  package phrase;
 -import java.io.PrintStream;
 +import java.util.ArrayList;
  import java.util.Arrays;
  import java.util.HashMap;
  import java.util.List;
  import java.util.Map;
 +import java.util.concurrent.ExecutionException;
 +import java.util.concurrent.ExecutorService;
 +import java.util.concurrent.Future;
  import optimization.gradientBasedMethods.ProjectedGradientDescent;
  import optimization.gradientBasedMethods.ProjectedObjective;
 @@ -12,7 +15,6 @@ import optimization.gradientBasedMethods.stats.OptimizerStats;  import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;
  import optimization.linesearch.InterpolationPickFirstStep;
  import optimization.linesearch.LineSearchMethod;
 -import optimization.linesearch.WolfRuleLineSearch;
  import optimization.projections.SimplexProjection;
  import optimization.stopCriteria.CompositeStopingCriteria;
  import optimization.stopCriteria.ProjectedGradientL2Norm;
 @@ -52,11 +54,17 @@ public class PhraseContextObjective extends ProjectedObjective  	private Map<Corpus.Edge, Integer> edgeIndex;
 -	public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters)
 +	private long projectionTime;
 +	private long objectiveTime;
 +	private long actualProjectionTime;
 +	private ExecutorService pool;
 +	
 +	public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters, ExecutorService pool)
  	{
  		c=cluster;
  		data=c.c.getEdges();
  		n_param=data.size()*c.K*2;
 +		this.pool=pool;
  		parameters = startingParameters;
  		if (parameters == null)
 @@ -99,6 +107,7 @@ public class PhraseContextObjective extends ProjectedObjective  		updateCalls++;
  		loglikelihood=0;
 +		long begin = System.currentTimeMillis();
  		for (int e=0; e<data.size(); e++) 
  		{
  			Edge edge = data.get(e);
 @@ -129,29 +138,64 @@ public class PhraseContextObjective extends ProjectedObjective  				gradient[ic]=-q[e][tag];
  			}
  		}
 -		//System.out.println("objective " + loglikelihood + " gradient: " + Arrays.toString(gradient));
 +		//System.out.println("objective " + loglikelihood + " gradient: " + Arrays.toString(gradient));		
 +		objectiveTime += System.currentTimeMillis() - begin;
  	}
  	@Override
  	public double[] projectPoint(double[] point) 
  	{
 +		long begin = System.currentTimeMillis();
 +		List<Future<?>> tasks = new ArrayList<Future<?>>();
 +
  		//System.out.println("projectPoint: " + Arrays.toString(point));
  		Arrays.fill(newPoint, 0, newPoint.length, 0);
 +		
  		if (c.scalePT > 0)
  		{
  			// first project using the phrase-tag constraints,
  			// for all p,t: sum_c lambda_ptc < scaleP 
 -			for (int p = 0; p < c.c.getNumPhrases(); ++p)
 +			if (pool == null)
  			{
 -				List<Edge> edges = c.c.getEdgesForPhrase(p);
 -				double toProject[] = new double[edges.size()];
 -				for(int tag=0;tag<c.K;tag++)
 +				for (int p = 0; p < c.c.getNumPhrases(); ++p)
 +				{
 +					List<Edge> edges = c.c.getEdgesForPhrase(p);
 +					double[] toProject = new double[edges.size()];
 +					for(int tag=0;tag<c.K;tag++)
 +					{
 +						for(int e=0; e<edges.size(); e++)
 +							toProject[e] = point[index(edges.get(e), tag, true)];
 +						long lbegin = System.currentTimeMillis();
 +						projectionPhrase.project(toProject);
 +						actualProjectionTime += System.currentTimeMillis() - lbegin;
 +						for(int e=0; e<edges.size(); e++)
 +							newPoint[index(edges.get(e), tag, true)] = toProject[e];
 +					}
 +				}
 +			}
 +			else // do above in parallel using thread pool
 +			{	
 +				for (int p = 0; p < c.c.getNumPhrases(); ++p)
  				{
 -					for(int e=0; e<edges.size(); e++)
 -						toProject[e] = point[index(edges.get(e), tag, true)];
 -					projectionPhrase.project(toProject);
 -					for(int e=0; e<edges.size(); e++)
 -						newPoint[index(edges.get(e),tag, true)] = toProject[e];
 +					final int phrase = p;
 +					final double[] inPoint = point;
 +					Runnable task = new Runnable()
 +					{
 +						public void run()
 +						{
 +							List<Edge> edges = c.c.getEdgesForPhrase(phrase);
 +							double toProject[] = new double[edges.size()];
 +							for(int tag=0;tag<c.K;tag++)
 +							{
 +								for(int e=0; e<edges.size(); e++)
 +									toProject[e] = inPoint[index(edges.get(e), tag, true)];
 +								projectionPhrase.project(toProject);
 +								for(int e=0; e<edges.size(); e++)
 +									newPoint[index(edges.get(e), tag, true)] = toProject[e];
 +							}
 +						}		
 +					};
 +					tasks.add(pool.submit(task));
  				}
  			}
  		}
 @@ -161,22 +205,79 @@ public class PhraseContextObjective extends ProjectedObjective  		{
  			// now project using the context-tag constraints,
  			// for all c,t: sum_p omega_pct < scaleC
 -			for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
 +			if (pool == null)
  			{
 -				List<Edge> edges = c.c.getEdgesForContext(ctx);
 -				double toProject[] = new double[edges.size()];
 -				for(int tag=0;tag<c.K;tag++)
 +				for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
  				{
 -					for(int e=0; e<edges.size(); e++)
 -						toProject[e] = point[index(edges.get(e), tag, false)];
 -					projectionContext.project(toProject);
 -					for(int e=0; e<edges.size(); e++)
 -						newPoint[index(edges.get(e),tag, false)] = toProject[e];
 +					List<Edge> edges = c.c.getEdgesForContext(ctx);
 +					double toProject[] = new double[edges.size()];
 +					for(int tag=0;tag<c.K;tag++)
 +					{
 +						for(int e=0; e<edges.size(); e++)
 +							toProject[e] = point[index(edges.get(e), tag, false)];
 +						long lbegin = System.currentTimeMillis();
 +						projectionContext.project(toProject);
 +						actualProjectionTime += System.currentTimeMillis() - lbegin;
 +						for(int e=0; e<edges.size(); e++)
 +							newPoint[index(edges.get(e), tag, false)] = toProject[e];
 +					}
 +				}
 +			}
 +			else
 +			{
 +				// do above in parallel using thread pool
 +				for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
 +				{
 +					final int context = ctx;
 +					final double[] inPoint = point;
 +					Runnable task = new Runnable()
 +					{
 +						public void run()
 +						{
 +							List<Edge> edges = c.c.getEdgesForContext(context);
 +							double toProject[] = new double[edges.size()];
 +							for(int tag=0;tag<c.K;tag++)
 +							{
 +								for(int e=0; e<edges.size(); e++)
 +									toProject[e] = inPoint[index(edges.get(e), tag, false)];
 +								projectionContext.project(toProject);
 +								for(int e=0; e<edges.size(); e++)
 +									newPoint[index(edges.get(e), tag, false)] = toProject[e];
 +							}
 +						}
 +					};
 +					tasks.add(pool.submit(task));
  				}
  			}
  		}
 +		
 +		if (pool != null)
 +		{
 +			// wait for all the jobs to complete
 +			Exception failure = null;
 +			for (Future<?> task: tasks)
 +			{
 +				try {
 +					task.get();
 +				} catch (InterruptedException e) {
 +					System.err.println("ERROR: Projection thread interrupted");
 +					e.printStackTrace();
 +					failure = e;
 +				} catch (ExecutionException e) {
 +					System.err.println("ERROR: Projection thread died");
 +					e.printStackTrace();
 +					failure = e;
 +				}
 +			}
 +			// rethrow the exception
 +			if (failure != null)
 +				throw new RuntimeException(failure);
 +		}
 +		
  		double[] tmp = newPoint;
  		newPoint = point;
 +		projectionTime += System.currentTimeMillis() - begin;
 +
  		//System.out.println("\treturning " + Arrays.toString(tmp));
  		return tmp;
 @@ -214,6 +315,11 @@ public class PhraseContextObjective extends ProjectedObjective  	public double[] optimizeWithProjectedGradientDescent()
  	{
 +		projectionTime = 0;
 +		actualProjectionTime = 0;
 +		objectiveTime = 0;
 +		long start = System.currentTimeMillis();
 +
  		LineSearchMethod ls =
  			new ArmijoLineSearchMinimizationAlongProjectionArc
  				(new InterpolationPickFirstStep(INIT_STEP_SIZE));
 @@ -230,20 +336,17 @@ public class PhraseContextObjective extends ProjectedObjective  		compositeStop.add(stopValue);
  		optimizer.setMaxIterations(ITERATIONS);
  		updateFunction();
 -		boolean succed = optimizer.optimize(this,stats,compositeStop);
 +		boolean success = optimizer.optimize(this,stats,compositeStop);
  //		System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
 -		if(succed){
 -			//System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
 -		}else{
 -			System.out.println("Failed to optimize");
 -		}
 -		//	ps.println(Arrays.toString(parameters));
 -		
 -		//	for(int edge=0;edge<data.getSize();edge++){
 -		//	ps.println(Arrays.toString(q[edge]));
 -		//	}
 -		//System.out.println(Arrays.toString(parameters));
 +		if (success)
 +			System.out.print("\toptimization took " + optimizer.getCurrentIteration() + " iterations");
 +	 	else
 +			System.out.print("\toptimization failed to converge");
 +		long total = System.currentTimeMillis() - start;
 +		System.out.println(" and " + total + " ms: projection " + projectionTime + 
 +				" actual " + actualProjectionTime + " objective " + objectiveTime);
 +
  		return parameters;
  	}
 @@ -298,5 +401,4 @@ public class PhraseContextObjective extends ProjectedObjective  	{
  		return loglikelihood() - KL_divergence() - c.scalePT * phrase_l1lmax() - c.scalePT * context_l1lmax();
  	}
 -	
 -}
 +}
\ No newline at end of file diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 015ef106..0a76e2dc 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -22,7 +22,7 @@ public class PhraseObjective extends ProjectedObjective  {
  	static final double GRAD_DIFF = 0.00002;
  	static double INIT_STEP_SIZE = 10;
 -	static double VAL_DIFF = 1e-6; // FIXME needs to be tuned
 +	static double VAL_DIFF = 1e-4; // FIXME needs to be tuned - and this might be too weak
  	static int ITERATIONS = 100;
  	//private double c1=0.0001; // wolf stuff
  	//private double c2=0.9;
 @@ -128,7 +128,8 @@ public class PhraseObjective extends ProjectedObjective  	}
  	@Override
 -	public double[] projectPoint(double[] point) {
 +	public double[] projectPoint(double[] point) 
 +	{
  		double toProject[]=new double[data.size()];
  		for(int tag=0;tag<c.K;tag++){
  			for(int edge=0;edge<data.size();edge++){
 diff --git a/gi/posterior-regularisation/prjava/train-PR-cluster.sh b/gi/posterior-regularisation/prjava/train-PR-cluster.sh new file mode 100755 index 00000000..b86d564b --- /dev/null +++ b/gi/posterior-regularisation/prjava/train-PR-cluster.sh @@ -0,0 +1,4 @@ +#!/bin/sh + +d=`dirname $0` +java -ea -Xmx8g -cp $d/prjava.jar:$d/lib/trove-2.0.2.jar:$d/lib/optimization.jar phrase.PhraseCluster $* | 
