diff options
| author | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-09 22:29:02 +0000 | 
|---|---|---|
| committer | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-09 22:29:02 +0000 | 
| commit | 9801ac3df2cbf2656b8d21b2fb0046bfb4046e98 (patch) | |
| tree | 02f3be210f1a5a060f6ea89cf6093e1ec9dfab95 /gi/posterior-regularisation/prjava/src/phrase | |
| parent | 6211d023c559f3969ac0a827f4635c5b0959f230 (diff) | |
Added initial VB implementation for symetric Dirichlet prior.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@215 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase')
3 files changed, 159 insertions, 46 deletions
| diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index b9b1b98c..7bc63c33 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -1,6 +1,7 @@  package phrase;
  import gnu.trove.TIntArrayList;
 +import org.apache.commons.math.special.Gamma;
  import io.FileUtil;
  import java.io.IOException;
  import java.io.PrintStream;
 @@ -12,6 +13,7 @@ import java.util.concurrent.LinkedBlockingQueue;  import java.util.concurrent.atomic.AtomicInteger;
  import phrase.Corpus.Edge;
 +import util.MathUtil;
  public class PhraseCluster {
 @@ -26,7 +28,12 @@ public class PhraseCluster {  	// pi[phrase][tag] = p(tag | phrase)
  	private double pi[][];
 -	public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){
 +	double alphaEmit;
 +	double alphaPi;
 +	
 +	public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads,
 +						 double alphaEmit, double alphaPi)
 +	{
  		K=numCluster;
  		c=corpus;
  		n_words=c.getNumWords();
 @@ -41,29 +48,41 @@ public class PhraseCluster {  		emit=new double [K][n_positions][n_words];
  		pi=new double[n_phrases][K];
 -		for(double [][]i:emit){
 -			for(double []j:i){
 -				arr.F.randomise(j);
 +		for(double [][]i:emit)
 +		{
 +			for(double []j:i)
 +			{
 +				arr.F.randomise(j, alphaEmit <= 0);
 +				if (alphaEmit > 0) 
 +					digammaNormalize(j, alphaEmit);
  			}
  		}
 -		for(double []j:pi){
 -			arr.F.randomise(j);
 +		for(double []j:pi)
 +		{
 +			arr.F.randomise(j, alphaPi <= 0);
 +			if (alphaPi > 0) 
 +				digammaNormalize(j, alphaPi);
  		}
 +		
 +		this.alphaEmit = alphaEmit;
 +		this.alphaPi = alphaPi;
  	}
 -		
 -	public double EM(){
 +	public double EM()
 +	{
  		double [][][]exp_emit=new double [K][n_positions][n_words];
  		double [][]exp_pi=new double[n_phrases][K];
  		double loglikelihood=0;
  		//E
 -		for(int phrase=0; phrase < n_phrases; phrase++){
 +		for(int phrase=0; phrase < n_phrases; phrase++)
 +		{
  			List<Edge> contexts = c.getEdgesForPhrase(phrase);
 -			for (int ctx=0; ctx<contexts.size(); ctx++){
 +			for (int ctx=0; ctx<contexts.size(); ctx++)
 +			{
  				Edge edge = contexts.get(ctx);
  				double p[]=posterior(edge);
  				double z = arr.F.l1norm(p);
 @@ -74,34 +93,127 @@ public class PhraseCluster {  				int count = edge.getCount();
  				//increment expected count
  				TIntArrayList context = edge.getContext();
 -				for(int tag=0;tag<K;tag++){
 -					for(int pos=0;pos<n_positions;pos++){
 -						exp_emit[tag][pos][context.get(pos)]+=p[tag]*count;
 -					}
 -					
 +				for(int tag=0;tag<K;tag++)
 +				{
 +					for(int pos=0;pos<n_positions;pos++)
 +						exp_emit[tag][pos][context.get(pos)]+=p[tag]*count;		
  					exp_pi[phrase][tag]+=p[tag]*count;
  				}
  			}
  		}
 -		
 -		//System.out.println("Log likelihood: "+loglikelihood);
 -		
 +
  		//M
 -		for(double [][]i:exp_emit){
 -			for(double []j:i){
 +		for(double [][]i:exp_emit)
 +			for(double []j:i)
  				arr.F.l1normalize(j);
 -			}
 -		}
 +		for(double []j:exp_pi)
 +				arr.F.l1normalize(j);
 +			
  		emit=exp_emit;
 +		pi=exp_pi;
 +
 +		return loglikelihood;
 +	}
 +	
 +	public double VBEM()
 +	{
 +		// FIXME: broken - needs to be done entirely in log-space
 -		for(double []j:exp_pi){
 -			arr.F.l1normalize(j);
 -		}
 +		double [][][]exp_emit = new double [K][n_positions][n_words];
 +		double [][]exp_pi = new double[n_phrases][K];
 +		double loglikelihood=0;
 +		
 +		//E
 +		for(int phrase=0; phrase < n_phrases; phrase++)
 +		{
 +			List<Edge> contexts = c.getEdgesForPhrase(phrase);
 +
 +			for (int ctx=0; ctx<contexts.size(); ctx++)
 +			{
 +				Edge edge = contexts.get(ctx);
 +				double p[] = posterior(edge);
 +				double z = arr.F.l1norm(p);
 +				assert z > 0;
 +				loglikelihood += edge.getCount() * Math.log(z);
 +				arr.F.l1normalize(p);
 +				
 +				int count = edge.getCount();
 +				//increment expected count
 +				TIntArrayList context = edge.getContext();
 +				for(int tag=0;tag<K;tag++)
 +				{
 +					for(int pos=0;pos<n_positions;pos++)
 +						exp_emit[tag][pos][context.get(pos)] += p[tag]*count;		
 +					exp_pi[phrase][tag] += p[tag]*count;
 +				}
 +			}
 +		}
 +
 +		// find the KL terms, KL(q||p) where p is symmetric Dirichlet prior and q are the expectations 
 +		double kl = 0;
 +		for (int phrase=0; phrase < n_phrases; phrase++)
 +			kl += KL_symmetric_dirichlet(exp_pi[phrase], alphaPi);
 +	
 +		for (int tag=0;tag<K;tag++)
 +			for (int pos=0;pos<n_positions; ++pos)
 +				kl += this.KL_symmetric_dirichlet(exp_emit[tag][pos], alphaEmit); 
 +		// FIXME: exp_emit[tag][pos] has structural zeros - certain words are *never* seen in that position
 +
 +		//M
 +		for(double [][]i:exp_emit)
 +			for(double []j:i)
 +				digammaNormalize(j, alphaEmit);
 +		emit=exp_emit;
 +		for(double []j:exp_pi)
 +			digammaNormalize(j, alphaPi);
  		pi=exp_pi;
 +
 +		System.out.println("KL=" + kl + " llh=" + loglikelihood);
 +		System.out.println(Arrays.toString(pi[0]));
 +		System.out.println(Arrays.toString(exp_emit[0][0]));
 +		return kl + loglikelihood;
 +	}
 +	
 +	public void digammaNormalize(double [] a, double alpha)
 +	{
 +		double sum=0;
 +		for(int i=0;i<a.length;i++)
 +			sum += a[i];
 -		return loglikelihood;
 +		assert sum > 1e-20;
 +		double dgs = Gamma.digamma(sum + alpha);
 +		
 +		for(int i=0;i<a.length;i++)
 +			a[i] = Math.exp(Gamma.digamma(a[i] + alpha/a.length) - dgs);
 +	}
 +	
 +	private double KL_symmetric_dirichlet(double[] q, double alpha)
 +	{
 +		// assumes that zeros in q are structural & should be skipped
 +		
 +		double p0 = alpha;
 +		double q0 = 0;
 +		int n = 0;
 +		for (int i=0; i<q.length; i++)
 +		{
 +			if (q[i] > 0)
 +			{
 +				q0 += q[i];
 +				n += 1;
 +			}
 +		}
 +
 +		double kl = Gamma.logGamma(q0) - Gamma.logGamma(p0);
 +		kl += n * Gamma.logGamma(alpha / n);
 +		double digamma_q0 = Gamma.digamma(q0);
 +		for (int i=0; i<q.length; i++)
 +		{
 +			if (q[i] > 0)
 +				kl -= -Gamma.logGamma(q[i]) - (q[i] - alpha/q.length) * (Gamma.digamma(q[i]) - digamma_q0);
 +		}
 +		return kl;
  	}
  	public double PREM_phrase_constraints(){
 @@ -117,7 +229,7 @@ public class PhraseCluster {  			PhraseObjective po=new PhraseObjective(this,phrase);
  			boolean ok = po.optimizeWithProjectedGradientDescent();
  			if (!ok) ++failures;
 -			iterations += po.iterations;
 +			iterations += po.getNumberUpdateCalls();
  			double [][] q=po.posterior();
  			loglikelihood += po.loglikelihood();
  			kl += po.KL_divergence();
 @@ -142,24 +254,19 @@ public class PhraseCluster {  		if (failures > 0)
  			System.out.println("WARNING: failed to converge in " + failures + "/" + n_phrases + " cases");
 -		System.out.println("\tmean iters: 	  " + iterations/(double)n_phrases);
 +		System.out.println("\tmean iters:     " + iterations/(double)n_phrases);
  		System.out.println("\tllh:            " + loglikelihood);
  		System.out.println("\tKL:             " + kl);
  		System.out.println("\tphrase l1lmax:  " + l1lmax);
  		//M
 -		for(double [][]i:exp_emit){
 -			for(double []j:i){
 +		for(double [][]i:exp_emit)
 +			for(double []j:i)
  				arr.F.l1normalize(j);
 -			}
 -		}
 -		
  		emit=exp_emit;
 -		for(double []j:exp_pi){
 +		for(double []j:exp_pi)
  			arr.F.l1normalize(j);
 -		}
 -		
  		pi=exp_pi;
  		return primal;
 @@ -216,7 +323,7 @@ public class PhraseCluster {  				kl += po.KL_divergence();
  				l1lmax += po.l1lmax();
  				primal += po.primal();
 -				iterations += po.iterations;
 +				iterations += po.getNumberUpdateCalls();
  				List<Edge> edges = c.getEdgesForPhrase(phrase);
 @@ -241,7 +348,7 @@ public class PhraseCluster {  		if (failures.get() > 0)
  			System.out.println("WARNING: failed to converge in " + failures.get() + "/" + n_phrases + " cases");
 -		System.out.println("\tmean iters: 	  " + iterations/(double)n_phrases);
 +		System.out.println("\tmean iters:     " + iterations/(double)n_phrases);
  		System.out.println("\tllh:            " + loglikelihood);
  		System.out.println("\tKL:             " + kl);
  		System.out.println("\tphrase l1lmax:  " + l1lmax);
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index f24b903d..cc12546d 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -1,7 +1,5 @@  package phrase;
 -import java.io.PrintStream;
 -import java.util.Arrays;
  import java.util.List;
  import optimization.gradientBasedMethods.ProjectedGradientDescent;
 @@ -163,9 +161,7 @@ public class PhraseObjective extends ProjectedObjective  	public double [][]posterior(){
  		return q;
  	}
 -	
 -	public int iterations = 0;
 -	
 +		
  	public boolean optimizeWithProjectedGradientDescent(){
  		LineSearchMethod ls =
  			new ArmijoLineSearchMinimizationAlongProjectionArc
 @@ -184,7 +180,6 @@ public class PhraseObjective extends ProjectedObjective  		optimizer.setMaxIterations(ITERATIONS);
  		updateFunction();
  		boolean success = optimizer.optimize(this,stats,compositeStop);
 -		iterations += optimizer.getCurrentIteration();
  //		System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
  		//if(succed){
  			//System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index b19f3fb9..439fb337 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -27,6 +27,9 @@ public class Trainer          parser.accepts("scale-context").withRequiredArg().ofType(Double.class).defaultsTo(0.0);          parser.accepts("seed").withRequiredArg().ofType(Long.class).defaultsTo(0l);          parser.accepts("convergence-threshold").withRequiredArg().ofType(Double.class).defaultsTo(1e-6); +        parser.accepts("variational-bayes"); +        parser.accepts("alpha-emit").withRequiredArg().ofType(Double.class).defaultsTo(0.1); +        parser.accepts("alpha-pi").withRequiredArg().ofType(Double.class).defaultsTo(0.01);          OptionSet options = parser.parse(args);          if (options.has("help") || !options.has("in")) @@ -47,6 +50,9 @@ public class Trainer  		double scale_context = (Double) options.valueOf("scale-context");  		int threads = (Integer) options.valueOf("threads");  		double threshold = (Double) options.valueOf("convergence-threshold"); +		boolean vb = options.has("variational-bayes"); +		double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0; +		double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0;  		if (options.has("seed"))  			F.rng = new Random((Long) options.valueOf("seed")); @@ -75,14 +81,19 @@ public class Trainer   				"and " + threads + " threads");   		System.out.println(); -		PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads); +		PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads, alphaEmit, alphaPi);  		double last = 0;  		for (int i=0; i<em_iterations+pr_iterations; i++)  		{  			double o;  			if (i < em_iterations)  -				o = cluster.EM(); +			{ +				if (!vb) +					o = cluster.EM(); +				else +					o = cluster.VBEM(); +			}  			else if (scale_context == 0)  			{  				if (threads >= 1) | 
