summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java178
1 files changed, 30 insertions, 148 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index ccb6ae9d..c032bb2b 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -4,14 +4,16 @@ import gnu.trove.TIntArrayList;
import org.apache.commons.math.special.Gamma;
import java.io.BufferedReader;
-import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-import java.util.StringTokenizer;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
@@ -56,23 +58,9 @@ public class PhraseCluster {
arr.F.randomise(j, true);
}
- public void initialiseVB(double alphaEmit, double alphaPi)
+ void useThreadPool(ExecutorService pool)
{
- assert alphaEmit > 0;
- assert alphaPi > 0;
-
- for(double [][]i:emit)
- for(double []j:i)
- digammaNormalize(j, alphaEmit);
-
- for(double []j:pi)
- digammaNormalize(j, alphaPi);
- }
-
- void useThreadPool(int threads)
- {
- assert threads > 0;
- pool = Executors.newFixedThreadPool(threads);
+ this.pool = pool;
}
public double EM(int phraseSizeLimit)
@@ -131,107 +119,6 @@ public class PhraseCluster {
return loglikelihood;
}
- public double VBEM(double alphaEmit, double alphaPi)
- {
- // FIXME: broken - needs to be done entirely in log-space
-
- double [][][]exp_emit = new double [K][n_positions][n_words];
- double [][]exp_pi = new double[n_phrases][K];
-
- double loglikelihood=0;
-
- //E
- for(int phrase=0; phrase < n_phrases; phrase++)
- {
- List<Edge> contexts = c.getEdgesForPhrase(phrase);
-
- for (int ctx=0; ctx<contexts.size(); ctx++)
- {
- Edge edge = contexts.get(ctx);
- double p[] = posterior(edge);
- double z = arr.F.l1norm(p);
- assert z > 0;
- loglikelihood += edge.getCount() * Math.log(z);
- arr.F.l1normalize(p);
-
- double count = edge.getCount();
- //increment expected count
- TIntArrayList context = edge.getContext();
- for(int tag=0;tag<K;tag++)
- {
- for(int pos=0;pos<n_positions;pos++)
- exp_emit[tag][pos][context.get(pos)] += p[tag]*count;
- exp_pi[phrase][tag] += p[tag]*count;
- }
- }
- }
-
- // find the KL terms, KL(q||p) where p is symmetric Dirichlet prior and q are the expectations
- double kl = 0;
- for (int phrase=0; phrase < n_phrases; phrase++)
- kl += KL_symmetric_dirichlet(exp_pi[phrase], alphaPi);
-
- for (int tag=0;tag<K;tag++)
- for (int pos=0;pos<n_positions; ++pos)
- kl += this.KL_symmetric_dirichlet(exp_emit[tag][pos], alphaEmit);
- // FIXME: exp_emit[tag][pos] has structural zeros - certain words are *never* seen in that position
-
- //M
- for(double [][]i:exp_emit)
- for(double []j:i)
- digammaNormalize(j, alphaEmit);
- emit=exp_emit;
- for(double []j:exp_pi)
- digammaNormalize(j, alphaPi);
- pi=exp_pi;
-
- System.out.println("KL=" + kl + " llh=" + loglikelihood);
- System.out.println(Arrays.toString(pi[0]));
- System.out.println(Arrays.toString(exp_emit[0][0]));
- return kl + loglikelihood;
- }
-
- public void digammaNormalize(double [] a, double alpha)
- {
- double sum=0;
- for(int i=0;i<a.length;i++)
- sum += a[i];
-
- assert sum > 1e-20;
- double dgs = Gamma.digamma(sum + alpha);
-
- for(int i=0;i<a.length;i++)
- a[i] = Math.exp(Gamma.digamma(a[i] + alpha/a.length) - dgs);
- }
-
- private double KL_symmetric_dirichlet(double[] q, double alpha)
- {
- // assumes that zeros in q are structural & should be skipped
- // FIXME: asssumption doesn't hold
-
- double p0 = alpha;
- double q0 = 0;
- int n = 0;
- for (int i=0; i<q.length; i++)
- {
- if (q[i] > 0)
- {
- q0 += q[i];
- n += 1;
- }
- }
-
- double kl = Gamma.logGamma(q0) - Gamma.logGamma(p0);
- kl += n * Gamma.logGamma(alpha / n);
- double digamma_q0 = Gamma.digamma(q0);
- for (int i=0; i<q.length; i++)
- {
- if (q[i] > 0)
- kl -= -Gamma.logGamma(q[i]) - (q[i] - alpha/q.length) * (Gamma.digamma(q[i]) - digamma_q0);
- }
- return kl;
- }
-
public double PREM(double scalePT, double scaleCT, int phraseSizeLimit)
{
if (scaleCT == 0)
@@ -339,51 +226,44 @@ public class PhraseCluster {
double loglikelihood=0, kl=0, l1lmax=0, primal=0;
final AtomicInteger failures = new AtomicInteger(0);
final AtomicLong elapsed = new AtomicLong(0l);
- int iterations=0, n=n_phrases;
+ int iterations=0;
long start = System.currentTimeMillis();
+ List<Future<PhraseObjective>> results = new ArrayList<Future<PhraseObjective>>();
if (lambdaPT == null && cacheLambda)
lambdaPT = new double[n_phrases][];
//E
- for(int phrase=0;phrase<n_phrases;phrase++){
- if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit)
- {
- n -= 1;
+ for(int phrase=0;phrase<n_phrases;phrase++) {
+ if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit) {
System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
continue;
}
final int p=phrase;
- pool.execute(new Runnable() {
- public void run() {
- try {
- //System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
- long start = System.currentTimeMillis();
- PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT, (cacheLambda) ? lambdaPT[p] : null);
- boolean ok = po.optimizeWithProjectedGradientDescent();
- if (!ok) failures.incrementAndGet();
- long end = System.currentTimeMillis();
- elapsed.addAndGet(end - start);
-
- //System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
- expectations.put(po);
- //System.out.println("" + Thread.currentThread().getId() + " added to queue " + p);
- } catch (InterruptedException e) {
- System.err.println(Thread.currentThread().getId() + " Local e-step thread interrupted; will cause deadlock.");
- e.printStackTrace();
- }
+ results.add(pool.submit(new Callable<PhraseObjective>() {
+ public PhraseObjective call() {
+ //System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
+ long start = System.currentTimeMillis();
+ PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT, (cacheLambda) ? lambdaPT[p] : null);
+ boolean ok = po.optimizeWithProjectedGradientDescent();
+ if (!ok) failures.incrementAndGet();
+ long end = System.currentTimeMillis();
+ elapsed.addAndGet(end - start);
+ //System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
+ return po;
}
- });
+ }));
}
// aggregate the expectations as they become available
- for(int count=0;count<n;count++) {
+ for (Future<PhraseObjective> fpo : results)
+ {
try {
//System.out.println("" + Thread.currentThread().getId() + " reading queue #" + count);
// wait (blocking) until something is ready
- PhraseObjective po = expectations.take();
+ PhraseObjective po = fpo.get();
// process
int phrase = po.phrase;
if (cacheLambda) lambdaPT[phrase] = po.getParameters();
@@ -408,10 +288,12 @@ public class PhraseCluster {
exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
}
}
- } catch (InterruptedException e)
- {
+ } catch (InterruptedException e) {
System.err.println("M-step thread interrupted. Probably fatal!");
- e.printStackTrace();
+ throw new RuntimeException(e);
+ } catch (ExecutionException e) {
+ System.err.println("M-step thread execution died. Probably fatal!");
+ throw new RuntimeException(e);
}
}