diff options
| author | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-06 16:23:11 +0000 | 
|---|---|---|
| committer | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-06 16:23:11 +0000 | 
| commit | 018a9f9feb6f432fb24e7a44908f165dc405ac05 (patch) | |
| tree | 3cbfee8762c34d73312cb3bcd4a20d9e549d0e88 | |
| parent | 825b1fc172a4f097c94b0fe8137ba2356262b5f4 (diff) | |
Thread pooling
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@151 ec762483-ff6d-05da-a07a-a48fb63a330f
3 files changed, 129 insertions, 45 deletions
| diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 8b1e0a8c..cd28c12e 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -1,53 +1,65 @@  package phrase;
  import io.FileUtil;
 -
  import java.io.PrintStream;
  import java.util.Arrays;
 +import java.util.concurrent.ExecutorService;
 +import java.util.concurrent.Executors;
 +import java.util.concurrent.LinkedBlockingQueue;
  public class PhraseCluster {
 -
 -	/**@brief number of clusters*/
 +	
  	public int K;
 +	public double scale;
  	private int n_phrase;
  	private int n_words;
  	public PhraseCorpus c;
 +	private ExecutorService pool; 
  	/**@brief
  	 * emit[tag][position][word]
  	 */
  	private double emit[][][];
  	private double pi[][];
 +
 -	public static int ITER=20;
 -	public static String postFilename="../pdata/posterior.out";
 -	public static String phraseStatFilename="../pdata/phrase_stat.out";
 -	private static int NUM_TAG=3;
  	public static void main(String[] args) {
 -		
 -		PhraseCorpus c=new PhraseCorpus(PhraseCorpus.DATA_FILENAME);
 -		
 -		PhraseCluster cluster=new PhraseCluster(NUM_TAG,c);
 -		PhraseObjective.ps=FileUtil.openOutFile(phraseStatFilename);
 -		for(int i=0;i<ITER;i++){
 -			PhraseObjective.ps.println("ITER: "+i);
 -			cluster.PREM();
 -		//	cluster.EM();
 +		String input_fname = args[0];
 +		int tags = Integer.parseInt(args[1]);
 +		String outputDir = args[2];
 +		int iterations = Integer.parseInt(args[3]);
 +		double scale = Double.parseDouble(args[4]);
 +		int threads = Integer.parseInt(args[5]);
 +		
 +		PhraseCorpus corpus = new PhraseCorpus(input_fname);
 +		PhraseCluster cluster = new PhraseCluster(tags, corpus, scale, threads);
 +		
 +		PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out");
 +		
 +		for(int i=0;i<iterations;i++){
 +			double o = cluster.PREM();
 +		    //double o = cluster.EM();
 +			PhraseObjective.ps.println("ITER: "+i+" objective: " + o);
  		}
 -		PrintStream ps=io.FileUtil.openOutFile(postFilename);
 +		PrintStream ps=io.FileUtil.openOutFile(outputDir + "/posterior.out");
  		cluster.displayPosterior(ps);
  		ps.println();
  		cluster.displayModelParam(ps);
  		ps.close();
  		PhraseObjective.ps.close();
 +		
 +		cluster.finish();
  	}
 -	public PhraseCluster(int numCluster,PhraseCorpus corpus){
 +	public PhraseCluster(int numCluster, PhraseCorpus corpus, double scale, int threads){
  		K=numCluster;
  		c=corpus;
  		n_words=c.wordLex.size();
  		n_phrase=c.data.length;
 +		this.scale = scale;
 +		if (threads > 0)
 +			pool = Executors.newFixedThreadPool(threads);
  		emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
  		pi=new double[n_phrase][K];
 @@ -61,28 +73,15 @@ public class PhraseCluster {  		for(double []j:pi){
  			arr.F.randomise(j);
  		}
 -		
 -		pi[0]=new double[]{
 -			0.3,0.5,0.2
 -		};
 -		
 -		double temp[][]=new double[][]{
 -				{0.11,0.16,0.19,0.11,0.1},
 -				{0.10,0.15,0.18,0.1,0.11},
 -				{0.09,0.07,0.12,0.14,0.13} 
 -		};
 -		
 -		for(int tag=0;tag<3;tag++){
 -			for(int word=0;word<4;word++){
 -				for(int pos=0;pos<4;pos++){
 -					emit[tag][pos][word]=temp[tag][word];
 -				}          
 -			}
 -		}
 -		
 +	}
 +	
 +	public void finish()
 +	{
 +		if (pool != null)
 +			pool.shutdown();
  	}
 -	public void EM(){
 +	public double EM(){
  		double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
  		double [][]exp_pi=new double[n_phrase][K];
 @@ -125,9 +124,14 @@ public class PhraseCluster {  		}
  		pi=exp_pi;
 +		
 +		return loglikelihood;
  	}
 -	public void PREM(){
 +	public double PREM(){
 +		if (pool != null)
 +			return PREMParallel();
 +		
  		double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
  		double [][]exp_pi=new double[n_phrase][K];
 @@ -171,6 +175,89 @@ public class PhraseCluster {  		}
  		pi=exp_pi;
 +		
 +		return primal;
 +	}
 +
 +	public double PREMParallel(){
 +		assert(pool != null);
 +		final LinkedBlockingQueue<PhraseObjective> expectations 
 +			= new LinkedBlockingQueue<PhraseObjective>();
 +		
 +		double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
 +		double [][]exp_pi=new double[n_phrase][K];
 +		
 +		double loglikelihood=0;
 +		double primal=0;
 +		//E
 +		for(int phrase=0;phrase<c.data.length;phrase++){
 +			final int p=phrase;
 +			pool.execute(new Runnable() {
 +				public void run() {
 +					try {
 +						//System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
 +						PhraseObjective po = new PhraseObjective(PhraseCluster.this, p);
 +						po.optimizeWithProjectedGradientDescent();
 +						//System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
 +						expectations.put(po);
 +						//System.out.println("" + Thread.currentThread().getId() + " added to queue " + p);
 +					} catch (InterruptedException e) {
 +						System.err.println(Thread.currentThread().getId() + " Local e-step thread interrupted; will cause deadlock.");
 +						e.printStackTrace();
 +					}
 +				}
 +			});
 +		}
 +		
 +		// aggregate the expectations as they become available
 +		for(int count=0;count<c.data.length;count++) {
 +			try {
 +				//System.out.println("" + Thread.currentThread().getId() + " reading queue #" + count);
 +
 +				// wait (blocking) until something is ready
 +				PhraseObjective po = expectations.take();
 +				// process
 +				int phrase = po.phrase;
 +				//System.out.println("" + Thread.currentThread().getId() + " taken phrase " + phrase);
 +				double [][] q=po.posterior();
 +				loglikelihood+=po.getValue();
 +				primal+=po.primal();
 +				for(int edge=0;edge<q.length;edge++){
 +					int []context=c.data[phrase][edge];
 +					int contextCnt=context[context.length-1];
 +					//increment expected count
 +					for(int tag=0;tag<K;tag++){
 +						for(int pos=0;pos<context.length-1;pos++){
 +							exp_emit[tag][pos][context[pos]]+=q[edge][tag]*contextCnt;
 +						}
 +						exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
 +					}
 +				}
 +			} catch (InterruptedException e){
 +				System.err.println("M-step thread interrupted. Probably fatal!");
 +				e.printStackTrace();
 +			}
 +		}
 +		
 +		System.out.println("Log likelihood: "+loglikelihood);
 +		System.out.println("Primal Objective: "+primal);
 +		
 +		//M
 +		for(double [][]i:exp_emit){
 +			for(double []j:i){
 +				arr.F.l1normalize(j);
 +			}
 +		}
 +		
 +		emit=exp_emit;
 +		
 +		for(double []j:exp_pi){
 +			arr.F.l1normalize(j);
 +		}
 +		
 +		pi=exp_pi;
 +		
 +		return primal;
  	}
  	/**
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java index 3902f665..99545371 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java @@ -12,7 +12,6 @@ public class PhraseCorpus {  	public static String LEX_FILENAME="../pdata/lex.out";
 -	//public static String DATA_FILENAME="../pdata/canned.con";
  	public static String DATA_FILENAME="../pdata/btec.con";
  	public static int NUM_CONTEXT=4;
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index e9e063d6..71c91b96 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -22,7 +22,6 @@ public class PhraseObjective extends ProjectedObjective{  	private static final double GRAD_DIFF = 0.002;
  	public static double INIT_STEP_SIZE=1;
  	public static double VAL_DIFF=0.001;
 -	private double scale=5;
  	private double c1=0.0001;
  	private double c2=0.9;
 @@ -73,7 +72,7 @@ public class PhraseObjective extends ProjectedObjective{  		newPoint  = new double[n_param];
  		gradient = new double[n_param];
  		initP();
 -		projection=new SimplexProjection (scale);
 +		projection=new SimplexProjection(c.scale);
  		q=new double [data.length][c.K];
  		setParameters(parameters);
 @@ -111,8 +110,7 @@ public class PhraseObjective extends ProjectedObjective{  		}
  		for(int edge=0;edge<data.length;edge++){
 -			loglikelihood+=Math.log
 -				(data[edge][countIdx]*arr.F.l1norm(q[edge]));
 +			loglikelihood+=data[edge][countIdx] * Math.log(arr.F.l1norm(q[edge]));
  			arr.F.l1normalize(q[edge]);
  		}
 @@ -222,7 +220,7 @@ public class PhraseObjective extends ProjectedObjective{  			sum+=max;
  		}
  //		ps.println(", "+sum);
 -		l=l-scale*sum;
 +		l=l-c.scale*sum;
  		return l;
  	}
 | 
