diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava/src')
| -rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java | 19 | ||||
| -rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java | 18 | 
2 files changed, 28 insertions, 9 deletions
| diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 7d7c46dd..b9b1b98c 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -9,6 +9,7 @@ import java.util.List;  import java.util.concurrent.ExecutorService;
  import java.util.concurrent.Executors;
  import java.util.concurrent.LinkedBlockingQueue;
 +import java.util.concurrent.atomic.AtomicInteger;
  import phrase.Corpus.Edge;
 @@ -110,10 +111,13 @@ public class PhraseCluster {  		double [][]exp_pi=new double[n_phrases][K];
  		double loglikelihood=0, kl=0, l1lmax=0, primal=0;
 +		int failures=0, iterations=0;
  		//E
  		for(int phrase=0; phrase<n_phrases; phrase++){
  			PhraseObjective po=new PhraseObjective(this,phrase);
 -			po.optimizeWithProjectedGradientDescent();
 +			boolean ok = po.optimizeWithProjectedGradientDescent();
 +			if (!ok) ++failures;
 +			iterations += po.iterations;
  			double [][] q=po.posterior();
  			loglikelihood += po.loglikelihood();
  			kl += po.KL_divergence();
 @@ -136,6 +140,9 @@ public class PhraseCluster {  			}
  		}
 +		if (failures > 0)
 +			System.out.println("WARNING: failed to converge in " + failures + "/" + n_phrases + " cases");
 +		System.out.println("\tmean iters: 	  " + iterations/(double)n_phrases);
  		System.out.println("\tllh:            " + loglikelihood);
  		System.out.println("\tKL:             " + kl);
  		System.out.println("\tphrase l1lmax:  " + l1lmax);
 @@ -170,6 +177,8 @@ public class PhraseCluster {  		double [][]exp_pi=new double[n_phrases][K];
  		double loglikelihood=0, kl=0, l1lmax=0, primal=0;
 +		final AtomicInteger failures = new AtomicInteger(0);
 +		int iterations=0;
  		//E
  		for(int phrase=0;phrase<n_phrases;phrase++){
 @@ -179,7 +188,8 @@ public class PhraseCluster {  					try {
  						//System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
  						PhraseObjective po = new PhraseObjective(PhraseCluster.this, p);
 -						po.optimizeWithProjectedGradientDescent();
 +						boolean ok = po.optimizeWithProjectedGradientDescent();
 +						if (!ok) failures.incrementAndGet();
  						//System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
  						expectations.put(po);
  						//System.out.println("" + Thread.currentThread().getId() + " added to queue " + p);
 @@ -206,6 +216,8 @@ public class PhraseCluster {  				kl += po.KL_divergence();
  				l1lmax += po.l1lmax();
  				primal += po.primal();
 +				iterations += po.iterations;
 +
  				List<Edge> edges = c.getEdgesForPhrase(phrase);
  				for(int edge=0;edge<q.length;edge++){
 @@ -227,6 +239,9 @@ public class PhraseCluster {  			}
  		}
 +		if (failures.get() > 0)
 +			System.out.println("WARNING: failed to converge in " + failures.get() + "/" + n_phrases + " cases");
 +		System.out.println("\tmean iters: 	  " + iterations/(double)n_phrases);
  		System.out.println("\tllh:            " + loglikelihood);
  		System.out.println("\tKL:             " + kl);
  		System.out.println("\tphrase l1lmax:  " + l1lmax);
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 3314f74a..f24b903d 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -22,7 +22,7 @@ public class PhraseObjective extends ProjectedObjective  {
  	static final double GRAD_DIFF = 0.00002;
  	static double INIT_STEP_SIZE = 300;
 -	static double VAL_DIFF = 1e-4; // FIXME needs to be tuned - and this might be too weak
 +	static double VAL_DIFF = 1e-6; // FIXME needs to be tuned - and this might be too weak
  	static int ITERATIONS = 100;
  	//private double c1=0.0001; // wolf stuff
  	//private double c2=0.9;
 @@ -164,7 +164,9 @@ public class PhraseObjective extends ProjectedObjective  		return q;
  	}
 -	public void optimizeWithProjectedGradientDescent(){
 +	public int iterations = 0;
 +	
 +	public boolean optimizeWithProjectedGradientDescent(){
  		LineSearchMethod ls =
  			new ArmijoLineSearchMinimizationAlongProjectionArc
  				(new InterpolationPickFirstStep(INIT_STEP_SIZE));
 @@ -181,13 +183,14 @@ public class PhraseObjective extends ProjectedObjective  		compositeStop.add(stopValue);
  		optimizer.setMaxIterations(ITERATIONS);
  		updateFunction();
 -		boolean succed = optimizer.optimize(this,stats,compositeStop);
 +		boolean success = optimizer.optimize(this,stats,compositeStop);
 +		iterations += optimizer.getCurrentIteration();
  //		System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
 -		if(succed){
 +		//if(succed){
  			//System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
 -		}else{
 -			System.out.println("Failed to optimize");
 -		}
 +		//}else{
 +//			System.out.println("Failed to optimize");
 +		//}
  		lambda[phrase]=parameters;
  		//	ps.println(Arrays.toString(parameters));
 @@ -195,6 +198,7 @@ public class PhraseObjective extends ProjectedObjective  		//	ps.println(Arrays.toString(q[edge]));
  		//	}
 +		return success;
  	}
  	public double KL_divergence()
 | 
