diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava')
5 files changed, 182 insertions, 213 deletions
| diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index ccb6ae9d..c032bb2b 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -4,14 +4,16 @@ import gnu.trove.TIntArrayList;  import org.apache.commons.math.special.Gamma;
  import java.io.BufferedReader;
 -import java.io.File;
  import java.io.IOException;
  import java.io.PrintStream;
 +import java.util.ArrayList;
  import java.util.Arrays;
  import java.util.List;
 -import java.util.StringTokenizer;
 +import java.util.concurrent.Callable;
 +import java.util.concurrent.ExecutionException;
  import java.util.concurrent.ExecutorService;
  import java.util.concurrent.Executors;
 +import java.util.concurrent.Future;
  import java.util.concurrent.LinkedBlockingQueue;
  import java.util.concurrent.atomic.AtomicInteger;
  import java.util.concurrent.atomic.AtomicLong;
 @@ -56,23 +58,9 @@ public class PhraseCluster {  			arr.F.randomise(j, true);
  	}
 -	public void initialiseVB(double alphaEmit, double alphaPi)
 +	void useThreadPool(ExecutorService pool)
  	{
 -		assert alphaEmit > 0;
 -		assert alphaPi > 0;
 -		
 -		for(double [][]i:emit)
 -			for(double []j:i)
 -				digammaNormalize(j, alphaEmit);
 -		
 -		for(double []j:pi)
 -			digammaNormalize(j, alphaPi);
 -	}
 -	
 -	void useThreadPool(int threads)
 -	{
 -		assert threads > 0;
 -		pool = Executors.newFixedThreadPool(threads);
 +		this.pool = pool;
  	}
  	public double EM(int phraseSizeLimit)
 @@ -131,107 +119,6 @@ public class PhraseCluster {  		return loglikelihood;
  	}
 -	public double VBEM(double alphaEmit, double alphaPi)
 -	{
 -		// FIXME: broken - needs to be done entirely in log-space
 -		
 -		double [][][]exp_emit = new double [K][n_positions][n_words];
 -		double [][]exp_pi = new double[n_phrases][K];
 -		
 -		double loglikelihood=0;
 -		
 -		//E
 -		for(int phrase=0; phrase < n_phrases; phrase++)
 -		{
 -			List<Edge> contexts = c.getEdgesForPhrase(phrase);
 -
 -			for (int ctx=0; ctx<contexts.size(); ctx++)
 -			{
 -				Edge edge = contexts.get(ctx);
 -				double p[] = posterior(edge);
 -				double z = arr.F.l1norm(p);
 -				assert z > 0;
 -				loglikelihood += edge.getCount() * Math.log(z);
 -				arr.F.l1normalize(p);
 -				
 -				double count = edge.getCount();
 -				//increment expected count
 -				TIntArrayList context = edge.getContext();
 -				for(int tag=0;tag<K;tag++)
 -				{
 -					for(int pos=0;pos<n_positions;pos++)
 -						exp_emit[tag][pos][context.get(pos)] += p[tag]*count;		
 -					exp_pi[phrase][tag] += p[tag]*count;
 -				}
 -			}
 -		}
 -
 -		// find the KL terms, KL(q||p) where p is symmetric Dirichlet prior and q are the expectations 
 -		double kl = 0;
 -		for (int phrase=0; phrase < n_phrases; phrase++)
 -			kl += KL_symmetric_dirichlet(exp_pi[phrase], alphaPi);
 -	
 -		for (int tag=0;tag<K;tag++)
 -			for (int pos=0;pos<n_positions; ++pos)
 -				kl += this.KL_symmetric_dirichlet(exp_emit[tag][pos], alphaEmit); 
 -		// FIXME: exp_emit[tag][pos] has structural zeros - certain words are *never* seen in that position
 -
 -		//M
 -		for(double [][]i:exp_emit)
 -			for(double []j:i)
 -				digammaNormalize(j, alphaEmit);
 -		emit=exp_emit;
 -		for(double []j:exp_pi)
 -			digammaNormalize(j, alphaPi);
 -		pi=exp_pi;
 -
 -		System.out.println("KL=" + kl + " llh=" + loglikelihood);
 -		System.out.println(Arrays.toString(pi[0]));
 -		System.out.println(Arrays.toString(exp_emit[0][0]));
 -		return kl + loglikelihood;
 -	}
 -	
 -	public void digammaNormalize(double [] a, double alpha)
 -	{
 -		double sum=0;
 -		for(int i=0;i<a.length;i++)
 -			sum += a[i];
 -		
 -		assert sum > 1e-20;
 -		double dgs = Gamma.digamma(sum + alpha);
 -		
 -		for(int i=0;i<a.length;i++)
 -			a[i] = Math.exp(Gamma.digamma(a[i] + alpha/a.length) - dgs);
 -	}
 -	
 -	private double KL_symmetric_dirichlet(double[] q, double alpha)
 -	{
 -		// assumes that zeros in q are structural & should be skipped
 -		// FIXME: asssumption doesn't hold
 -		
 -		double p0 = alpha;
 -		double q0 = 0;
 -		int n = 0;
 -		for (int i=0; i<q.length; i++)
 -		{
 -			if (q[i] > 0)
 -			{
 -				q0 += q[i];
 -				n += 1;
 -			}
 -		}
 -
 -		double kl = Gamma.logGamma(q0) - Gamma.logGamma(p0);
 -		kl += n * Gamma.logGamma(alpha / n);
 -		double digamma_q0 = Gamma.digamma(q0);
 -		for (int i=0; i<q.length; i++)
 -		{
 -			if (q[i] > 0)
 -				kl -= -Gamma.logGamma(q[i]) - (q[i] - alpha/q.length) * (Gamma.digamma(q[i]) - digamma_q0);
 -		}
 -		return kl;
 -	}
 -	
  	public double PREM(double scalePT, double scaleCT, int phraseSizeLimit)
  	{
  		if (scaleCT == 0)
 @@ -339,51 +226,44 @@ public class PhraseCluster {  		double loglikelihood=0, kl=0, l1lmax=0, primal=0;
  		final AtomicInteger failures = new AtomicInteger(0);
  		final AtomicLong elapsed = new AtomicLong(0l);
 -		int iterations=0, n=n_phrases;
 +		int iterations=0;
  		long start = System.currentTimeMillis();
 +		List<Future<PhraseObjective>> results = new ArrayList<Future<PhraseObjective>>();
  		if (lambdaPT == null && cacheLambda)
  			lambdaPT = new double[n_phrases][];
  		//E
 -		for(int phrase=0;phrase<n_phrases;phrase++){
 -			if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit)
 -			{
 -				n -= 1;
 +		for(int phrase=0;phrase<n_phrases;phrase++) {
 +			if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit) {
  				System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
  				continue;
  			}
  			final int p=phrase;
 -			pool.execute(new Runnable() {
 -				public void run() {
 -					try {
 -						//System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
 -						long start = System.currentTimeMillis();
 -						PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT, (cacheLambda) ? lambdaPT[p] : null);
 -						boolean ok = po.optimizeWithProjectedGradientDescent();
 -						if (!ok) failures.incrementAndGet();
 -						long end = System.currentTimeMillis();
 -						elapsed.addAndGet(end - start);
 -
 -						//System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
 -						expectations.put(po);
 -						//System.out.println("" + Thread.currentThread().getId() + " added to queue " + p);
 -					} catch (InterruptedException e) {
 -						System.err.println(Thread.currentThread().getId() + " Local e-step thread interrupted; will cause deadlock.");
 -						e.printStackTrace();
 -					}
 +			results.add(pool.submit(new Callable<PhraseObjective>() {
 +				public PhraseObjective call() {
 +					//System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
 +					long start = System.currentTimeMillis();
 +					PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT, (cacheLambda) ? lambdaPT[p] : null);
 +					boolean ok = po.optimizeWithProjectedGradientDescent();
 +					if (!ok) failures.incrementAndGet();
 +					long end = System.currentTimeMillis();
 +					elapsed.addAndGet(end - start);
 +					//System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
 +					return po;
  				}
 -			});
 +			}));
  		}
  		// aggregate the expectations as they become available
 -		for(int count=0;count<n;count++) {
 +		for (Future<PhraseObjective> fpo : results)
 +		{
  			try {
  				//System.out.println("" + Thread.currentThread().getId() + " reading queue #" + count);
  				// wait (blocking) until something is ready
 -				PhraseObjective po = expectations.take();
 +				PhraseObjective po = fpo.get();
  				// process
  				int phrase = po.phrase;
  				if (cacheLambda) lambdaPT[phrase] = po.getParameters();
 @@ -408,10 +288,12 @@ public class PhraseCluster {  						exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
  					}
  				}
 -			} catch (InterruptedException e)
 -			{
 +			} catch (InterruptedException e) {
  				System.err.println("M-step thread interrupted. Probably fatal!");
 -				e.printStackTrace();
 +				throw new RuntimeException(e);
 +			} catch (ExecutionException e) {
 +				System.err.println("M-step thread execution died. Probably fatal!");
 +				throw new RuntimeException(e);
  			}
  		}
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java index 06a9f8cb..5947c4be 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java @@ -277,7 +277,10 @@ public class PhraseContextObjective extends ProjectedObjective  			}
  			// rethrow the exception
  			if (failure != null)
 +			{
 +				pool.shutdownNow();
  				throw new RuntimeException(failure);
 +			}
  		}
  		double[] tmp = newPoint;
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 7c32d9c0..5efe778a 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -192,7 +192,7 @@ public class PhraseObjective extends ProjectedObjective  		//	for(int edge=0;edge<data.getSize();edge++){
  		//	ps.println(Arrays.toString(q[edge]));
  		//	}
 -		
 +
  		return success;
  	}
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index f205ce67..6f302b20 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -4,11 +4,12 @@ import io.FileUtil;  import joptsimple.OptionParser;  import joptsimple.OptionSet;  import java.io.File; -import java.io.FileNotFoundException;  import java.io.IOException;  import java.io.PrintStream;  import java.util.List;  import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors;  import phrase.Corpus.Edge; @@ -18,7 +19,6 @@ public class Trainer  {  	public static void main(String[] args)   	{ -		          OptionParser parser = new OptionParser();          parser.accepts("help");          parser.accepts("in").withRequiredArg().ofType(File.class); @@ -68,6 +68,10 @@ public class Trainer  		if (options.has("seed"))  			F.rng = new Random((Long) options.valueOf("seed")); +		ExecutorService threadPool = null; +		if (threads > 0) +			threadPool = Executors.newFixedThreadPool(threads);			 +		  		if (tags <= 1 || scale_phrase < 0 || scale_context < 0 || threshold < 0)  		{  			System.err.println("Invalid arguments. Try again!"); @@ -114,26 +118,30 @@ public class Trainer   			agree = new Agree(tags, corpus);   		else   		{ - 			cluster = new PhraseCluster(tags, corpus); - 			if (threads > 0) cluster.useThreadPool(threads); - 			 - 			if (vb)	{ - 				//cluster.initialiseVB(alphaEmit, alphaPi); + 			if (vb)	 + 			{   				vbModel=new VB(tags,corpus);   				vbModel.alpha=alphaPi;   				vbModel.lambda=alphaEmit; - 			} - 			if (options.has("no-parameter-cache"))  - 				cluster.cacheLambda = false; - 			if (options.has("start")) + 	 			if (threadPool != null) vbModel.useThreadPool(threadPool); + 			}  + 			else    			{ - 				try { -					System.err.println("Reading starting parameters from " + options.valueOf("start")); -					cluster.loadParameters(FileUtil.reader((File)options.valueOf("start"))); -				} catch (IOException e) { -					System.err.println("Failed to open input file: " + options.valueOf("start")); -					e.printStackTrace(); -				} + 				cluster = new PhraseCluster(tags, corpus); + 	 			if (threadPool != null) cluster.useThreadPool(threadPool); + 				 +	 			if (options.has("no-parameter-cache"))  +	 				cluster.cacheLambda = false; +	 			if (options.has("start")) +	 			{ +	 				try { +						System.err.println("Reading starting parameters from " + options.valueOf("start")); +						cluster.loadParameters(FileUtil.reader((File)options.valueOf("start"))); +					} catch (IOException e) { +						System.err.println("Failed to open input file: " + options.valueOf("start")); +						e.printStackTrace(); +					} +	 			}   			}   		} @@ -143,9 +151,8 @@ public class Trainer  			double o;  			if (agree != null)  				o = agree.EM(); -			else if(agree2sides!=null){ +			else if(agree2sides!=null)  				o = agree2sides.EM(); -			}  			else  			{  				if (i < skip) @@ -173,11 +180,25 @@ public class Trainer  			last = o;  		} -		if (cluster == null) -			cluster = agree.model1; +		double pl1lmax = 0, cl1lmax = 0; +		if (cluster != null) +		{ +			pl1lmax = cluster.phrase_l1lmax(); +			cl1lmax = cluster.context_l1lmax(); +		} +		else if (agree != null) +		{ +			// fairly arbitrary choice of model1 cf model2 +			pl1lmax = agree.model1.phrase_l1lmax(); +			cl1lmax = agree.model1.context_l1lmax(); +		} +		else if (agree2sides != null) +		{ +			// fairly arbitrary choice of model1 cf model2 +			pl1lmax = agree2sides.model1.phrase_l1lmax(); +			cl1lmax = agree2sides.model1.context_l1lmax(); +		} -		double pl1lmax = cluster.phrase_l1lmax(); -		double cl1lmax = cluster.context_l1lmax();  		System.out.println("\nFinal posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax);  		if (options.has("out")) @@ -194,11 +215,18 @@ public class Trainer  					System.out.println("Reading testing concordance from " + infile);  					test = corpus.readEdges(FileUtil.reader(infile));  				} -				if(vb){ +				if(vb) { +					assert !options.has("test");  					vbModel.displayPosterior(ps); -				}else{ +				} else if (cluster != null)   					cluster.displayPosterior(ps, test); +				else if (agree != null)  +					agree.displayPosterior(ps, test); +				else if (agree2sides != null) { +					assert !options.has("test"); +					agree2sides.displayPosterior(ps);  				} +				  				ps.close();  			} catch (IOException e) {  				System.err.println("Failed to open either testing file or output file"); @@ -209,6 +237,7 @@ public class Trainer  		if (options.has("parameters"))  		{ +			assert !vb;  			File outfile = (File) options.valueOf("parameters");  			PrintStream ps;  			try { @@ -222,7 +251,7 @@ public class Trainer  			}  		} -		if (cluster.pool != null) +		if (cluster != null && cluster.pool != null)  			cluster.pool.shutdown();  	}  } diff --git a/gi/posterior-regularisation/prjava/src/phrase/VB.java b/gi/posterior-regularisation/prjava/src/phrase/VB.java index a858c883..cd3f4966 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/VB.java +++ b/gi/posterior-regularisation/prjava/src/phrase/VB.java @@ -7,8 +7,13 @@ import io.FileUtil;  import java.io.File;
  import java.io.IOException;
  import java.io.PrintStream;
 +import java.util.ArrayList;
  import java.util.Arrays;
  import java.util.List;
 +import java.util.concurrent.Callable;
 +import java.util.concurrent.ExecutionException;
 +import java.util.concurrent.ExecutorService;
 +import java.util.concurrent.Future;
  import org.apache.commons.math.special.Gamma;
 @@ -38,21 +43,17 @@ public class VB {  	/**@brief
  	 * variational param for z
  	 */
 -	private double phi[][];
 +	//private double phi[][];
  	/**@brief
  	 * variational param for theta
  	 */
  	private double gamma[];
  	private static double VAL_DIFF_RATIO=0.005;
 -	/**@brief
 -	 * objective for a single document
 -	 */
 -	private double obj;
 -	
  	private int n_positions;
  	private int n_words;
  	private int K;
 +	private ExecutorService pool;
  	private Corpus c;
  	public static void main(String[] args) {
 @@ -122,17 +123,14 @@ public class VB {  	}
 -	private void inference(int phraseID){
 +	private double inference(int phraseID, double[][] phi, double[] gamma)
 +	{
  		List<Edge > doc=c.getEdgesForPhrase(phraseID);
 -		phi=new double[doc.size()][K];
  		for(int i=0;i<phi.length;i++){
  			for(int j=0;j<phi[i].length;j++){
  				phi[i][j]=1.0/K;
  			}
  		}
 -		if(gamma==null){
 -			gamma=new double[K];
 -		}
  		Arrays.fill(gamma,alpha+1.0/K);
  		double digamma_gamma[]=new double[K];
 @@ -143,7 +141,7 @@ public class VB {  		}
  		double gammaSum[]=new double [K];
  		double prev_val=0;
 -		obj=0;
 +		double obj=0;
  		for(int iter=0;iter<MAX_ITER;iter++){
  			prev_val=obj;
 @@ -224,6 +222,8 @@ public class VB {  				break;
  			}
  		}//end of inference loop
 +		
 +		return obj;
  	}//end of inference
  	/**
 @@ -251,31 +251,79 @@ public class VB {  			}
  		}
 -		
  		//E
  		double exp_rho[][][]=new double[K][n_positions][n_words];
 -		for (int d=0;d<c.getNumPhrases();d++){
 -			inference(d);
 -			List<Edge>doc=c.getEdgesForPhrase(d);
 -			for(int n=0;n<doc.size();n++){
 -				TIntArrayList context=doc.get(n).getContext();
 -				for(int pos=0;pos<n_positions;pos++){
 -					int word=context.get(pos);
 -					for(int i=0;i<K;i++){	
 -						exp_rho[i][pos][word]+=phi[n][i];
 +		if (pool == null)
 +		{
 +			for (int d=0;d<c.getNumPhrases();d++)
 +			{		
 +				List<Edge > doc=c.getEdgesForPhrase(d);
 +				double[][] phi = new double[doc.size()][K];
 +				double[] gamma = new double[K];
 +				
 +				emObj += inference(d, phi, gamma);
 +				
 +				for(int n=0;n<doc.size();n++){
 +					TIntArrayList context=doc.get(n).getContext();
 +					for(int pos=0;pos<n_positions;pos++){
 +						int word=context.get(pos);
 +						for(int i=0;i<K;i++){	
 +							exp_rho[i][pos][word]+=phi[n][i];
 +						}
  					}
  				}
 +				//if(d!=0 && d%100==0)  System.out.print(".");
 +				//if(d!=0 && d%1000==0) System.out.println(d);
  			}
 -/*			if(d!=0 && d%100==0){
 -				System.out.print(".");
 -			}
 -			if(d!=0 && d%1000==0){
 -				System.out.println(d);
 -			}
 -*/
 -			emObj+=obj;
  		}
 +		else // multi-threaded version of above loop
 +		{
 +			class PartialEStep implements Callable<PartialEStep>
 +			{
 +				double[][] phi;
 +				double[] gamma;
 +				double obj;
 +				int d;
 +				PartialEStep(int d) { this.d = d; }
 +
 +				public PartialEStep call()
 +				{
 +					phi = new double[c.getEdgesForPhrase(d).size()][K];
 +					gamma = new double[K];
 +					obj = inference(d, phi, gamma);
 +					return this;
 +				}			
 +			}
 +
 +			List<Future<PartialEStep>> jobs = new ArrayList<Future<PartialEStep>>();
 +			for (int d=0;d<c.getNumPhrases();d++)
 +				jobs.add(pool.submit(new PartialEStep(d)));
 +			for (Future<PartialEStep> job: jobs)
 +			{
 +				try {
 +					PartialEStep e = job.get();
 +					
 +					emObj += e.obj;				
 +					List<Edge> doc = c.getEdgesForPhrase(e.d);
 +					for(int n=0;n<doc.size();n++){
 +						TIntArrayList context=doc.get(n).getContext();
 +						for(int pos=0;pos<n_positions;pos++){
 +							int word=context.get(pos);
 +							for(int i=0;i<K;i++){	
 +								exp_rho[i][pos][word]+=e.phi[n][i];
 +							}
 +						}
 +					}
 +				} catch (ExecutionException e) {
 +					System.err.println("ERROR: E-step thread execution failed.");
 +					throw new RuntimeException(e);
 +				} catch (InterruptedException e) {
 +					System.err.println("ERROR: Failed to join E-step thread.");
 +					throw new RuntimeException(e);
 +				}
 +			}
 +		}	
  	//	System.out.println("EM Objective:"+emObj);
  		//M
 @@ -309,8 +357,15 @@ public class VB {  	public void displayPosterior(PrintStream ps)
  	{	
  		for(int d=0;d<c.getNumPhrases();d++){
 -			inference(d);
 -			List<Edge> doc=c.getEdgesForPhrase(d);
 +			List<Edge > doc=c.getEdgesForPhrase(d);
 +			double[][] phi = new double[doc.size()][K];
 +			for(int i=0;i<phi.length;i++)
 +				for(int j=0;j<phi[i].length;j++)
 +					phi[i][j]=1.0/K;
 +			double[] gamma = new double[K];
 +
 +			inference(d, phi, gamma);
 +
  			for(int n=0;n<doc.size();n++){
  				Edge edge=doc.get(n);
  				int tag=arr.F.argmax(phi[n]);
 @@ -328,13 +383,9 @@ public class VB {  	  double v;
  	  if (log_a < log_b)
 -	  {
  	      v = log_b+Math.log(1 + Math.exp(log_a-log_b));
 -	  }
  	  else
 -	  {
  	      v = log_a+Math.log(1 + Math.exp(log_b-log_a));
 -	  }
  	  return(v);
  	}
 @@ -360,5 +411,9 @@ public class VB {  	    Math.log(x-2)-Math.log(x-3)-Math.log(x-4)-Math.log(x-5)-Math.log(x-6);
  	    return z;
  	}
 -	
 +
 +	public void useThreadPool(ExecutorService threadPool) 
 +	{
 +		pool = threadPool;
 +	}
  }//End of  class
 | 
