diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/VB.java')
| -rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/VB.java | 419 | 
1 files changed, 0 insertions, 419 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/VB.java b/gi/posterior-regularisation/prjava/src/phrase/VB.java deleted file mode 100644 index cd3f4966..00000000 --- a/gi/posterior-regularisation/prjava/src/phrase/VB.java +++ /dev/null @@ -1,419 +0,0 @@ -package phrase;
 -
 -import gnu.trove.TIntArrayList;
 -
 -import io.FileUtil;
 -
 -import java.io.File;
 -import java.io.IOException;
 -import java.io.PrintStream;
 -import java.util.ArrayList;
 -import java.util.Arrays;
 -import java.util.List;
 -import java.util.concurrent.Callable;
 -import java.util.concurrent.ExecutionException;
 -import java.util.concurrent.ExecutorService;
 -import java.util.concurrent.Future;
 -
 -import org.apache.commons.math.special.Gamma;
 -
 -import phrase.Corpus.Edge;
 -
 -public class VB {
 -
 -	public static int MAX_ITER=400;
 -	
 -	/**@brief
 -	 * hyper param for beta
 -	 * where beta is multinomial
 -	 * for generating words from a topic
 -	 */
 -	public double lambda=0.1;
 -	/**@brief
 -	 * hyper param for theta
 -	 * where theta is dirichlet for z
 -	 */
 -	public double alpha=0.0001;
 -	/**@brief
 -	 * variational param for beta
 -	 */
 -	private double rho[][][];
 -	private double digamma_rho[][][];
 -	private double rho_sum[][];
 -	/**@brief
 -	 * variational param for z
 -	 */
 -	//private double phi[][];
 -	/**@brief
 -	 * variational param for theta
 -	 */
 -	private double gamma[];
 -	private static double VAL_DIFF_RATIO=0.005;
 -	
 -	private int n_positions;
 -	private int n_words;
 -	private int K;
 -	private ExecutorService pool;
 -	
 -	private Corpus c;
 -	public static void main(String[] args) {
 -	//	String in="../pdata/canned.con";
 -		String in="../pdata/btec.con";
 -		String out="../pdata/vb.out";
 -		int numCluster=25;
 -		Corpus corpus = null;
 -		File infile = new File(in);
 -		try {
 -			System.out.println("Reading concordance from " + infile);
 -			corpus = Corpus.readFromFile(FileUtil.reader(infile));
 -			corpus.printStats(System.out);
 -		} catch (IOException e) {
 -			System.err.println("Failed to open input file: " + infile);
 -			e.printStackTrace();
 -			System.exit(1);
 -		}
 -		
 -		VB vb=new VB(numCluster, corpus);
 -		int iter=20;
 -		for(int i=0;i<iter;i++){
 -			double obj=vb.EM();
 -			System.out.println("Iter "+i+": "+obj);
 -		}
 -		
 -		File outfile = new File (out);
 -		try {
 -			PrintStream ps = FileUtil.printstream(outfile);
 -			vb.displayPosterior(ps);
 -		//	ps.println();
 -		//	c2f.displayModelParam(ps);
 -			ps.close();
 -		} catch (IOException e) {
 -			System.err.println("Failed to open output file: " + outfile);
 -			e.printStackTrace();
 -			System.exit(1);
 -		}
 -	}
 -
 -	public VB(int numCluster, Corpus corpus){
 -		c=corpus;
 -		K=numCluster;
 -		n_positions=c.getNumContextPositions();
 -		n_words=c.getNumWords();
 -		rho=new double[K][n_positions][n_words];
 -		//to init rho
 -		//loop through data and count up words
 -		double[] phi_tmp=new double[K];
 -		for(int i=0;i<K;i++){
 -			for(int pos=0;pos<n_positions;pos++){
 -				Arrays.fill(rho[i][pos], lambda);
 -			}
 -		}
 -		for(int d=0;d<c.getNumPhrases();d++){
 -			List<Edge>doc=c.getEdgesForPhrase(d);
 -			for(int n=0;n<doc.size();n++){
 -				TIntArrayList context=doc.get(n).getContext();
 -				arr.F.randomise(phi_tmp);
 -				for(int i=0;i<K;i++){
 -					for(int pos=0;pos<n_positions;pos++){
 -						rho[i][pos][context.get(pos)]+=phi_tmp[i];
 -					}
 -				}
 -			}
 -		}
 -		
 -	}
 -	
 -	private double inference(int phraseID, double[][] phi, double[] gamma)
 -	{
 -		List<Edge > doc=c.getEdgesForPhrase(phraseID);
 -		for(int i=0;i<phi.length;i++){
 -			for(int j=0;j<phi[i].length;j++){
 -				phi[i][j]=1.0/K;
 -			}
 -		}
 -		Arrays.fill(gamma,alpha+1.0/K);
 -		
 -		double digamma_gamma[]=new double[K];
 -		
 -		double gamma_sum=digamma(arr.F.l1norm(gamma));
 -		for(int i=0;i<K;i++){
 -			digamma_gamma[i]=digamma(gamma[i]);
 -		}
 -		double gammaSum[]=new double [K];
 -		double prev_val=0;
 -		double obj=0;
 -		
 -		for(int iter=0;iter<MAX_ITER;iter++){
 -			prev_val=obj;
 -			obj=0;
 -			Arrays.fill(gammaSum,0.0);
 -			for(int n=0;n<doc.size();n++){
 -				TIntArrayList context=doc.get(n).getContext();
 -				double phisum=0;
 -				for(int i=0;i<K;i++){
 -					double sum=0;
 -					for(int pos=0;pos<n_positions;pos++){
 -						int word=context.get(pos);
 -						sum+=digamma_rho[i][pos][word]-rho_sum[i][pos];
 -					}
 -					sum+= digamma_gamma[i]-gamma_sum;
 -					phi[n][i]=sum;
 -					
 -					if (i > 0){
 -	                    phisum = log_sum(phisum, phi[n][i]);
 -					}
 -	                else{
 -	                    phisum = phi[n][i];
 -	                }
 -					
 -				}//end of  a word
 -				
 -				for(int i=0;i<K;i++){
 -					phi[n][i]=Math.exp(phi[n][i]-phisum);
 -					gammaSum[i]+=phi[n][i];
 -				}
 -				
 -			}//end of doc
 -			
 -			for(int i=0;i<K;i++){
 -				gamma[i]=alpha+gammaSum[i];
 -			}
 -			gamma_sum=digamma(arr.F.l1norm(gamma));
 -			for(int i=0;i<K;i++){
 -				digamma_gamma[i]=digamma(gamma[i]);
 -			}
 -			//compute objective for reporting
 -
 -			obj=0;
 -			
 -			for(int i=0;i<K;i++){
 -				obj+=(alpha-1)*(digamma_gamma[i]-gamma_sum);
 -			}
 -			
 -			
 -			for(int n=0;n<doc.size();n++){
 -				TIntArrayList context=doc.get(n).getContext();
 -				
 -				for(int i=0;i<K;i++){
 -					//entropy of phi + expected log likelihood of z
 -					obj+=phi[n][i]*(digamma_gamma[i]-gamma_sum);
 -					
 -					if(phi[n][i]>1e-10){
 -						obj+=phi[n][i]*Math.log(phi[n][i]);
 -					}
 -					
 -					double beta_sum=0;
 -					for(int pos=0;pos<n_positions;pos++){
 -						int word=context.get(pos);
 -						beta_sum+=(digamma(rho[i][pos][word])-rho_sum[i][pos]);
 -					}
 -					obj+=phi[n][i]*beta_sum;
 -				}
 -			}
 -			
 -			obj-=log_gamma(arr.F.l1norm(gamma));
 -			for(int i=0;i<K;i++){
 -				obj+=Gamma.logGamma(gamma[i]);
 -				obj-=(gamma[i]-1)*(digamma_gamma[i]-gamma_sum);
 -			}
 -			
 -//			System.out.println(phraseID+": "+obj);
 -			if(iter>0 && (obj-prev_val)/Math.abs(obj)<VAL_DIFF_RATIO){
 -				break;
 -			}
 -		}//end of inference loop
 -		
 -		return obj;
 -	}//end of inference
 -	
 -	/**
 -	 * @return objective of this iteration
 -	 */
 -	public double EM(){
 -		double emObj=0;
 -		if(digamma_rho==null){
 -			digamma_rho=new double[K][n_positions][n_words];
 -		}
 -		for(int i=0;i<K;i++){
 -			for (int pos=0;pos<n_positions;pos++){
 -				for(int j=0;j<n_words;j++){
 -					digamma_rho[i][pos][j]= digamma(rho[i][pos][j]);
 -				}
 -			}
 -		}
 -		
 -		if(rho_sum==null){
 -			rho_sum=new double [K][n_positions];
 -		}
 -		for(int i=0;i<K;i++){
 -			for(int pos=0;pos<n_positions;pos++){
 -				rho_sum[i][pos]=digamma(arr.F.l1norm(rho[i][pos]));
 -			}
 -		}
 -
 -		//E
 -		double exp_rho[][][]=new double[K][n_positions][n_words];
 -		if (pool == null)
 -		{
 -			for (int d=0;d<c.getNumPhrases();d++)
 -			{		
 -				List<Edge > doc=c.getEdgesForPhrase(d);
 -				double[][] phi = new double[doc.size()][K];
 -				double[] gamma = new double[K];
 -				
 -				emObj += inference(d, phi, gamma);
 -				
 -				for(int n=0;n<doc.size();n++){
 -					TIntArrayList context=doc.get(n).getContext();
 -					for(int pos=0;pos<n_positions;pos++){
 -						int word=context.get(pos);
 -						for(int i=0;i<K;i++){	
 -							exp_rho[i][pos][word]+=phi[n][i];
 -						}
 -					}
 -				}
 -				//if(d!=0 && d%100==0)  System.out.print(".");
 -				//if(d!=0 && d%1000==0) System.out.println(d);
 -			}
 -		}
 -		else // multi-threaded version of above loop
 -		{
 -			class PartialEStep implements Callable<PartialEStep>
 -			{
 -				double[][] phi;
 -				double[] gamma;
 -				double obj;
 -				int d;
 -				PartialEStep(int d) { this.d = d; }
 -
 -				public PartialEStep call()
 -				{
 -					phi = new double[c.getEdgesForPhrase(d).size()][K];
 -					gamma = new double[K];
 -					obj = inference(d, phi, gamma);
 -					return this;
 -				}			
 -			}
 -
 -			List<Future<PartialEStep>> jobs = new ArrayList<Future<PartialEStep>>();
 -			for (int d=0;d<c.getNumPhrases();d++)
 -				jobs.add(pool.submit(new PartialEStep(d)));
 -		
 -			for (Future<PartialEStep> job: jobs)
 -			{
 -				try {
 -					PartialEStep e = job.get();
 -					
 -					emObj += e.obj;				
 -					List<Edge> doc = c.getEdgesForPhrase(e.d);
 -					for(int n=0;n<doc.size();n++){
 -						TIntArrayList context=doc.get(n).getContext();
 -						for(int pos=0;pos<n_positions;pos++){
 -							int word=context.get(pos);
 -							for(int i=0;i<K;i++){	
 -								exp_rho[i][pos][word]+=e.phi[n][i];
 -							}
 -						}
 -					}
 -				} catch (ExecutionException e) {
 -					System.err.println("ERROR: E-step thread execution failed.");
 -					throw new RuntimeException(e);
 -				} catch (InterruptedException e) {
 -					System.err.println("ERROR: Failed to join E-step thread.");
 -					throw new RuntimeException(e);
 -				}
 -			}
 -		}	
 -	//	System.out.println("EM Objective:"+emObj);
 -		
 -		//M
 -		for(int i=0;i<K;i++){
 -			for(int pos=0;pos<n_positions;pos++){
 -				for(int j=0;j<n_words;j++){
 -					rho[i][pos][j]=lambda+exp_rho[i][pos][j];
 -				}
 -			}
 -		}
 -		
 -		//E[\log p(\beta|\lambda)] - E[\log q(\beta)]
 -		for(int i=0;i<K;i++){
 -			double rhoSum=0;
 -			for(int pos=0;pos<n_positions;pos++){
 -				for(int j=0;j<n_words;j++){
 -					rhoSum+=rho[i][pos][j];
 -				}
 -				double digamma_rhoSum=Gamma.digamma(rhoSum);
 -				emObj-=Gamma.logGamma(rhoSum);
 -				for(int j=0;j<n_words;j++){
 -					emObj+=(lambda-rho[i][pos][j])*(Gamma.digamma(rho[i][pos][j])-digamma_rhoSum);
 -					emObj+=Gamma.logGamma(rho[i][pos][j]);
 -				}
 -			}
 -		}
 -		
 -		return emObj;
 -	}//end of EM
 -	
 -	public void displayPosterior(PrintStream ps)
 -	{	
 -		for(int d=0;d<c.getNumPhrases();d++){
 -			List<Edge > doc=c.getEdgesForPhrase(d);
 -			double[][] phi = new double[doc.size()][K];
 -			for(int i=0;i<phi.length;i++)
 -				for(int j=0;j<phi[i].length;j++)
 -					phi[i][j]=1.0/K;
 -			double[] gamma = new double[K];
 -
 -			inference(d, phi, gamma);
 -
 -			for(int n=0;n<doc.size();n++){
 -				Edge edge=doc.get(n);
 -				int tag=arr.F.argmax(phi[n]);
 -				ps.print(edge.getPhraseString());
 -				ps.print("\t");
 -				ps.print(edge.getContextString(true));
 -
 -				ps.println(" ||| C=" + tag);
 -			}
 -		}
 -	}
 -
 -	double log_sum(double log_a, double log_b)
 -	{
 -	  double v;
 -
 -	  if (log_a < log_b)
 -	      v = log_b+Math.log(1 + Math.exp(log_a-log_b));
 -	  else
 -	      v = log_a+Math.log(1 + Math.exp(log_b-log_a));
 -	  return(v);
 -	}
 -		
 -	double digamma(double x)
 -	{
 -	    double p;
 -	    x=x+6;
 -	    p=1/(x*x);
 -	    p=(((0.004166666666667*p-0.003968253986254)*p+
 -		0.008333333333333)*p-0.083333333333333)*p;
 -	    p=p+Math.log(x)-0.5/x-1/(x-1)-1/(x-2)-1/(x-3)-1/(x-4)-1/(x-5)-1/(x-6);
 -	    return p;
 -	}
 -	
 -	double log_gamma(double x)
 -	{
 -	     double z=1/(x*x);
 -
 -	    x=x+6;
 -	    z=(((-0.000595238095238*z+0.000793650793651)
 -		*z-0.002777777777778)*z+0.083333333333333)/x;
 -	    z=(x-0.5)*Math.log(x)-x+0.918938533204673+z-Math.log(x-1)-
 -	    Math.log(x-2)-Math.log(x-3)-Math.log(x-4)-Math.log(x-5)-Math.log(x-6);
 -	    return z;
 -	}
 -
 -	public void useThreadPool(ExecutorService threadPool) 
 -	{
 -		pool = threadPool;
 -	}
 -}//End of  class
  | 
