diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/VB.java')
| -rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/VB.java | 129 | 
1 files changed, 92 insertions, 37 deletions
| diff --git a/gi/posterior-regularisation/prjava/src/phrase/VB.java b/gi/posterior-regularisation/prjava/src/phrase/VB.java index a858c883..cd3f4966 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/VB.java +++ b/gi/posterior-regularisation/prjava/src/phrase/VB.java @@ -7,8 +7,13 @@ import io.FileUtil;  import java.io.File;
  import java.io.IOException;
  import java.io.PrintStream;
 +import java.util.ArrayList;
  import java.util.Arrays;
  import java.util.List;
 +import java.util.concurrent.Callable;
 +import java.util.concurrent.ExecutionException;
 +import java.util.concurrent.ExecutorService;
 +import java.util.concurrent.Future;
  import org.apache.commons.math.special.Gamma;
 @@ -38,21 +43,17 @@ public class VB {  	/**@brief
  	 * variational param for z
  	 */
 -	private double phi[][];
 +	//private double phi[][];
  	/**@brief
  	 * variational param for theta
  	 */
  	private double gamma[];
  	private static double VAL_DIFF_RATIO=0.005;
 -	/**@brief
 -	 * objective for a single document
 -	 */
 -	private double obj;
 -	
  	private int n_positions;
  	private int n_words;
  	private int K;
 +	private ExecutorService pool;
  	private Corpus c;
  	public static void main(String[] args) {
 @@ -122,17 +123,14 @@ public class VB {  	}
 -	private void inference(int phraseID){
 +	private double inference(int phraseID, double[][] phi, double[] gamma)
 +	{
  		List<Edge > doc=c.getEdgesForPhrase(phraseID);
 -		phi=new double[doc.size()][K];
  		for(int i=0;i<phi.length;i++){
  			for(int j=0;j<phi[i].length;j++){
  				phi[i][j]=1.0/K;
  			}
  		}
 -		if(gamma==null){
 -			gamma=new double[K];
 -		}
  		Arrays.fill(gamma,alpha+1.0/K);
  		double digamma_gamma[]=new double[K];
 @@ -143,7 +141,7 @@ public class VB {  		}
  		double gammaSum[]=new double [K];
  		double prev_val=0;
 -		obj=0;
 +		double obj=0;
  		for(int iter=0;iter<MAX_ITER;iter++){
  			prev_val=obj;
 @@ -224,6 +222,8 @@ public class VB {  				break;
  			}
  		}//end of inference loop
 +		
 +		return obj;
  	}//end of inference
  	/**
 @@ -251,31 +251,79 @@ public class VB {  			}
  		}
 -		
  		//E
  		double exp_rho[][][]=new double[K][n_positions][n_words];
 -		for (int d=0;d<c.getNumPhrases();d++){
 -			inference(d);
 -			List<Edge>doc=c.getEdgesForPhrase(d);
 -			for(int n=0;n<doc.size();n++){
 -				TIntArrayList context=doc.get(n).getContext();
 -				for(int pos=0;pos<n_positions;pos++){
 -					int word=context.get(pos);
 -					for(int i=0;i<K;i++){	
 -						exp_rho[i][pos][word]+=phi[n][i];
 +		if (pool == null)
 +		{
 +			for (int d=0;d<c.getNumPhrases();d++)
 +			{		
 +				List<Edge > doc=c.getEdgesForPhrase(d);
 +				double[][] phi = new double[doc.size()][K];
 +				double[] gamma = new double[K];
 +				
 +				emObj += inference(d, phi, gamma);
 +				
 +				for(int n=0;n<doc.size();n++){
 +					TIntArrayList context=doc.get(n).getContext();
 +					for(int pos=0;pos<n_positions;pos++){
 +						int word=context.get(pos);
 +						for(int i=0;i<K;i++){	
 +							exp_rho[i][pos][word]+=phi[n][i];
 +						}
  					}
  				}
 +				//if(d!=0 && d%100==0)  System.out.print(".");
 +				//if(d!=0 && d%1000==0) System.out.println(d);
  			}
 -/*			if(d!=0 && d%100==0){
 -				System.out.print(".");
 -			}
 -			if(d!=0 && d%1000==0){
 -				System.out.println(d);
 -			}
 -*/
 -			emObj+=obj;
  		}
 +		else // multi-threaded version of above loop
 +		{
 +			class PartialEStep implements Callable<PartialEStep>
 +			{
 +				double[][] phi;
 +				double[] gamma;
 +				double obj;
 +				int d;
 +				PartialEStep(int d) { this.d = d; }
 +
 +				public PartialEStep call()
 +				{
 +					phi = new double[c.getEdgesForPhrase(d).size()][K];
 +					gamma = new double[K];
 +					obj = inference(d, phi, gamma);
 +					return this;
 +				}			
 +			}
 +
 +			List<Future<PartialEStep>> jobs = new ArrayList<Future<PartialEStep>>();
 +			for (int d=0;d<c.getNumPhrases();d++)
 +				jobs.add(pool.submit(new PartialEStep(d)));
 +			for (Future<PartialEStep> job: jobs)
 +			{
 +				try {
 +					PartialEStep e = job.get();
 +					
 +					emObj += e.obj;				
 +					List<Edge> doc = c.getEdgesForPhrase(e.d);
 +					for(int n=0;n<doc.size();n++){
 +						TIntArrayList context=doc.get(n).getContext();
 +						for(int pos=0;pos<n_positions;pos++){
 +							int word=context.get(pos);
 +							for(int i=0;i<K;i++){	
 +								exp_rho[i][pos][word]+=e.phi[n][i];
 +							}
 +						}
 +					}
 +				} catch (ExecutionException e) {
 +					System.err.println("ERROR: E-step thread execution failed.");
 +					throw new RuntimeException(e);
 +				} catch (InterruptedException e) {
 +					System.err.println("ERROR: Failed to join E-step thread.");
 +					throw new RuntimeException(e);
 +				}
 +			}
 +		}	
  	//	System.out.println("EM Objective:"+emObj);
  		//M
 @@ -309,8 +357,15 @@ public class VB {  	public void displayPosterior(PrintStream ps)
  	{	
  		for(int d=0;d<c.getNumPhrases();d++){
 -			inference(d);
 -			List<Edge> doc=c.getEdgesForPhrase(d);
 +			List<Edge > doc=c.getEdgesForPhrase(d);
 +			double[][] phi = new double[doc.size()][K];
 +			for(int i=0;i<phi.length;i++)
 +				for(int j=0;j<phi[i].length;j++)
 +					phi[i][j]=1.0/K;
 +			double[] gamma = new double[K];
 +
 +			inference(d, phi, gamma);
 +
  			for(int n=0;n<doc.size();n++){
  				Edge edge=doc.get(n);
  				int tag=arr.F.argmax(phi[n]);
 @@ -328,13 +383,9 @@ public class VB {  	  double v;
  	  if (log_a < log_b)
 -	  {
  	      v = log_b+Math.log(1 + Math.exp(log_a-log_b));
 -	  }
  	  else
 -	  {
  	      v = log_a+Math.log(1 + Math.exp(log_b-log_a));
 -	  }
  	  return(v);
  	}
 @@ -360,5 +411,9 @@ public class VB {  	    Math.log(x-2)-Math.log(x-3)-Math.log(x-4)-Math.log(x-5)-Math.log(x-6);
  	    return z;
  	}
 -	
 +
 +	public void useThreadPool(ExecutorService threadPool) 
 +	{
 +		pool = threadPool;
 +	}
  }//End of  class
 | 
