summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java19
1 files changed, 17 insertions, 2 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index 7d7c46dd..b9b1b98c 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -9,6 +9,7 @@ import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.atomic.AtomicInteger;
import phrase.Corpus.Edge;
@@ -110,10 +111,13 @@ public class PhraseCluster {
double [][]exp_pi=new double[n_phrases][K];
double loglikelihood=0, kl=0, l1lmax=0, primal=0;
+ int failures=0, iterations=0;
//E
for(int phrase=0; phrase<n_phrases; phrase++){
PhraseObjective po=new PhraseObjective(this,phrase);
- po.optimizeWithProjectedGradientDescent();
+ boolean ok = po.optimizeWithProjectedGradientDescent();
+ if (!ok) ++failures;
+ iterations += po.iterations;
double [][] q=po.posterior();
loglikelihood += po.loglikelihood();
kl += po.KL_divergence();
@@ -136,6 +140,9 @@ public class PhraseCluster {
}
}
+ if (failures > 0)
+ System.out.println("WARNING: failed to converge in " + failures + "/" + n_phrases + " cases");
+ System.out.println("\tmean iters: " + iterations/(double)n_phrases);
System.out.println("\tllh: " + loglikelihood);
System.out.println("\tKL: " + kl);
System.out.println("\tphrase l1lmax: " + l1lmax);
@@ -170,6 +177,8 @@ public class PhraseCluster {
double [][]exp_pi=new double[n_phrases][K];
double loglikelihood=0, kl=0, l1lmax=0, primal=0;
+ final AtomicInteger failures = new AtomicInteger(0);
+ int iterations=0;
//E
for(int phrase=0;phrase<n_phrases;phrase++){
@@ -179,7 +188,8 @@ public class PhraseCluster {
try {
//System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
PhraseObjective po = new PhraseObjective(PhraseCluster.this, p);
- po.optimizeWithProjectedGradientDescent();
+ boolean ok = po.optimizeWithProjectedGradientDescent();
+ if (!ok) failures.incrementAndGet();
//System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
expectations.put(po);
//System.out.println("" + Thread.currentThread().getId() + " added to queue " + p);
@@ -206,6 +216,8 @@ public class PhraseCluster {
kl += po.KL_divergence();
l1lmax += po.l1lmax();
primal += po.primal();
+ iterations += po.iterations;
+
List<Edge> edges = c.getEdgesForPhrase(phrase);
for(int edge=0;edge<q.length;edge++){
@@ -227,6 +239,9 @@ public class PhraseCluster {
}
}
+ if (failures.get() > 0)
+ System.out.println("WARNING: failed to converge in " + failures.get() + "/" + n_phrases + " cases");
+ System.out.println("\tmean iters: " + iterations/(double)n_phrases);
System.out.println("\tllh: " + loglikelihood);
System.out.println("\tKL: " + kl);
System.out.println("\tphrase l1lmax: " + l1lmax);