From 78763d1966bc6bb7702906b73aeb6b154577418e Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Fri, 9 Jul 2010 19:39:33 +0000 Subject: Prettyfied output regarding optimization failure. git-svn-id: https://ws10smt.googlecode.com/svn/trunk@210 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../prjava/src/phrase/PhraseCluster.java | 19 +++++++++++++++++-- .../prjava/src/phrase/PhraseObjective.java | 18 +++++++++++------- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 7d7c46dd..b9b1b98c 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -9,6 +9,7 @@ import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; import phrase.Corpus.Edge; @@ -110,10 +111,13 @@ public class PhraseCluster { double [][]exp_pi=new double[n_phrases][K]; double loglikelihood=0, kl=0, l1lmax=0, primal=0; + int failures=0, iterations=0; //E for(int phrase=0; phrase 0) + System.out.println("WARNING: failed to converge in " + failures + "/" + n_phrases + " cases"); + System.out.println("\tmean iters: " + iterations/(double)n_phrases); System.out.println("\tllh: " + loglikelihood); System.out.println("\tKL: " + kl); System.out.println("\tphrase l1lmax: " + l1lmax); @@ -170,6 +177,8 @@ public class PhraseCluster { double [][]exp_pi=new double[n_phrases][K]; double loglikelihood=0, kl=0, l1lmax=0, primal=0; + final AtomicInteger failures = new AtomicInteger(0); + int iterations=0; //E for(int phrase=0;phrase edges = c.getEdgesForPhrase(phrase); for(int edge=0;edge 0) + System.out.println("WARNING: failed to converge in " + failures.get() + "/" + n_phrases + " cases"); + System.out.println("\tmean iters: " + iterations/(double)n_phrases); System.out.println("\tllh: " + loglikelihood); System.out.println("\tKL: " + kl); System.out.println("\tphrase l1lmax: " + l1lmax); diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 3314f74a..f24b903d 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -22,7 +22,7 @@ public class PhraseObjective extends ProjectedObjective { static final double GRAD_DIFF = 0.00002; static double INIT_STEP_SIZE = 300; - static double VAL_DIFF = 1e-4; // FIXME needs to be tuned - and this might be too weak + static double VAL_DIFF = 1e-6; // FIXME needs to be tuned - and this might be too weak static int ITERATIONS = 100; //private double c1=0.0001; // wolf stuff //private double c2=0.9; @@ -164,7 +164,9 @@ public class PhraseObjective extends ProjectedObjective return q; } - public void optimizeWithProjectedGradientDescent(){ + public int iterations = 0; + + public boolean optimizeWithProjectedGradientDescent(){ LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc (new InterpolationPickFirstStep(INIT_STEP_SIZE)); @@ -181,13 +183,14 @@ public class PhraseObjective extends ProjectedObjective compositeStop.add(stopValue); optimizer.setMaxIterations(ITERATIONS); updateFunction(); - boolean succed = optimizer.optimize(this,stats,compositeStop); + boolean success = optimizer.optimize(this,stats,compositeStop); + iterations += optimizer.getCurrentIteration(); // System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1)); - if(succed){ + //if(succed){ //System.out.println("Ended optimization in " + optimizer.getCurrentIteration()); - }else{ - System.out.println("Failed to optimize"); - } + //}else{ +// System.out.println("Failed to optimize"); + //} lambda[phrase]=parameters; // ps.println(Arrays.toString(parameters)); @@ -195,6 +198,7 @@ public class PhraseObjective extends ProjectedObjective // ps.println(Arrays.toString(q[edge])); // } + return success; } public double KL_divergence() -- cgit v1.2.3