diff options
author | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-23 19:26:17 +0000 |
---|---|---|
committer | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-23 19:26:17 +0000 |
commit | 7b61f71be28539f815a171ba84baeaa90f863e88 (patch) | |
tree | bff936400881f788efa4364c4acfd32a3c860f04 /gi/posterior-regularisation/prjava/src/phrase | |
parent | a2d9d0f96502c7d3c04303f3db36a8602d992287 (diff) |
Parallelised VB-EM
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@384 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase')
5 files changed, 182 insertions, 213 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index ccb6ae9d..c032bb2b 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -4,14 +4,16 @@ import gnu.trove.TIntArrayList; import org.apache.commons.math.special.Gamma;
import java.io.BufferedReader;
-import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-import java.util.StringTokenizer;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
@@ -56,23 +58,9 @@ public class PhraseCluster { arr.F.randomise(j, true);
}
- public void initialiseVB(double alphaEmit, double alphaPi)
+ void useThreadPool(ExecutorService pool)
{
- assert alphaEmit > 0;
- assert alphaPi > 0;
-
- for(double [][]i:emit)
- for(double []j:i)
- digammaNormalize(j, alphaEmit);
-
- for(double []j:pi)
- digammaNormalize(j, alphaPi);
- }
-
- void useThreadPool(int threads)
- {
- assert threads > 0;
- pool = Executors.newFixedThreadPool(threads);
+ this.pool = pool;
}
public double EM(int phraseSizeLimit)
@@ -131,107 +119,6 @@ public class PhraseCluster { return loglikelihood;
}
- public double VBEM(double alphaEmit, double alphaPi)
- {
- // FIXME: broken - needs to be done entirely in log-space
-
- double [][][]exp_emit = new double [K][n_positions][n_words];
- double [][]exp_pi = new double[n_phrases][K];
-
- double loglikelihood=0;
-
- //E
- for(int phrase=0; phrase < n_phrases; phrase++)
- {
- List<Edge> contexts = c.getEdgesForPhrase(phrase);
-
- for (int ctx=0; ctx<contexts.size(); ctx++)
- {
- Edge edge = contexts.get(ctx);
- double p[] = posterior(edge);
- double z = arr.F.l1norm(p);
- assert z > 0;
- loglikelihood += edge.getCount() * Math.log(z);
- arr.F.l1normalize(p);
-
- double count = edge.getCount();
- //increment expected count
- TIntArrayList context = edge.getContext();
- for(int tag=0;tag<K;tag++)
- {
- for(int pos=0;pos<n_positions;pos++)
- exp_emit[tag][pos][context.get(pos)] += p[tag]*count;
- exp_pi[phrase][tag] += p[tag]*count;
- }
- }
- }
-
- // find the KL terms, KL(q||p) where p is symmetric Dirichlet prior and q are the expectations
- double kl = 0;
- for (int phrase=0; phrase < n_phrases; phrase++)
- kl += KL_symmetric_dirichlet(exp_pi[phrase], alphaPi);
-
- for (int tag=0;tag<K;tag++)
- for (int pos=0;pos<n_positions; ++pos)
- kl += this.KL_symmetric_dirichlet(exp_emit[tag][pos], alphaEmit);
- // FIXME: exp_emit[tag][pos] has structural zeros - certain words are *never* seen in that position
-
- //M
- for(double [][]i:exp_emit)
- for(double []j:i)
- digammaNormalize(j, alphaEmit);
- emit=exp_emit;
- for(double []j:exp_pi)
- digammaNormalize(j, alphaPi);
- pi=exp_pi;
-
- System.out.println("KL=" + kl + " llh=" + loglikelihood);
- System.out.println(Arrays.toString(pi[0]));
- System.out.println(Arrays.toString(exp_emit[0][0]));
- return kl + loglikelihood;
- }
-
- public void digammaNormalize(double [] a, double alpha)
- {
- double sum=0;
- for(int i=0;i<a.length;i++)
- sum += a[i];
-
- assert sum > 1e-20;
- double dgs = Gamma.digamma(sum + alpha);
-
- for(int i=0;i<a.length;i++)
- a[i] = Math.exp(Gamma.digamma(a[i] + alpha/a.length) - dgs);
- }
-
- private double KL_symmetric_dirichlet(double[] q, double alpha)
- {
- // assumes that zeros in q are structural & should be skipped
- // FIXME: asssumption doesn't hold
-
- double p0 = alpha;
- double q0 = 0;
- int n = 0;
- for (int i=0; i<q.length; i++)
- {
- if (q[i] > 0)
- {
- q0 += q[i];
- n += 1;
- }
- }
-
- double kl = Gamma.logGamma(q0) - Gamma.logGamma(p0);
- kl += n * Gamma.logGamma(alpha / n);
- double digamma_q0 = Gamma.digamma(q0);
- for (int i=0; i<q.length; i++)
- {
- if (q[i] > 0)
- kl -= -Gamma.logGamma(q[i]) - (q[i] - alpha/q.length) * (Gamma.digamma(q[i]) - digamma_q0);
- }
- return kl;
- }
-
public double PREM(double scalePT, double scaleCT, int phraseSizeLimit)
{
if (scaleCT == 0)
@@ -339,51 +226,44 @@ public class PhraseCluster { double loglikelihood=0, kl=0, l1lmax=0, primal=0;
final AtomicInteger failures = new AtomicInteger(0);
final AtomicLong elapsed = new AtomicLong(0l);
- int iterations=0, n=n_phrases;
+ int iterations=0;
long start = System.currentTimeMillis();
+ List<Future<PhraseObjective>> results = new ArrayList<Future<PhraseObjective>>();
if (lambdaPT == null && cacheLambda)
lambdaPT = new double[n_phrases][];
//E
- for(int phrase=0;phrase<n_phrases;phrase++){
- if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit)
- {
- n -= 1;
+ for(int phrase=0;phrase<n_phrases;phrase++) {
+ if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit) {
System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
continue;
}
final int p=phrase;
- pool.execute(new Runnable() {
- public void run() {
- try {
- //System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
- long start = System.currentTimeMillis();
- PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT, (cacheLambda) ? lambdaPT[p] : null);
- boolean ok = po.optimizeWithProjectedGradientDescent();
- if (!ok) failures.incrementAndGet();
- long end = System.currentTimeMillis();
- elapsed.addAndGet(end - start);
-
- //System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
- expectations.put(po);
- //System.out.println("" + Thread.currentThread().getId() + " added to queue " + p);
- } catch (InterruptedException e) {
- System.err.println(Thread.currentThread().getId() + " Local e-step thread interrupted; will cause deadlock.");
- e.printStackTrace();
- }
+ results.add(pool.submit(new Callable<PhraseObjective>() {
+ public PhraseObjective call() {
+ //System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
+ long start = System.currentTimeMillis();
+ PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT, (cacheLambda) ? lambdaPT[p] : null);
+ boolean ok = po.optimizeWithProjectedGradientDescent();
+ if (!ok) failures.incrementAndGet();
+ long end = System.currentTimeMillis();
+ elapsed.addAndGet(end - start);
+ //System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
+ return po;
}
- });
+ }));
}
// aggregate the expectations as they become available
- for(int count=0;count<n;count++) {
+ for (Future<PhraseObjective> fpo : results)
+ {
try {
//System.out.println("" + Thread.currentThread().getId() + " reading queue #" + count);
// wait (blocking) until something is ready
- PhraseObjective po = expectations.take();
+ PhraseObjective po = fpo.get();
// process
int phrase = po.phrase;
if (cacheLambda) lambdaPT[phrase] = po.getParameters();
@@ -408,10 +288,12 @@ public class PhraseCluster { exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
}
}
- } catch (InterruptedException e)
- {
+ } catch (InterruptedException e) {
System.err.println("M-step thread interrupted. Probably fatal!");
- e.printStackTrace();
+ throw new RuntimeException(e);
+ } catch (ExecutionException e) {
+ System.err.println("M-step thread execution died. Probably fatal!");
+ throw new RuntimeException(e);
}
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java index 06a9f8cb..5947c4be 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java @@ -277,7 +277,10 @@ public class PhraseContextObjective extends ProjectedObjective }
// rethrow the exception
if (failure != null)
+ {
+ pool.shutdownNow();
throw new RuntimeException(failure);
+ }
}
double[] tmp = newPoint;
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 7c32d9c0..5efe778a 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -192,7 +192,7 @@ public class PhraseObjective extends ProjectedObjective // for(int edge=0;edge<data.getSize();edge++){
// ps.println(Arrays.toString(q[edge]));
// }
-
+
return success;
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index f205ce67..6f302b20 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -4,11 +4,12 @@ import io.FileUtil; import joptsimple.OptionParser; import joptsimple.OptionSet; import java.io.File; -import java.io.FileNotFoundException; import java.io.IOException; import java.io.PrintStream; import java.util.List; import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import phrase.Corpus.Edge; @@ -18,7 +19,6 @@ public class Trainer { public static void main(String[] args) { - OptionParser parser = new OptionParser(); parser.accepts("help"); parser.accepts("in").withRequiredArg().ofType(File.class); @@ -68,6 +68,10 @@ public class Trainer if (options.has("seed")) F.rng = new Random((Long) options.valueOf("seed")); + ExecutorService threadPool = null; + if (threads > 0) + threadPool = Executors.newFixedThreadPool(threads); + if (tags <= 1 || scale_phrase < 0 || scale_context < 0 || threshold < 0) { System.err.println("Invalid arguments. Try again!"); @@ -114,26 +118,30 @@ public class Trainer agree = new Agree(tags, corpus); else { - cluster = new PhraseCluster(tags, corpus); - if (threads > 0) cluster.useThreadPool(threads); - - if (vb) { - //cluster.initialiseVB(alphaEmit, alphaPi); + if (vb) + { vbModel=new VB(tags,corpus); vbModel.alpha=alphaPi; vbModel.lambda=alphaEmit; - } - if (options.has("no-parameter-cache")) - cluster.cacheLambda = false; - if (options.has("start")) + if (threadPool != null) vbModel.useThreadPool(threadPool); + } + else { - try { - System.err.println("Reading starting parameters from " + options.valueOf("start")); - cluster.loadParameters(FileUtil.reader((File)options.valueOf("start"))); - } catch (IOException e) { - System.err.println("Failed to open input file: " + options.valueOf("start")); - e.printStackTrace(); - } + cluster = new PhraseCluster(tags, corpus); + if (threadPool != null) cluster.useThreadPool(threadPool); + + if (options.has("no-parameter-cache")) + cluster.cacheLambda = false; + if (options.has("start")) + { + try { + System.err.println("Reading starting parameters from " + options.valueOf("start")); + cluster.loadParameters(FileUtil.reader((File)options.valueOf("start"))); + } catch (IOException e) { + System.err.println("Failed to open input file: " + options.valueOf("start")); + e.printStackTrace(); + } + } } } @@ -143,9 +151,8 @@ public class Trainer double o; if (agree != null) o = agree.EM(); - else if(agree2sides!=null){ + else if(agree2sides!=null) o = agree2sides.EM(); - } else { if (i < skip) @@ -173,11 +180,25 @@ public class Trainer last = o; } - if (cluster == null) - cluster = agree.model1; + double pl1lmax = 0, cl1lmax = 0; + if (cluster != null) + { + pl1lmax = cluster.phrase_l1lmax(); + cl1lmax = cluster.context_l1lmax(); + } + else if (agree != null) + { + // fairly arbitrary choice of model1 cf model2 + pl1lmax = agree.model1.phrase_l1lmax(); + cl1lmax = agree.model1.context_l1lmax(); + } + else if (agree2sides != null) + { + // fairly arbitrary choice of model1 cf model2 + pl1lmax = agree2sides.model1.phrase_l1lmax(); + cl1lmax = agree2sides.model1.context_l1lmax(); + } - double pl1lmax = cluster.phrase_l1lmax(); - double cl1lmax = cluster.context_l1lmax(); System.out.println("\nFinal posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax); if (options.has("out")) @@ -194,11 +215,18 @@ public class Trainer System.out.println("Reading testing concordance from " + infile); test = corpus.readEdges(FileUtil.reader(infile)); } - if(vb){ + if(vb) { + assert !options.has("test"); vbModel.displayPosterior(ps); - }else{ + } else if (cluster != null) cluster.displayPosterior(ps, test); + else if (agree != null) + agree.displayPosterior(ps, test); + else if (agree2sides != null) { + assert !options.has("test"); + agree2sides.displayPosterior(ps); } + ps.close(); } catch (IOException e) { System.err.println("Failed to open either testing file or output file"); @@ -209,6 +237,7 @@ public class Trainer if (options.has("parameters")) { + assert !vb; File outfile = (File) options.valueOf("parameters"); PrintStream ps; try { @@ -222,7 +251,7 @@ public class Trainer } } - if (cluster.pool != null) + if (cluster != null && cluster.pool != null) cluster.pool.shutdown(); } } diff --git a/gi/posterior-regularisation/prjava/src/phrase/VB.java b/gi/posterior-regularisation/prjava/src/phrase/VB.java index a858c883..cd3f4966 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/VB.java +++ b/gi/posterior-regularisation/prjava/src/phrase/VB.java @@ -7,8 +7,13 @@ import io.FileUtil; import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
import org.apache.commons.math.special.Gamma;
@@ -38,21 +43,17 @@ public class VB { /**@brief
* variational param for z
*/
- private double phi[][];
+ //private double phi[][];
/**@brief
* variational param for theta
*/
private double gamma[];
private static double VAL_DIFF_RATIO=0.005;
- /**@brief
- * objective for a single document
- */
- private double obj;
-
private int n_positions;
private int n_words;
private int K;
+ private ExecutorService pool;
private Corpus c;
public static void main(String[] args) {
@@ -122,17 +123,14 @@ public class VB { }
- private void inference(int phraseID){
+ private double inference(int phraseID, double[][] phi, double[] gamma)
+ {
List<Edge > doc=c.getEdgesForPhrase(phraseID);
- phi=new double[doc.size()][K];
for(int i=0;i<phi.length;i++){
for(int j=0;j<phi[i].length;j++){
phi[i][j]=1.0/K;
}
}
- if(gamma==null){
- gamma=new double[K];
- }
Arrays.fill(gamma,alpha+1.0/K);
double digamma_gamma[]=new double[K];
@@ -143,7 +141,7 @@ public class VB { }
double gammaSum[]=new double [K];
double prev_val=0;
- obj=0;
+ double obj=0;
for(int iter=0;iter<MAX_ITER;iter++){
prev_val=obj;
@@ -224,6 +222,8 @@ public class VB { break;
}
}//end of inference loop
+
+ return obj;
}//end of inference
/**
@@ -251,31 +251,79 @@ public class VB { }
}
-
//E
double exp_rho[][][]=new double[K][n_positions][n_words];
- for (int d=0;d<c.getNumPhrases();d++){
- inference(d);
- List<Edge>doc=c.getEdgesForPhrase(d);
- for(int n=0;n<doc.size();n++){
- TIntArrayList context=doc.get(n).getContext();
- for(int pos=0;pos<n_positions;pos++){
- int word=context.get(pos);
- for(int i=0;i<K;i++){
- exp_rho[i][pos][word]+=phi[n][i];
+ if (pool == null)
+ {
+ for (int d=0;d<c.getNumPhrases();d++)
+ {
+ List<Edge > doc=c.getEdgesForPhrase(d);
+ double[][] phi = new double[doc.size()][K];
+ double[] gamma = new double[K];
+
+ emObj += inference(d, phi, gamma);
+
+ for(int n=0;n<doc.size();n++){
+ TIntArrayList context=doc.get(n).getContext();
+ for(int pos=0;pos<n_positions;pos++){
+ int word=context.get(pos);
+ for(int i=0;i<K;i++){
+ exp_rho[i][pos][word]+=phi[n][i];
+ }
}
}
+ //if(d!=0 && d%100==0) System.out.print(".");
+ //if(d!=0 && d%1000==0) System.out.println(d);
}
-/* if(d!=0 && d%100==0){
- System.out.print(".");
- }
- if(d!=0 && d%1000==0){
- System.out.println(d);
- }
-*/
- emObj+=obj;
}
+ else // multi-threaded version of above loop
+ {
+ class PartialEStep implements Callable<PartialEStep>
+ {
+ double[][] phi;
+ double[] gamma;
+ double obj;
+ int d;
+ PartialEStep(int d) { this.d = d; }
+
+ public PartialEStep call()
+ {
+ phi = new double[c.getEdgesForPhrase(d).size()][K];
+ gamma = new double[K];
+ obj = inference(d, phi, gamma);
+ return this;
+ }
+ }
+
+ List<Future<PartialEStep>> jobs = new ArrayList<Future<PartialEStep>>();
+ for (int d=0;d<c.getNumPhrases();d++)
+ jobs.add(pool.submit(new PartialEStep(d)));
+ for (Future<PartialEStep> job: jobs)
+ {
+ try {
+ PartialEStep e = job.get();
+
+ emObj += e.obj;
+ List<Edge> doc = c.getEdgesForPhrase(e.d);
+ for(int n=0;n<doc.size();n++){
+ TIntArrayList context=doc.get(n).getContext();
+ for(int pos=0;pos<n_positions;pos++){
+ int word=context.get(pos);
+ for(int i=0;i<K;i++){
+ exp_rho[i][pos][word]+=e.phi[n][i];
+ }
+ }
+ }
+ } catch (ExecutionException e) {
+ System.err.println("ERROR: E-step thread execution failed.");
+ throw new RuntimeException(e);
+ } catch (InterruptedException e) {
+ System.err.println("ERROR: Failed to join E-step thread.");
+ throw new RuntimeException(e);
+ }
+ }
+ }
// System.out.println("EM Objective:"+emObj);
//M
@@ -309,8 +357,15 @@ public class VB { public void displayPosterior(PrintStream ps)
{
for(int d=0;d<c.getNumPhrases();d++){
- inference(d);
- List<Edge> doc=c.getEdgesForPhrase(d);
+ List<Edge > doc=c.getEdgesForPhrase(d);
+ double[][] phi = new double[doc.size()][K];
+ for(int i=0;i<phi.length;i++)
+ for(int j=0;j<phi[i].length;j++)
+ phi[i][j]=1.0/K;
+ double[] gamma = new double[K];
+
+ inference(d, phi, gamma);
+
for(int n=0;n<doc.size();n++){
Edge edge=doc.get(n);
int tag=arr.F.argmax(phi[n]);
@@ -328,13 +383,9 @@ public class VB { double v;
if (log_a < log_b)
- {
v = log_b+Math.log(1 + Math.exp(log_a-log_b));
- }
else
- {
v = log_a+Math.log(1 + Math.exp(log_b-log_a));
- }
return(v);
}
@@ -360,5 +411,9 @@ public class VB { Math.log(x-2)-Math.log(x-3)-Math.log(x-4)-Math.log(x-5)-Math.log(x-6);
return z;
}
-
+
+ public void useThreadPool(ExecutorService threadPool)
+ {
+ pool = threadPool;
+ }
}//End of class
|