summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/VB.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/VB.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/VB.java129
1 files changed, 92 insertions, 37 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/VB.java b/gi/posterior-regularisation/prjava/src/phrase/VB.java
index a858c883..cd3f4966 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/VB.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/VB.java
@@ -7,8 +7,13 @@ import io.FileUtil;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
import org.apache.commons.math.special.Gamma;
@@ -38,21 +43,17 @@ public class VB {
/**@brief
* variational param for z
*/
- private double phi[][];
+ //private double phi[][];
/**@brief
* variational param for theta
*/
private double gamma[];
private static double VAL_DIFF_RATIO=0.005;
- /**@brief
- * objective for a single document
- */
- private double obj;
-
private int n_positions;
private int n_words;
private int K;
+ private ExecutorService pool;
private Corpus c;
public static void main(String[] args) {
@@ -122,17 +123,14 @@ public class VB {
}
- private void inference(int phraseID){
+ private double inference(int phraseID, double[][] phi, double[] gamma)
+ {
List<Edge > doc=c.getEdgesForPhrase(phraseID);
- phi=new double[doc.size()][K];
for(int i=0;i<phi.length;i++){
for(int j=0;j<phi[i].length;j++){
phi[i][j]=1.0/K;
}
}
- if(gamma==null){
- gamma=new double[K];
- }
Arrays.fill(gamma,alpha+1.0/K);
double digamma_gamma[]=new double[K];
@@ -143,7 +141,7 @@ public class VB {
}
double gammaSum[]=new double [K];
double prev_val=0;
- obj=0;
+ double obj=0;
for(int iter=0;iter<MAX_ITER;iter++){
prev_val=obj;
@@ -224,6 +222,8 @@ public class VB {
break;
}
}//end of inference loop
+
+ return obj;
}//end of inference
/**
@@ -251,31 +251,79 @@ public class VB {
}
}
-
//E
double exp_rho[][][]=new double[K][n_positions][n_words];
- for (int d=0;d<c.getNumPhrases();d++){
- inference(d);
- List<Edge>doc=c.getEdgesForPhrase(d);
- for(int n=0;n<doc.size();n++){
- TIntArrayList context=doc.get(n).getContext();
- for(int pos=0;pos<n_positions;pos++){
- int word=context.get(pos);
- for(int i=0;i<K;i++){
- exp_rho[i][pos][word]+=phi[n][i];
+ if (pool == null)
+ {
+ for (int d=0;d<c.getNumPhrases();d++)
+ {
+ List<Edge > doc=c.getEdgesForPhrase(d);
+ double[][] phi = new double[doc.size()][K];
+ double[] gamma = new double[K];
+
+ emObj += inference(d, phi, gamma);
+
+ for(int n=0;n<doc.size();n++){
+ TIntArrayList context=doc.get(n).getContext();
+ for(int pos=0;pos<n_positions;pos++){
+ int word=context.get(pos);
+ for(int i=0;i<K;i++){
+ exp_rho[i][pos][word]+=phi[n][i];
+ }
}
}
+ //if(d!=0 && d%100==0) System.out.print(".");
+ //if(d!=0 && d%1000==0) System.out.println(d);
}
-/* if(d!=0 && d%100==0){
- System.out.print(".");
- }
- if(d!=0 && d%1000==0){
- System.out.println(d);
- }
-*/
- emObj+=obj;
}
+ else // multi-threaded version of above loop
+ {
+ class PartialEStep implements Callable<PartialEStep>
+ {
+ double[][] phi;
+ double[] gamma;
+ double obj;
+ int d;
+ PartialEStep(int d) { this.d = d; }
+
+ public PartialEStep call()
+ {
+ phi = new double[c.getEdgesForPhrase(d).size()][K];
+ gamma = new double[K];
+ obj = inference(d, phi, gamma);
+ return this;
+ }
+ }
+
+ List<Future<PartialEStep>> jobs = new ArrayList<Future<PartialEStep>>();
+ for (int d=0;d<c.getNumPhrases();d++)
+ jobs.add(pool.submit(new PartialEStep(d)));
+ for (Future<PartialEStep> job: jobs)
+ {
+ try {
+ PartialEStep e = job.get();
+
+ emObj += e.obj;
+ List<Edge> doc = c.getEdgesForPhrase(e.d);
+ for(int n=0;n<doc.size();n++){
+ TIntArrayList context=doc.get(n).getContext();
+ for(int pos=0;pos<n_positions;pos++){
+ int word=context.get(pos);
+ for(int i=0;i<K;i++){
+ exp_rho[i][pos][word]+=e.phi[n][i];
+ }
+ }
+ }
+ } catch (ExecutionException e) {
+ System.err.println("ERROR: E-step thread execution failed.");
+ throw new RuntimeException(e);
+ } catch (InterruptedException e) {
+ System.err.println("ERROR: Failed to join E-step thread.");
+ throw new RuntimeException(e);
+ }
+ }
+ }
// System.out.println("EM Objective:"+emObj);
//M
@@ -309,8 +357,15 @@ public class VB {
public void displayPosterior(PrintStream ps)
{
for(int d=0;d<c.getNumPhrases();d++){
- inference(d);
- List<Edge> doc=c.getEdgesForPhrase(d);
+ List<Edge > doc=c.getEdgesForPhrase(d);
+ double[][] phi = new double[doc.size()][K];
+ for(int i=0;i<phi.length;i++)
+ for(int j=0;j<phi[i].length;j++)
+ phi[i][j]=1.0/K;
+ double[] gamma = new double[K];
+
+ inference(d, phi, gamma);
+
for(int n=0;n<doc.size();n++){
Edge edge=doc.get(n);
int tag=arr.F.argmax(phi[n]);
@@ -328,13 +383,9 @@ public class VB {
double v;
if (log_a < log_b)
- {
v = log_b+Math.log(1 + Math.exp(log_a-log_b));
- }
else
- {
v = log_a+Math.log(1 + Math.exp(log_b-log_a));
- }
return(v);
}
@@ -360,5 +411,9 @@ public class VB {
Math.log(x-2)-Math.log(x-3)-Math.log(x-4)-Math.log(x-5)-Math.log(x-6);
return z;
}
-
+
+ public void useThreadPool(ExecutorService threadPool)
+ {
+ pool = threadPool;
+ }
}//End of class