package phrase; import gnu.trove.TIntArrayList; import io.FileUtil; import java.io.File; import java.io.IOException; import java.io.PrintStream; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import org.apache.commons.math.special.Gamma; import phrase.Corpus.Edge; public class VB { public static int MAX_ITER=400; /**@brief * hyper param for beta * where beta is multinomial * for generating words from a topic */ public double lambda=0.1; /**@brief * hyper param for theta * where theta is dirichlet for z */ public double alpha=0.0001; /**@brief * variational param for beta */ private double rho[][][]; private double digamma_rho[][][]; private double rho_sum[][]; /**@brief * variational param for z */ //private double phi[][]; /**@brief * variational param for theta */ private double gamma[]; private static double VAL_DIFF_RATIO=0.005; private int n_positions; private int n_words; private int K; private ExecutorService pool; private Corpus c; public static void main(String[] args) { // String in="../pdata/canned.con"; String in="../pdata/btec.con"; String out="../pdata/vb.out"; int numCluster=25; Corpus corpus = null; File infile = new File(in); try { System.out.println("Reading concordance from " + infile); corpus = Corpus.readFromFile(FileUtil.reader(infile)); corpus.printStats(System.out); } catch (IOException e) { System.err.println("Failed to open input file: " + infile); e.printStackTrace(); System.exit(1); } VB vb=new VB(numCluster, corpus); int iter=20; for(int i=0;i<iter;i++){ double obj=vb.EM(); System.out.println("Iter "+i+": "+obj); } File outfile = new File (out); try { PrintStream ps = FileUtil.printstream(outfile); vb.displayPosterior(ps); // ps.println(); // c2f.displayModelParam(ps); ps.close(); } catch (IOException e) { System.err.println("Failed to open output file: " + outfile); e.printStackTrace(); System.exit(1); } } public VB(int numCluster, Corpus corpus){ c=corpus; K=numCluster; n_positions=c.getNumContextPositions(); n_words=c.getNumWords(); rho=new double[K][n_positions][n_words]; //to init rho //loop through data and count up words double[] phi_tmp=new double[K]; for(int i=0;i<K;i++){ for(int pos=0;pos<n_positions;pos++){ Arrays.fill(rho[i][pos], lambda); } } for(int d=0;d<c.getNumPhrases();d++){ List<Edge>doc=c.getEdgesForPhrase(d); for(int n=0;n<doc.size();n++){ TIntArrayList context=doc.get(n).getContext(); arr.F.randomise(phi_tmp); for(int i=0;i<K;i++){ for(int pos=0;pos<n_positions;pos++){ rho[i][pos][context.get(pos)]+=phi_tmp[i]; } } } } } private double inference(int phraseID, double[][] phi, double[] gamma) { List<Edge > doc=c.getEdgesForPhrase(phraseID); for(int i=0;i<phi.length;i++){ for(int j=0;j<phi[i].length;j++){ phi[i][j]=1.0/K; } } Arrays.fill(gamma,alpha+1.0/K); double digamma_gamma[]=new double[K]; double gamma_sum=digamma(arr.F.l1norm(gamma)); for(int i=0;i<K;i++){ digamma_gamma[i]=digamma(gamma[i]); } double gammaSum[]=new double [K]; double prev_val=0; double obj=0; for(int iter=0;iter<MAX_ITER;iter++){ prev_val=obj; obj=0; Arrays.fill(gammaSum,0.0); for(int n=0;n<doc.size();n++){ TIntArrayList context=doc.get(n).getContext(); double phisum=0; for(int i=0;i<K;i++){ double sum=0; for(int pos=0;pos<n_positions;pos++){ int word=context.get(pos); sum+=digamma_rho[i][pos][word]-rho_sum[i][pos]; } sum+= digamma_gamma[i]-gamma_sum; phi[n][i]=sum; if (i > 0){ phisum = log_sum(phisum, phi[n][i]); } else{ phisum = phi[n][i]; } }//end of a word for(int i=0;i<K;i++){ phi[n][i]=Math.exp(phi[n][i]-phisum); gammaSum[i]+=phi[n][i]; } }//end of doc for(int i=0;i<K;i++){ gamma[i]=alpha+gammaSum[i]; } gamma_sum=digamma(arr.F.l1norm(gamma)); for(int i=0;i<K;i++){ digamma_gamma[i]=digamma(gamma[i]); } //compute objective for reporting obj=0; for(int i=0;i<K;i++){ obj+=(alpha-1)*(digamma_gamma[i]-gamma_sum); } for(int n=0;n<doc.size();n++){ TIntArrayList context=doc.get(n).getContext(); for(int i=0;i<K;i++){ //entropy of phi + expected log likelihood of z obj+=phi[n][i]*(digamma_gamma[i]-gamma_sum); if(phi[n][i]>1e-10){ obj+=phi[n][i]*Math.log(phi[n][i]); } double beta_sum=0; for(int pos=0;pos<n_positions;pos++){ int word=context.get(pos); beta_sum+=(digamma(rho[i][pos][word])-rho_sum[i][pos]); } obj+=phi[n][i]*beta_sum; } } obj-=log_gamma(arr.F.l1norm(gamma)); for(int i=0;i<K;i++){ obj+=Gamma.logGamma(gamma[i]); obj-=(gamma[i]-1)*(digamma_gamma[i]-gamma_sum); } // System.out.println(phraseID+": "+obj); if(iter>0 && (obj-prev_val)/Math.abs(obj)<VAL_DIFF_RATIO){ break; } }//end of inference loop return obj; }//end of inference /** * @return objective of this iteration */ public double EM(){ double emObj=0; if(digamma_rho==null){ digamma_rho=new double[K][n_positions][n_words]; } for(int i=0;i<K;i++){ for (int pos=0;pos<n_positions;pos++){ for(int j=0;j<n_words;j++){ digamma_rho[i][pos][j]= digamma(rho[i][pos][j]); } } } if(rho_sum==null){ rho_sum=new double [K][n_positions]; } for(int i=0;i<K;i++){ for(int pos=0;pos<n_positions;pos++){ rho_sum[i][pos]=digamma(arr.F.l1norm(rho[i][pos])); } } //E double exp_rho[][][]=new double[K][n_positions][n_words]; if (pool == null) { for (int d=0;d<c.getNumPhrases();d++) { List<Edge > doc=c.getEdgesForPhrase(d); double[][] phi = new double[doc.size()][K]; double[] gamma = new double[K]; emObj += inference(d, phi, gamma); for(int n=0;n<doc.size();n++){ TIntArrayList context=doc.get(n).getContext(); for(int pos=0;pos<n_positions;pos++){ int word=context.get(pos); for(int i=0;i<K;i++){ exp_rho[i][pos][word]+=phi[n][i]; } } } //if(d!=0 && d%100==0) System.out.print("."); //if(d!=0 && d%1000==0) System.out.println(d); } } else // multi-threaded version of above loop { class PartialEStep implements Callable<PartialEStep> { double[][] phi; double[] gamma; double obj; int d; PartialEStep(int d) { this.d = d; } public PartialEStep call() { phi = new double[c.getEdgesForPhrase(d).size()][K]; gamma = new double[K]; obj = inference(d, phi, gamma); return this; } } List<Future<PartialEStep>> jobs = new ArrayList<Future<PartialEStep>>(); for (int d=0;d<c.getNumPhrases();d++) jobs.add(pool.submit(new PartialEStep(d))); for (Future<PartialEStep> job: jobs) { try { PartialEStep e = job.get(); emObj += e.obj; List<Edge> doc = c.getEdgesForPhrase(e.d); for(int n=0;n<doc.size();n++){ TIntArrayList context=doc.get(n).getContext(); for(int pos=0;pos<n_positions;pos++){ int word=context.get(pos); for(int i=0;i<K;i++){ exp_rho[i][pos][word]+=e.phi[n][i]; } } } } catch (ExecutionException e) { System.err.println("ERROR: E-step thread execution failed."); throw new RuntimeException(e); } catch (InterruptedException e) { System.err.println("ERROR: Failed to join E-step thread."); throw new RuntimeException(e); } } } // System.out.println("EM Objective:"+emObj); //M for(int i=0;i<K;i++){ for(int pos=0;pos<n_positions;pos++){ for(int j=0;j<n_words;j++){ rho[i][pos][j]=lambda+exp_rho[i][pos][j]; } } } //E[\log p(\beta|\lambda)] - E[\log q(\beta)] for(int i=0;i<K;i++){ double rhoSum=0; for(int pos=0;pos<n_positions;pos++){ for(int j=0;j<n_words;j++){ rhoSum+=rho[i][pos][j]; } double digamma_rhoSum=Gamma.digamma(rhoSum); emObj-=Gamma.logGamma(rhoSum); for(int j=0;j<n_words;j++){ emObj+=(lambda-rho[i][pos][j])*(Gamma.digamma(rho[i][pos][j])-digamma_rhoSum); emObj+=Gamma.logGamma(rho[i][pos][j]); } } } return emObj; }//end of EM public void displayPosterior(PrintStream ps) { for(int d=0;d<c.getNumPhrases();d++){ List<Edge > doc=c.getEdgesForPhrase(d); double[][] phi = new double[doc.size()][K]; for(int i=0;i<phi.length;i++) for(int j=0;j<phi[i].length;j++) phi[i][j]=1.0/K; double[] gamma = new double[K]; inference(d, phi, gamma); for(int n=0;n<doc.size();n++){ Edge edge=doc.get(n); int tag=arr.F.argmax(phi[n]); ps.print(edge.getPhraseString()); ps.print("\t"); ps.print(edge.getContextString(true)); ps.println(" ||| C=" + tag); } } } double log_sum(double log_a, double log_b) { double v; if (log_a < log_b) v = log_b+Math.log(1 + Math.exp(log_a-log_b)); else v = log_a+Math.log(1 + Math.exp(log_b-log_a)); return(v); } double digamma(double x) { double p; x=x+6; p=1/(x*x); p=(((0.004166666666667*p-0.003968253986254)*p+ 0.008333333333333)*p-0.083333333333333)*p; p=p+Math.log(x)-0.5/x-1/(x-1)-1/(x-2)-1/(x-3)-1/(x-4)-1/(x-5)-1/(x-6); return p; } double log_gamma(double x) { double z=1/(x*x); x=x+6; z=(((-0.000595238095238*z+0.000793650793651) *z-0.002777777777778)*z+0.083333333333333)/x; z=(x-0.5)*Math.log(x)-x+0.918938533204673+z-Math.log(x-1)- Math.log(x-2)-Math.log(x-3)-Math.log(x-4)-Math.log(x-5)-Math.log(x-6); return z; } public void useThreadPool(ExecutorService threadPool) { pool = threadPool; } }//End of class