summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java19
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java18
2 files changed, 28 insertions, 9 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index 7d7c46dd..b9b1b98c 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -9,6 +9,7 @@ import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.atomic.AtomicInteger;
import phrase.Corpus.Edge;
@@ -110,10 +111,13 @@ public class PhraseCluster {
double [][]exp_pi=new double[n_phrases][K];
double loglikelihood=0, kl=0, l1lmax=0, primal=0;
+ int failures=0, iterations=0;
//E
for(int phrase=0; phrase<n_phrases; phrase++){
PhraseObjective po=new PhraseObjective(this,phrase);
- po.optimizeWithProjectedGradientDescent();
+ boolean ok = po.optimizeWithProjectedGradientDescent();
+ if (!ok) ++failures;
+ iterations += po.iterations;
double [][] q=po.posterior();
loglikelihood += po.loglikelihood();
kl += po.KL_divergence();
@@ -136,6 +140,9 @@ public class PhraseCluster {
}
}
+ if (failures > 0)
+ System.out.println("WARNING: failed to converge in " + failures + "/" + n_phrases + " cases");
+ System.out.println("\tmean iters: " + iterations/(double)n_phrases);
System.out.println("\tllh: " + loglikelihood);
System.out.println("\tKL: " + kl);
System.out.println("\tphrase l1lmax: " + l1lmax);
@@ -170,6 +177,8 @@ public class PhraseCluster {
double [][]exp_pi=new double[n_phrases][K];
double loglikelihood=0, kl=0, l1lmax=0, primal=0;
+ final AtomicInteger failures = new AtomicInteger(0);
+ int iterations=0;
//E
for(int phrase=0;phrase<n_phrases;phrase++){
@@ -179,7 +188,8 @@ public class PhraseCluster {
try {
//System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
PhraseObjective po = new PhraseObjective(PhraseCluster.this, p);
- po.optimizeWithProjectedGradientDescent();
+ boolean ok = po.optimizeWithProjectedGradientDescent();
+ if (!ok) failures.incrementAndGet();
//System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
expectations.put(po);
//System.out.println("" + Thread.currentThread().getId() + " added to queue " + p);
@@ -206,6 +216,8 @@ public class PhraseCluster {
kl += po.KL_divergence();
l1lmax += po.l1lmax();
primal += po.primal();
+ iterations += po.iterations;
+
List<Edge> edges = c.getEdgesForPhrase(phrase);
for(int edge=0;edge<q.length;edge++){
@@ -227,6 +239,9 @@ public class PhraseCluster {
}
}
+ if (failures.get() > 0)
+ System.out.println("WARNING: failed to converge in " + failures.get() + "/" + n_phrases + " cases");
+ System.out.println("\tmean iters: " + iterations/(double)n_phrases);
System.out.println("\tllh: " + loglikelihood);
System.out.println("\tKL: " + kl);
System.out.println("\tphrase l1lmax: " + l1lmax);
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
index 3314f74a..f24b903d 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
@@ -22,7 +22,7 @@ public class PhraseObjective extends ProjectedObjective
{
static final double GRAD_DIFF = 0.00002;
static double INIT_STEP_SIZE = 300;
- static double VAL_DIFF = 1e-4; // FIXME needs to be tuned - and this might be too weak
+ static double VAL_DIFF = 1e-6; // FIXME needs to be tuned - and this might be too weak
static int ITERATIONS = 100;
//private double c1=0.0001; // wolf stuff
//private double c2=0.9;
@@ -164,7 +164,9 @@ public class PhraseObjective extends ProjectedObjective
return q;
}
- public void optimizeWithProjectedGradientDescent(){
+ public int iterations = 0;
+
+ public boolean optimizeWithProjectedGradientDescent(){
LineSearchMethod ls =
new ArmijoLineSearchMinimizationAlongProjectionArc
(new InterpolationPickFirstStep(INIT_STEP_SIZE));
@@ -181,13 +183,14 @@ public class PhraseObjective extends ProjectedObjective
compositeStop.add(stopValue);
optimizer.setMaxIterations(ITERATIONS);
updateFunction();
- boolean succed = optimizer.optimize(this,stats,compositeStop);
+ boolean success = optimizer.optimize(this,stats,compositeStop);
+ iterations += optimizer.getCurrentIteration();
// System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
- if(succed){
+ //if(succed){
//System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
- }else{
- System.out.println("Failed to optimize");
- }
+ //}else{
+// System.out.println("Failed to optimize");
+ //}
lambda[phrase]=parameters;
// ps.println(Arrays.toString(parameters));
@@ -195,6 +198,7 @@ public class PhraseObjective extends ProjectedObjective
// ps.println(Arrays.toString(q[edge]));
// }
+ return success;
}
public double KL_divergence()