summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-23 19:26:17 +0000
committertrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-23 19:26:17 +0000
commit7b61f71be28539f815a171ba84baeaa90f863e88 (patch)
treebff936400881f788efa4364c4acfd32a3c860f04
parenta2d9d0f96502c7d3c04303f3db36a8602d992287 (diff)
Parallelised VB-EM
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@384 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java178
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java3
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java2
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Trainer.java83
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/VB.java129
5 files changed, 182 insertions, 213 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index ccb6ae9d..c032bb2b 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -4,14 +4,16 @@ import gnu.trove.TIntArrayList;
import org.apache.commons.math.special.Gamma;
import java.io.BufferedReader;
-import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-import java.util.StringTokenizer;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
@@ -56,23 +58,9 @@ public class PhraseCluster {
arr.F.randomise(j, true);
}
- public void initialiseVB(double alphaEmit, double alphaPi)
+ void useThreadPool(ExecutorService pool)
{
- assert alphaEmit > 0;
- assert alphaPi > 0;
-
- for(double [][]i:emit)
- for(double []j:i)
- digammaNormalize(j, alphaEmit);
-
- for(double []j:pi)
- digammaNormalize(j, alphaPi);
- }
-
- void useThreadPool(int threads)
- {
- assert threads > 0;
- pool = Executors.newFixedThreadPool(threads);
+ this.pool = pool;
}
public double EM(int phraseSizeLimit)
@@ -131,107 +119,6 @@ public class PhraseCluster {
return loglikelihood;
}
- public double VBEM(double alphaEmit, double alphaPi)
- {
- // FIXME: broken - needs to be done entirely in log-space
-
- double [][][]exp_emit = new double [K][n_positions][n_words];
- double [][]exp_pi = new double[n_phrases][K];
-
- double loglikelihood=0;
-
- //E
- for(int phrase=0; phrase < n_phrases; phrase++)
- {
- List<Edge> contexts = c.getEdgesForPhrase(phrase);
-
- for (int ctx=0; ctx<contexts.size(); ctx++)
- {
- Edge edge = contexts.get(ctx);
- double p[] = posterior(edge);
- double z = arr.F.l1norm(p);
- assert z > 0;
- loglikelihood += edge.getCount() * Math.log(z);
- arr.F.l1normalize(p);
-
- double count = edge.getCount();
- //increment expected count
- TIntArrayList context = edge.getContext();
- for(int tag=0;tag<K;tag++)
- {
- for(int pos=0;pos<n_positions;pos++)
- exp_emit[tag][pos][context.get(pos)] += p[tag]*count;
- exp_pi[phrase][tag] += p[tag]*count;
- }
- }
- }
-
- // find the KL terms, KL(q||p) where p is symmetric Dirichlet prior and q are the expectations
- double kl = 0;
- for (int phrase=0; phrase < n_phrases; phrase++)
- kl += KL_symmetric_dirichlet(exp_pi[phrase], alphaPi);
-
- for (int tag=0;tag<K;tag++)
- for (int pos=0;pos<n_positions; ++pos)
- kl += this.KL_symmetric_dirichlet(exp_emit[tag][pos], alphaEmit);
- // FIXME: exp_emit[tag][pos] has structural zeros - certain words are *never* seen in that position
-
- //M
- for(double [][]i:exp_emit)
- for(double []j:i)
- digammaNormalize(j, alphaEmit);
- emit=exp_emit;
- for(double []j:exp_pi)
- digammaNormalize(j, alphaPi);
- pi=exp_pi;
-
- System.out.println("KL=" + kl + " llh=" + loglikelihood);
- System.out.println(Arrays.toString(pi[0]));
- System.out.println(Arrays.toString(exp_emit[0][0]));
- return kl + loglikelihood;
- }
-
- public void digammaNormalize(double [] a, double alpha)
- {
- double sum=0;
- for(int i=0;i<a.length;i++)
- sum += a[i];
-
- assert sum > 1e-20;
- double dgs = Gamma.digamma(sum + alpha);
-
- for(int i=0;i<a.length;i++)
- a[i] = Math.exp(Gamma.digamma(a[i] + alpha/a.length) - dgs);
- }
-
- private double KL_symmetric_dirichlet(double[] q, double alpha)
- {
- // assumes that zeros in q are structural & should be skipped
- // FIXME: asssumption doesn't hold
-
- double p0 = alpha;
- double q0 = 0;
- int n = 0;
- for (int i=0; i<q.length; i++)
- {
- if (q[i] > 0)
- {
- q0 += q[i];
- n += 1;
- }
- }
-
- double kl = Gamma.logGamma(q0) - Gamma.logGamma(p0);
- kl += n * Gamma.logGamma(alpha / n);
- double digamma_q0 = Gamma.digamma(q0);
- for (int i=0; i<q.length; i++)
- {
- if (q[i] > 0)
- kl -= -Gamma.logGamma(q[i]) - (q[i] - alpha/q.length) * (Gamma.digamma(q[i]) - digamma_q0);
- }
- return kl;
- }
-
public double PREM(double scalePT, double scaleCT, int phraseSizeLimit)
{
if (scaleCT == 0)
@@ -339,51 +226,44 @@ public class PhraseCluster {
double loglikelihood=0, kl=0, l1lmax=0, primal=0;
final AtomicInteger failures = new AtomicInteger(0);
final AtomicLong elapsed = new AtomicLong(0l);
- int iterations=0, n=n_phrases;
+ int iterations=0;
long start = System.currentTimeMillis();
+ List<Future<PhraseObjective>> results = new ArrayList<Future<PhraseObjective>>();
if (lambdaPT == null && cacheLambda)
lambdaPT = new double[n_phrases][];
//E
- for(int phrase=0;phrase<n_phrases;phrase++){
- if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit)
- {
- n -= 1;
+ for(int phrase=0;phrase<n_phrases;phrase++) {
+ if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit) {
System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
continue;
}
final int p=phrase;
- pool.execute(new Runnable() {
- public void run() {
- try {
- //System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
- long start = System.currentTimeMillis();
- PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT, (cacheLambda) ? lambdaPT[p] : null);
- boolean ok = po.optimizeWithProjectedGradientDescent();
- if (!ok) failures.incrementAndGet();
- long end = System.currentTimeMillis();
- elapsed.addAndGet(end - start);
-
- //System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
- expectations.put(po);
- //System.out.println("" + Thread.currentThread().getId() + " added to queue " + p);
- } catch (InterruptedException e) {
- System.err.println(Thread.currentThread().getId() + " Local e-step thread interrupted; will cause deadlock.");
- e.printStackTrace();
- }
+ results.add(pool.submit(new Callable<PhraseObjective>() {
+ public PhraseObjective call() {
+ //System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
+ long start = System.currentTimeMillis();
+ PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT, (cacheLambda) ? lambdaPT[p] : null);
+ boolean ok = po.optimizeWithProjectedGradientDescent();
+ if (!ok) failures.incrementAndGet();
+ long end = System.currentTimeMillis();
+ elapsed.addAndGet(end - start);
+ //System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
+ return po;
}
- });
+ }));
}
// aggregate the expectations as they become available
- for(int count=0;count<n;count++) {
+ for (Future<PhraseObjective> fpo : results)
+ {
try {
//System.out.println("" + Thread.currentThread().getId() + " reading queue #" + count);
// wait (blocking) until something is ready
- PhraseObjective po = expectations.take();
+ PhraseObjective po = fpo.get();
// process
int phrase = po.phrase;
if (cacheLambda) lambdaPT[phrase] = po.getParameters();
@@ -408,10 +288,12 @@ public class PhraseCluster {
exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
}
}
- } catch (InterruptedException e)
- {
+ } catch (InterruptedException e) {
System.err.println("M-step thread interrupted. Probably fatal!");
- e.printStackTrace();
+ throw new RuntimeException(e);
+ } catch (ExecutionException e) {
+ System.err.println("M-step thread execution died. Probably fatal!");
+ throw new RuntimeException(e);
}
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
index 06a9f8cb..5947c4be 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
@@ -277,7 +277,10 @@ public class PhraseContextObjective extends ProjectedObjective
}
// rethrow the exception
if (failure != null)
+ {
+ pool.shutdownNow();
throw new RuntimeException(failure);
+ }
}
double[] tmp = newPoint;
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
index 7c32d9c0..5efe778a 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
@@ -192,7 +192,7 @@ public class PhraseObjective extends ProjectedObjective
// for(int edge=0;edge<data.getSize();edge++){
// ps.println(Arrays.toString(q[edge]));
// }
-
+
return success;
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
index f205ce67..6f302b20 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
@@ -4,11 +4,12 @@ import io.FileUtil;
import joptsimple.OptionParser;
import joptsimple.OptionSet;
import java.io.File;
-import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintStream;
import java.util.List;
import java.util.Random;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
import phrase.Corpus.Edge;
@@ -18,7 +19,6 @@ public class Trainer
{
public static void main(String[] args)
{
-
OptionParser parser = new OptionParser();
parser.accepts("help");
parser.accepts("in").withRequiredArg().ofType(File.class);
@@ -68,6 +68,10 @@ public class Trainer
if (options.has("seed"))
F.rng = new Random((Long) options.valueOf("seed"));
+ ExecutorService threadPool = null;
+ if (threads > 0)
+ threadPool = Executors.newFixedThreadPool(threads);
+
if (tags <= 1 || scale_phrase < 0 || scale_context < 0 || threshold < 0)
{
System.err.println("Invalid arguments. Try again!");
@@ -114,26 +118,30 @@ public class Trainer
agree = new Agree(tags, corpus);
else
{
- cluster = new PhraseCluster(tags, corpus);
- if (threads > 0) cluster.useThreadPool(threads);
-
- if (vb) {
- //cluster.initialiseVB(alphaEmit, alphaPi);
+ if (vb)
+ {
vbModel=new VB(tags,corpus);
vbModel.alpha=alphaPi;
vbModel.lambda=alphaEmit;
- }
- if (options.has("no-parameter-cache"))
- cluster.cacheLambda = false;
- if (options.has("start"))
+ if (threadPool != null) vbModel.useThreadPool(threadPool);
+ }
+ else
{
- try {
- System.err.println("Reading starting parameters from " + options.valueOf("start"));
- cluster.loadParameters(FileUtil.reader((File)options.valueOf("start")));
- } catch (IOException e) {
- System.err.println("Failed to open input file: " + options.valueOf("start"));
- e.printStackTrace();
- }
+ cluster = new PhraseCluster(tags, corpus);
+ if (threadPool != null) cluster.useThreadPool(threadPool);
+
+ if (options.has("no-parameter-cache"))
+ cluster.cacheLambda = false;
+ if (options.has("start"))
+ {
+ try {
+ System.err.println("Reading starting parameters from " + options.valueOf("start"));
+ cluster.loadParameters(FileUtil.reader((File)options.valueOf("start")));
+ } catch (IOException e) {
+ System.err.println("Failed to open input file: " + options.valueOf("start"));
+ e.printStackTrace();
+ }
+ }
}
}
@@ -143,9 +151,8 @@ public class Trainer
double o;
if (agree != null)
o = agree.EM();
- else if(agree2sides!=null){
+ else if(agree2sides!=null)
o = agree2sides.EM();
- }
else
{
if (i < skip)
@@ -173,11 +180,25 @@ public class Trainer
last = o;
}
- if (cluster == null)
- cluster = agree.model1;
+ double pl1lmax = 0, cl1lmax = 0;
+ if (cluster != null)
+ {
+ pl1lmax = cluster.phrase_l1lmax();
+ cl1lmax = cluster.context_l1lmax();
+ }
+ else if (agree != null)
+ {
+ // fairly arbitrary choice of model1 cf model2
+ pl1lmax = agree.model1.phrase_l1lmax();
+ cl1lmax = agree.model1.context_l1lmax();
+ }
+ else if (agree2sides != null)
+ {
+ // fairly arbitrary choice of model1 cf model2
+ pl1lmax = agree2sides.model1.phrase_l1lmax();
+ cl1lmax = agree2sides.model1.context_l1lmax();
+ }
- double pl1lmax = cluster.phrase_l1lmax();
- double cl1lmax = cluster.context_l1lmax();
System.out.println("\nFinal posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax);
if (options.has("out"))
@@ -194,11 +215,18 @@ public class Trainer
System.out.println("Reading testing concordance from " + infile);
test = corpus.readEdges(FileUtil.reader(infile));
}
- if(vb){
+ if(vb) {
+ assert !options.has("test");
vbModel.displayPosterior(ps);
- }else{
+ } else if (cluster != null)
cluster.displayPosterior(ps, test);
+ else if (agree != null)
+ agree.displayPosterior(ps, test);
+ else if (agree2sides != null) {
+ assert !options.has("test");
+ agree2sides.displayPosterior(ps);
}
+
ps.close();
} catch (IOException e) {
System.err.println("Failed to open either testing file or output file");
@@ -209,6 +237,7 @@ public class Trainer
if (options.has("parameters"))
{
+ assert !vb;
File outfile = (File) options.valueOf("parameters");
PrintStream ps;
try {
@@ -222,7 +251,7 @@ public class Trainer
}
}
- if (cluster.pool != null)
+ if (cluster != null && cluster.pool != null)
cluster.pool.shutdown();
}
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/VB.java b/gi/posterior-regularisation/prjava/src/phrase/VB.java
index a858c883..cd3f4966 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/VB.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/VB.java
@@ -7,8 +7,13 @@ import io.FileUtil;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
import org.apache.commons.math.special.Gamma;
@@ -38,21 +43,17 @@ public class VB {
/**@brief
* variational param for z
*/
- private double phi[][];
+ //private double phi[][];
/**@brief
* variational param for theta
*/
private double gamma[];
private static double VAL_DIFF_RATIO=0.005;
- /**@brief
- * objective for a single document
- */
- private double obj;
-
private int n_positions;
private int n_words;
private int K;
+ private ExecutorService pool;
private Corpus c;
public static void main(String[] args) {
@@ -122,17 +123,14 @@ public class VB {
}
- private void inference(int phraseID){
+ private double inference(int phraseID, double[][] phi, double[] gamma)
+ {
List<Edge > doc=c.getEdgesForPhrase(phraseID);
- phi=new double[doc.size()][K];
for(int i=0;i<phi.length;i++){
for(int j=0;j<phi[i].length;j++){
phi[i][j]=1.0/K;
}
}
- if(gamma==null){
- gamma=new double[K];
- }
Arrays.fill(gamma,alpha+1.0/K);
double digamma_gamma[]=new double[K];
@@ -143,7 +141,7 @@ public class VB {
}
double gammaSum[]=new double [K];
double prev_val=0;
- obj=0;
+ double obj=0;
for(int iter=0;iter<MAX_ITER;iter++){
prev_val=obj;
@@ -224,6 +222,8 @@ public class VB {
break;
}
}//end of inference loop
+
+ return obj;
}//end of inference
/**
@@ -251,31 +251,79 @@ public class VB {
}
}
-
//E
double exp_rho[][][]=new double[K][n_positions][n_words];
- for (int d=0;d<c.getNumPhrases();d++){
- inference(d);
- List<Edge>doc=c.getEdgesForPhrase(d);
- for(int n=0;n<doc.size();n++){
- TIntArrayList context=doc.get(n).getContext();
- for(int pos=0;pos<n_positions;pos++){
- int word=context.get(pos);
- for(int i=0;i<K;i++){
- exp_rho[i][pos][word]+=phi[n][i];
+ if (pool == null)
+ {
+ for (int d=0;d<c.getNumPhrases();d++)
+ {
+ List<Edge > doc=c.getEdgesForPhrase(d);
+ double[][] phi = new double[doc.size()][K];
+ double[] gamma = new double[K];
+
+ emObj += inference(d, phi, gamma);
+
+ for(int n=0;n<doc.size();n++){
+ TIntArrayList context=doc.get(n).getContext();
+ for(int pos=0;pos<n_positions;pos++){
+ int word=context.get(pos);
+ for(int i=0;i<K;i++){
+ exp_rho[i][pos][word]+=phi[n][i];
+ }
}
}
+ //if(d!=0 && d%100==0) System.out.print(".");
+ //if(d!=0 && d%1000==0) System.out.println(d);
}
-/* if(d!=0 && d%100==0){
- System.out.print(".");
- }
- if(d!=0 && d%1000==0){
- System.out.println(d);
- }
-*/
- emObj+=obj;
}
+ else // multi-threaded version of above loop
+ {
+ class PartialEStep implements Callable<PartialEStep>
+ {
+ double[][] phi;
+ double[] gamma;
+ double obj;
+ int d;
+ PartialEStep(int d) { this.d = d; }
+
+ public PartialEStep call()
+ {
+ phi = new double[c.getEdgesForPhrase(d).size()][K];
+ gamma = new double[K];
+ obj = inference(d, phi, gamma);
+ return this;
+ }
+ }
+
+ List<Future<PartialEStep>> jobs = new ArrayList<Future<PartialEStep>>();
+ for (int d=0;d<c.getNumPhrases();d++)
+ jobs.add(pool.submit(new PartialEStep(d)));
+ for (Future<PartialEStep> job: jobs)
+ {
+ try {
+ PartialEStep e = job.get();
+
+ emObj += e.obj;
+ List<Edge> doc = c.getEdgesForPhrase(e.d);
+ for(int n=0;n<doc.size();n++){
+ TIntArrayList context=doc.get(n).getContext();
+ for(int pos=0;pos<n_positions;pos++){
+ int word=context.get(pos);
+ for(int i=0;i<K;i++){
+ exp_rho[i][pos][word]+=e.phi[n][i];
+ }
+ }
+ }
+ } catch (ExecutionException e) {
+ System.err.println("ERROR: E-step thread execution failed.");
+ throw new RuntimeException(e);
+ } catch (InterruptedException e) {
+ System.err.println("ERROR: Failed to join E-step thread.");
+ throw new RuntimeException(e);
+ }
+ }
+ }
// System.out.println("EM Objective:"+emObj);
//M
@@ -309,8 +357,15 @@ public class VB {
public void displayPosterior(PrintStream ps)
{
for(int d=0;d<c.getNumPhrases();d++){
- inference(d);
- List<Edge> doc=c.getEdgesForPhrase(d);
+ List<Edge > doc=c.getEdgesForPhrase(d);
+ double[][] phi = new double[doc.size()][K];
+ for(int i=0;i<phi.length;i++)
+ for(int j=0;j<phi[i].length;j++)
+ phi[i][j]=1.0/K;
+ double[] gamma = new double[K];
+
+ inference(d, phi, gamma);
+
for(int n=0;n<doc.size();n++){
Edge edge=doc.get(n);
int tag=arr.F.argmax(phi[n]);
@@ -328,13 +383,9 @@ public class VB {
double v;
if (log_a < log_b)
- {
v = log_b+Math.log(1 + Math.exp(log_a-log_b));
- }
else
- {
v = log_a+Math.log(1 + Math.exp(log_b-log_a));
- }
return(v);
}
@@ -360,5 +411,9 @@ public class VB {
Math.log(x-2)-Math.log(x-3)-Math.log(x-4)-Math.log(x-5)-Math.log(x-6);
return z;
}
-
+
+ public void useThreadPool(ExecutorService threadPool)
+ {
+ pool = threadPool;
+ }
}//End of class