diff options
| author | desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-23 17:08:53 +0000 | 
|---|---|---|
| committer | desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-23 17:08:53 +0000 | 
| commit | 76ef39de737e7abc0a8fe989dfacb7885617e59f (patch) | |
| tree | 77c6099236431c4488aa5ac95b6d680bfd5faf05 | |
| parent | 7776119e54c477a27fb0617d8bf8b483ac78898e (diff) | |
vb runnable from trainer
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@380 ec762483-ff6d-05da-a07a-a48fb63a330f
| -rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/Trainer.java | 43 | ||||
| -rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/VB.java | 97 | 
2 files changed, 92 insertions, 48 deletions
| diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index b51db919..cea6a20a 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -18,7 +18,7 @@ public class Trainer  {  	public static void main(String[] args)   	{ - +		          OptionParser parser = new OptionParser();          parser.accepts("help");          parser.accepts("in").withRequiredArg().ofType(File.class); @@ -107,6 +107,7 @@ public class Trainer   		PhraseCluster cluster = null;   		Agree2Sides agree2sides = null;   		Agree agree= null; + 		VB vbModel=null;   		if (options.has("agree-language"))   			agree2sides = new Agree2Sides(tags, corpus,corpus1);   		else if (options.has("agree-direction")) @@ -115,7 +116,11 @@ public class Trainer   		{   			cluster = new PhraseCluster(tags, corpus);   			if (threads > 0) cluster.useThreadPool(threads); - 			if (vb)	cluster.initialiseVB(alphaEmit, alphaPi); + 			 + 			if (vb)	{ + 				//cluster.initialiseVB(alphaEmit, alphaPi); + 				vbModel=new VB(tags,corpus); + 			}   			if (options.has("no-parameter-cache"))    				cluster.cacheLambda = false;   			if (options.has("start")) @@ -149,7 +154,7 @@ public class Trainer  					if (!vb)  						o = cluster.EM((i < skip) ? i+1 : 0);  					else -						o = cluster.VBEM(alphaEmit, alphaPi);	 +						o = vbModel.EM();	  				}  				else  					o = cluster.PREM(scale_phrase, scale_context, (i < skip) ? i+1 : 0); @@ -166,10 +171,8 @@ public class Trainer  			last = o;  		} -		if (cluster == null && agree != null) +		if (cluster == null)  			cluster = agree.model1; -		else if (cluster == null && agree2sides != null) -			cluster = agree2sides.model1;  		double pl1lmax = cluster.phrase_l1lmax();  		double cl1lmax = cluster.context_l1lmax(); @@ -180,26 +183,20 @@ public class Trainer  			File outfile = (File) options.valueOf("out");  			try {  				PrintStream ps = FileUtil.printstream(outfile); -				List<Edge> test = corpus.getEdges(); -				if (options.has("test")) // just use the training +				List<Edge> test; +				if (!options.has("test")) // just use the training +					test = corpus.getEdges(); +				else  				{	// if --test supplied, load up the file -					if (agree2sides == null) -					{ -						infile = (File) options.valueOf("test"); -						System.out.println("Reading testing concordance from " + infile); -						test = corpus.readEdges(FileUtil.reader(infile)); -					} -					else -						System.err.println("Can't run bilingual agreement model on different test data cf training (yet); --test ignored."); +					infile = (File) options.valueOf("test"); +					System.out.println("Reading testing concordance from " + infile); +					test = corpus.readEdges(FileUtil.reader(infile));  				} -				 -				if (agree != null) -					agree.displayPosterior(ps, test); -				else if (agree2sides != null) -					agree2sides.displayPosterior(ps); -				else +				if(vb){ +					vbModel.displayPosterior(ps); +				}else{  					cluster.displayPosterior(ps, test); -					 +				}  				ps.close();  			} catch (IOException e) {  				System.err.println("Failed to open either testing file or output file"); diff --git a/gi/posterior-regularisation/prjava/src/phrase/VB.java b/gi/posterior-regularisation/prjava/src/phrase/VB.java index cc1c1c96..a858c883 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/VB.java +++ b/gi/posterior-regularisation/prjava/src/phrase/VB.java @@ -16,7 +16,7 @@ import phrase.Corpus.Edge;  public class VB {
 -	public static int MAX_ITER=40;
 +	public static int MAX_ITER=400;
  	/**@brief
  	 * hyper param for beta
 @@ -28,11 +28,13 @@ public class VB {  	 * hyper param for theta
  	 * where theta is dirichlet for z
  	 */
 -	public double alpha=0.000001;
 +	public double alpha=0.0001;
  	/**@brief
  	 * variational param for beta
  	 */
  	private double rho[][][];
 +	private double digamma_rho[][][];
 +	private double rho_sum[][];
  	/**@brief
  	 * variational param for z
  	 */
 @@ -41,8 +43,7 @@ public class VB {  	 * variational param for theta
  	 */
  	private double gamma[];
 -	
 -	private static double VAL_DIFF_RATIO=0.001;
 +	private static double VAL_DIFF_RATIO=0.005;
  	/**@brief
  	 * objective for a single document
 @@ -55,8 +56,8 @@ public class VB {  	private Corpus c;
  	public static void main(String[] args) {
 -		String in="../pdata/canned.con";
 -		//String in="../pdata/btec.con";
 +	//	String in="../pdata/canned.con";
 +		String in="../pdata/btec.con";
  		String out="../pdata/vb.out";
  		int numCluster=25;
  		Corpus corpus = null;
 @@ -118,6 +119,7 @@ public class VB {  				}
  			}
  		}
 +		
  	}
  	private void inference(int phraseID){
 @@ -128,26 +130,21 @@ public class VB {  				phi[i][j]=1.0/K;
  			}
  		}
 -		gamma = new double[K];
 -		double digamma_gamma[]=new double[K];
 -		for(int i=0;i<gamma.length;i++){
 -			gamma[i] = alpha + 1.0/K;
 +		if(gamma==null){
 +			gamma=new double[K];
  		}
 +		Arrays.fill(gamma,alpha+1.0/K);
 -		double rho_sum[][]=new double [K][n_positions];
 -		for(int i=0;i<K;i++){
 -			for(int pos=0;pos<n_positions;pos++){
 -				rho_sum[i][pos]=Gamma.digamma(arr.F.l1norm(rho[i][pos]));
 -			}
 -		}
 -		double gamma_sum=Gamma.digamma(arr.F.l1norm(gamma));
 +		double digamma_gamma[]=new double[K];
 +		
 +		double gamma_sum=digamma(arr.F.l1norm(gamma));
  		for(int i=0;i<K;i++){
 -			digamma_gamma[i]=Gamma.digamma(gamma[i]);
 +			digamma_gamma[i]=digamma(gamma[i]);
  		}
  		double gammaSum[]=new double [K];
 -		
  		double prev_val=0;
  		obj=0;
 +		
  		for(int iter=0;iter<MAX_ITER;iter++){
  			prev_val=obj;
  			obj=0;
 @@ -159,7 +156,7 @@ public class VB {  					double sum=0;
  					for(int pos=0;pos<n_positions;pos++){
  						int word=context.get(pos);
 -						sum+=Gamma.digamma(rho[i][pos][word])-rho_sum[i][pos];
 +						sum+=digamma_rho[i][pos][word]-rho_sum[i][pos];
  					}
  					sum+= digamma_gamma[i]-gamma_sum;
  					phi[n][i]=sum;
 @@ -183,11 +180,12 @@ public class VB {  			for(int i=0;i<K;i++){
  				gamma[i]=alpha+gammaSum[i];
  			}
 -			gamma_sum=Gamma.digamma(arr.F.l1norm(gamma));
 +			gamma_sum=digamma(arr.F.l1norm(gamma));
  			for(int i=0;i<K;i++){
 -				digamma_gamma[i]=Gamma.digamma(gamma[i]);
 +				digamma_gamma[i]=digamma(gamma[i]);
  			}
  			//compute objective for reporting
 +
  			obj=0;
  			for(int i=0;i<K;i++){
 @@ -209,13 +207,13 @@ public class VB {  					double beta_sum=0;
  					for(int pos=0;pos<n_positions;pos++){
  						int word=context.get(pos);
 -						beta_sum+=(Gamma.digamma(rho[i][pos][word])-rho_sum[i][pos]);
 +						beta_sum+=(digamma(rho[i][pos][word])-rho_sum[i][pos]);
  					}
  					obj+=phi[n][i]*beta_sum;
  				}
  			}
 -			obj-=Gamma.logGamma(arr.F.l1norm(gamma));
 +			obj-=log_gamma(arr.F.l1norm(gamma));
  			for(int i=0;i<K;i++){
  				obj+=Gamma.logGamma(gamma[i]);
  				obj-=(gamma[i]-1)*(digamma_gamma[i]-gamma_sum);
 @@ -233,6 +231,26 @@ public class VB {  	 */
  	public double EM(){
  		double emObj=0;
 +		if(digamma_rho==null){
 +			digamma_rho=new double[K][n_positions][n_words];
 +		}
 +		for(int i=0;i<K;i++){
 +			for (int pos=0;pos<n_positions;pos++){
 +				for(int j=0;j<n_words;j++){
 +					digamma_rho[i][pos][j]= digamma(rho[i][pos][j]);
 +				}
 +			}
 +		}
 +		
 +		if(rho_sum==null){
 +			rho_sum=new double [K][n_positions];
 +		}
 +		for(int i=0;i<K;i++){
 +			for(int pos=0;pos<n_positions;pos++){
 +				rho_sum[i][pos]=digamma(arr.F.l1norm(rho[i][pos]));
 +			}
 +		}
 +
  		//E
  		double exp_rho[][][]=new double[K][n_positions][n_words];
 @@ -248,7 +266,13 @@ public class VB {  					}
  				}
  			}
 -			
 +/*			if(d!=0 && d%100==0){
 +				System.out.print(".");
 +			}
 +			if(d!=0 && d%1000==0){
 +				System.out.println(d);
 +			}
 +*/
  			emObj+=obj;
  		}
 @@ -313,5 +337,28 @@ public class VB {  	  }
  	  return(v);
  	}
 +		
 +	double digamma(double x)
 +	{
 +	    double p;
 +	    x=x+6;
 +	    p=1/(x*x);
 +	    p=(((0.004166666666667*p-0.003968253986254)*p+
 +		0.008333333333333)*p-0.083333333333333)*p;
 +	    p=p+Math.log(x)-0.5/x-1/(x-1)-1/(x-2)-1/(x-3)-1/(x-4)-1/(x-5)-1/(x-6);
 +	    return p;
 +	}
 +	
 +	double log_gamma(double x)
 +	{
 +	     double z=1/(x*x);
 +
 +	    x=x+6;
 +	    z=(((-0.000595238095238*z+0.000793650793651)
 +		*z-0.002777777777778)*z+0.083333333333333)/x;
 +	    z=(x-0.5)*Math.log(x)-x+0.918938533204673+z-Math.log(x-1)-
 +	    Math.log(x-2)-Math.log(x-3)-Math.log(x-4)-Math.log(x-5)-Math.log(x-6);
 +	    return z;
 +	}
  }//End of  class
 | 
