diff options
Diffstat (limited to 'gi')
4 files changed, 48 insertions, 28 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 1f73764e..a369b319 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -2,8 +2,6 @@ package phrase;  import gnu.trove.TIntArrayList;
  import org.apache.commons.math.special.Gamma;
 -import io.FileUtil;
 -import java.io.IOException;
  import java.io.PrintStream;
  import java.util.Arrays;
  import java.util.List;
 @@ -11,9 +9,10 @@ import java.util.concurrent.ExecutorService;  import java.util.concurrent.Executors;
  import java.util.concurrent.LinkedBlockingQueue;
  import java.util.concurrent.atomic.AtomicInteger;
 +import java.util.concurrent.atomic.AtomicLong;
  import phrase.Corpus.Edge;
 -import util.MathUtil;
 +
  public class PhraseCluster {
 @@ -21,7 +20,11 @@ public class PhraseCluster {  	private int n_phrases, n_words, n_contexts, n_positions;
  	public Corpus c;
  	public ExecutorService pool; 
 -	
 +
 +	double[] lambdaPTCT;
 +	double[][] lambdaPT;
 +	boolean cacheLambda = true;
 +
  	// emit[tag][position][word] = p(word | tag, position in context)
  	double emit[][][];
  	// pi[phrase][tag] = p(tag | phrase)
 @@ -232,14 +235,19 @@ public class PhraseCluster {  	{
  		double [][][]exp_emit=new double[K][n_positions][n_words];
  		double [][]exp_pi=new double[n_phrases][K];
 +
 +		if (lambdaPT == null && cacheLambda)
 +			lambdaPT = new double[n_phrases][];
  		double loglikelihood=0, kl=0, l1lmax=0, primal=0;
  		int failures=0, iterations=0;
 +		long start = System.currentTimeMillis();
  		//E
  		for(int phrase=0; phrase<n_phrases; phrase++){
 -			PhraseObjective po=new PhraseObjective(this, phrase, scalePT);
 +			PhraseObjective po = new PhraseObjective(this, phrase, scalePT, (cacheLambda) ? lambdaPT[phrase] : null);
  			boolean ok = po.optimizeWithProjectedGradientDescent();
  			if (!ok) ++failures;
 +			if (cacheLambda) lambdaPT[phrase] = po.getParameters();
  			iterations += po.getNumberUpdateCalls();
  			double [][] q=po.posterior();
  			loglikelihood += po.loglikelihood();
 @@ -263,9 +271,10 @@ public class PhraseCluster {  			}
  		}
 +		long end = System.currentTimeMillis();
  		if (failures > 0)
  			System.out.println("WARNING: failed to converge in " + failures + "/" + n_phrases + " cases");
 -		System.out.println("\tmean iters:     " + iterations/(double)n_phrases);
 +		System.out.println("\tmean iters:     " + iterations/(double)n_phrases + " elapsed time " + (end - start) / 1000.0);
  		System.out.println("\tllh:            " + loglikelihood);
  		System.out.println("\tKL:             " + kl);
  		System.out.println("\tphrase l1lmax:  " + l1lmax);
 @@ -295,7 +304,12 @@ public class PhraseCluster {  		double loglikelihood=0, kl=0, l1lmax=0, primal=0;
  		final AtomicInteger failures = new AtomicInteger(0);
 +		final AtomicLong elapsed = new AtomicLong(0l);
  		int iterations=0;
 +		long start = System.currentTimeMillis();
 +		
 +		if (lambdaPT == null && cacheLambda)
 +			lambdaPT = new double[n_phrases][];
  		//E
  		for(int phrase=0;phrase<n_phrases;phrase++){
 @@ -304,9 +318,13 @@ public class PhraseCluster {  				public void run() {
  					try {
  						//System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
 -						PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT);
 +						long start = System.currentTimeMillis();
 +						PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT, (cacheLambda) ? lambdaPT[p] : null);
  						boolean ok = po.optimizeWithProjectedGradientDescent();
  						if (!ok) failures.incrementAndGet();
 +						long end = System.currentTimeMillis();
 +						elapsed.addAndGet(end - start);
 +
  						//System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
  						expectations.put(po);
  						//System.out.println("" + Thread.currentThread().getId() + " added to queue " + p);
 @@ -327,6 +345,7 @@ public class PhraseCluster {  				PhraseObjective po = expectations.take();
  				// process
  				int phrase = po.phrase;
 +				if (cacheLambda) lambdaPT[phrase] = po.getParameters();
  				//System.out.println("" + Thread.currentThread().getId() + " taken phrase " + phrase);
  				double [][] q=po.posterior();
  				loglikelihood += po.loglikelihood();
 @@ -335,7 +354,6 @@ public class PhraseCluster {  				primal += po.primal(scalePT);
  				iterations += po.getNumberUpdateCalls();
 -				
  				List<Edge> edges = c.getEdgesForPhrase(phrase);
  				for(int edge=0;edge<q.length;edge++){
  					Edge e = edges.get(edge);
 @@ -356,9 +374,11 @@ public class PhraseCluster {  			}
  		}
 +		long end = System.currentTimeMillis();
 +		
  		if (failures.get() > 0)
  			System.out.println("WARNING: failed to converge in " + failures.get() + "/" + n_phrases + " cases");
 -		System.out.println("\tmean iters:     " + iterations/(double)n_phrases);
 +		System.out.println("\tmean iters:     " + iterations/(double)n_phrases + " walltime " + (end-start)/1000.0 + " threads " + elapsed.get() / 1000.0);
  		System.out.println("\tllh:            " + loglikelihood);
  		System.out.println("\tKL:             " + kl);
  		System.out.println("\tphrase l1lmax:  " + l1lmax);
 @@ -376,16 +396,15 @@ public class PhraseCluster {  		return primal;
  	}
 -	double[] lambda;
 -
  	public double PREM_phrase_context_constraints(double scalePT, double scaleCT)
  	{	
  		double[][][] exp_emit = new double [K][n_positions][n_words];
  		double[][] exp_pi = new double[n_phrases][K];
  		//E step
 -		PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool, scalePT, scaleCT);
 -		lambda = pco.optimizeWithProjectedGradientDescent();
 +		PhraseContextObjective pco = new PhraseContextObjective(this, lambdaPTCT, pool, scalePT, scaleCT);
 +		boolean ok = pco.optimizeWithProjectedGradientDescent();
 +		if (cacheLambda) lambdaPTCT = pco.getParameters();
  		//now extract expectations
  		List<Corpus.Edge> edges = c.getEdges();
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java index 7e6c7f60..06a9f8cb 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java @@ -318,7 +318,7 @@ public class PhraseContextObjective extends ProjectedObjective  		return q[edgeIndex];
  	}
 -	public double[] optimizeWithProjectedGradientDescent()
 +	public boolean optimizeWithProjectedGradientDescent()
  	{
  		projectionTime = 0;
  		actualProjectionTime = 0;
 @@ -354,7 +354,7 @@ public class PhraseContextObjective extends ProjectedObjective  		System.out.println(" and " + total + " ms: projection " + projectionTime + 
  				" actual " + actualProjectionTime + " objective " + objectiveTime);
 -		return parameters;
 +		return success;
  	}
  	double loglikelihood()
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index e62b62f4..7c32d9c0 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -25,7 +25,7 @@ public class PhraseObjective extends ProjectedObjective  	static int ITERATIONS = 100;
  	//private double c1=0.0001; // wolf stuff
  	//private double c2=0.9;
 -	private static double lambda[][];
 +	//private static double lambda[][];
  	private PhraseCluster c;
  	/**@brief
 @@ -64,23 +64,18 @@ public class PhraseObjective extends ProjectedObjective  	 */
  	public double llh;
 -	public PhraseObjective(PhraseCluster cluster, int phraseIdx, double scale){
 +	public PhraseObjective(PhraseCluster cluster, int phraseIdx, double scale, double[] lambda){
  		phrase=phraseIdx;
  		c=cluster;
  		data=c.c.getEdgesForPhrase(phrase);
  		n_param=data.size()*c.K;
  		//System.out.println("Num parameters " + n_param + " for phrase #" + phraseIdx);
 -		if (lambda==null){
 -			lambda=new double[c.c.getNumPhrases()][];
 -		}
 -		
 -		if (lambda[phrase]==null){
 -			lambda[phrase]=new double[n_param];
 -		}
 +		if (lambda==null) 
 +			lambda=new double[n_param];
 -		parameters=lambda[phrase];
 -		newPoint  = new double[n_param];
 +		parameters = lambda;
 +		newPoint = new double[n_param];
  		gradient = new double[n_param];
  		initP();
  		projection=new SimplexProjection(scale);
 @@ -163,8 +158,12 @@ public class PhraseObjective extends ProjectedObjective  	public double [][]posterior(){
  		return q;
  	}
 -		
 +	
 +	long optimizationTime;
 +	
  	public boolean optimizeWithProjectedGradientDescent(){
 +		long start = System.currentTimeMillis();
 +		
  		LineSearchMethod ls =
  			new ArmijoLineSearchMinimizationAlongProjectionArc
  				(new InterpolationPickFirstStep(INIT_STEP_SIZE));
 @@ -188,7 +187,6 @@ public class PhraseObjective extends ProjectedObjective  		//}else{
  //			System.out.println("Failed to optimize");
  		//}
 -		lambda[phrase]=parameters;
  		//	ps.println(Arrays.toString(parameters));
  		//	for(int edge=0;edge<data.getSize();edge++){
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index 240c4d64..20f6c905 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -31,6 +31,7 @@ public class Trainer          parser.accepts("alpha-emit").withRequiredArg().ofType(Double.class).defaultsTo(0.1);          parser.accepts("alpha-pi").withRequiredArg().ofType(Double.class).defaultsTo(0.01);          parser.accepts("agree"); +        parser.accepts("no-parameter-cache");          OptionSet options = parser.parse(args);          if (options.has("help") || !options.has("in")) @@ -96,6 +97,8 @@ public class Trainer   			cluster = new PhraseCluster(tags, corpus);   			if (threads > 0) cluster.useThreadPool(threads);   			if (vb)	cluster.initialiseVB(alphaEmit, alphaPi); + 			if (options.has("no-parameter-cache"))  + 				cluster.cacheLambda = false;   		}  		double last = 0;  | 
