package phrase; import gnu.trove.TIntArrayList; import io.FileUtil; import java.io.File; import java.io.IOException; import java.io.PrintStream; import java.util.List; import phrase.Corpus.Edge; public class Agree2Sides { PhraseCluster model1,model2; Corpus c1,c2; private int K; /**@brief sum of loglikelihood of two * individual models */ public double llh; /**@brief Bhattacharyya distance * */ public double bdist; /** * * @param numCluster * @param corpus */ public Agree2Sides(int numCluster, Corpus corpus1 , Corpus corpus2 ){ model1=new PhraseCluster(numCluster, corpus1); model2=new PhraseCluster(numCluster,corpus2); c1=corpus1; c2=corpus2; K=numCluster; } /**@brief test * */ public static void main(String args[]){ //String in="../pdata/canned.con"; // String in="../pdata/btec.con"; String in1="../pdata/source.txt"; String in2="../pdata/target.txt"; String out="../pdata/posterior.out"; int numCluster=25; Corpus corpus1 = null,corpus2=null; File infile1 = new File(in1),infile2=new File(in2); try { System.out.println("Reading concordance from " + infile1); corpus1 = Corpus.readFromFile(FileUtil.reader(infile1)); System.out.println("Reading concordance from " + infile2); corpus2 = Corpus.readFromFile(FileUtil.reader(infile2)); corpus1.printStats(System.out); } catch (IOException e) { System.err.println("Failed to open input file: " + infile1); e.printStackTrace(); System.exit(1); } Agree2Sides agree=new Agree2Sides(numCluster, corpus1,corpus2); int iter=20; for(int i=0;i<iter;i++){ agree.EM(); System.out.println("Iter"+i+", llh: "+agree.llh+ ", divergence:"+agree.bdist+ " sum: "+(agree.llh+agree.bdist)); } File outfile = new File (out); try { PrintStream ps = FileUtil.printstream(outfile); agree.displayPosterior(ps); // ps.println(); // c2f.displayModelParam(ps); ps.close(); } catch (IOException e) { System.err.println("Failed to open output file: " + outfile); e.printStackTrace(); System.exit(1); } } public double EM(){ double [][][]exp_emit1=new double [K][c1.getNumContextPositions()][c1.getNumWords()]; double [][]exp_pi1=new double[c1.getNumPhrases()][K]; double [][][]exp_emit2=new double [K][c2.getNumContextPositions()][c2.getNumWords()]; double [][]exp_pi2=new double[c2.getNumPhrases()][K]; llh=0; bdist=0; //E for(int i=0;i<c1.getEdges().size();i++){ Edge edge1=c1.getEdges().get(i); Edge edge2=c2.getEdges().get(i); double p[]=posterior(i); double z = arr.F.l1norm(p); assert z > 0; bdist += edge1.getCount() * Math.log(z); arr.F.l1normalize(p); double count = edge1.getCount(); //increment expected count TIntArrayList contextToks1 = edge1.getContext(); TIntArrayList contextToks2 = edge2.getContext(); int phrase1=edge1.getPhraseId(); int phrase2=edge2.getPhraseId(); for(int tag=0;tag<K;tag++){ for(int position=0;position<c1.getNumContextPositions();position++){ exp_emit1[tag][position][contextToks1.get(position)]+=p[tag]*count; } for(int position=0;position<c2.getNumContextPositions();position++){ exp_emit2[tag][position][contextToks2.get(position)]+=p[tag]*count; } exp_pi1[phrase1][tag]+=p[tag]*count; exp_pi2[phrase2][tag]+=p[tag]*count; } } //System.out.println("Log likelihood: "+loglikelihood); //M for(double [][]i:exp_emit1){ for(double []j:i){ arr.F.l1normalize(j); } } for(double []j:exp_pi1){ arr.F.l1normalize(j); } for(double [][]i:exp_emit2){ for(double []j:i){ arr.F.l1normalize(j); } } for(double []j:exp_pi2){ arr.F.l1normalize(j); } model1.emit=exp_emit1; model1.pi=exp_pi1; model2.emit=exp_emit2; model2.pi=exp_pi2; return llh; } public double[] posterior(int edgeIdx) { return posterior(c1.getEdges().get(edgeIdx), c2.getEdges().get(edgeIdx)); } public double[] posterior(Edge e1, Edge e2) { double[] prob1=model1.posterior(e1); double[] prob2=model2.posterior(e2); llh+=e1.getCount()*Math.log(arr.F.l1norm(prob1)); llh+=e2.getCount()*Math.log(arr.F.l1norm(prob2)); arr.F.l1normalize(prob1); arr.F.l1normalize(prob2); for(int i=0;i<prob1.length;i++){ prob1[i]*=prob2[i]; prob1[i]=Math.sqrt(prob1[i]); } return prob1; } public void displayPosterior(PrintStream ps) { for (int i=0;i<c1.getEdges().size();i++) { Edge edge=c1.getEdges().get(i); double probs[] = posterior(i); arr.F.l1normalize(probs); // emit phrase ps.print(edge.getPhraseString()); ps.print("\t"); ps.print(edge.getContextString(true)); int t=arr.F.argmax(probs); ps.println(" ||| C=" + t); } } }