package phrase;
import gnu.trove.TIntArrayList;
import io.FileUtil;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.util.List;
import phrase.Corpus.Edge;
public class Agree {
PhraseCluster model1;
C2F model2;
Corpus c;
private int K,n_phrases, n_words, n_contexts, n_positions1,n_positions2;
/**@brief sum of loglikelihood of two
* individual models
*/
public double llh;
/**@brief Bhattacharyya distance
*
*/
public double bdist;
/**
*
* @param numCluster
* @param corpus
*/
public Agree(int numCluster, Corpus corpus){
model1=new PhraseCluster(numCluster, corpus);
model2=new C2F(numCluster,corpus);
c=corpus;
n_words=c.getNumWords();
n_phrases=c.getNumPhrases();
n_contexts=c.getNumContexts();
n_positions1=c.getNumContextPositions();
n_positions2=2;
K=numCluster;
}
/**@brief test
*
*/
public static void main(String args[]){
//String in="../pdata/canned.con";
String in="../pdata/btec.con";
String out="../pdata/posterior.out";
int numCluster=25;
Corpus corpus = null;
File infile = new File(in);
try {
System.out.println("Reading concordance from " + infile);
corpus = Corpus.readFromFile(FileUtil.reader(infile));
corpus.printStats(System.out);
} catch (IOException e) {
System.err.println("Failed to open input file: " + infile);
e.printStackTrace();
System.exit(1);
}
Agree agree=new Agree(numCluster, corpus);
int iter=20;
for(int i=0;i<iter;i++){
agree.EM();
System.out.println("Iter"+i+", llh: "+agree.llh+
", divergence:"+agree.bdist+
" sum: "+(agree.llh+agree.bdist));
}
File outfile = new File (out);
try {
PrintStream ps = FileUtil.printstream(outfile);
agree.displayPosterior(ps);
// ps.println();
// c2f.displayModelParam(ps);
ps.close();
} catch (IOException e) {
System.err.println("Failed to open output file: " + outfile);
e.printStackTrace();
System.exit(1);
}
}
public double EM(){
double [][][]exp_emit1=new double [K][n_positions1][n_words];
double [][]exp_pi1=new double[n_phrases][K];
double [][][]exp_emit2=new double [K][n_positions2][n_words];
double [][]exp_pi2=new double[n_contexts][K];
llh=0;
bdist=0;
//E
for(int context=0; context< n_contexts; context++){
List<Edge> contexts = c.getEdgesForContext(context);
for (int ctx=0; ctx<contexts.size(); ctx++){
Edge edge = contexts.get(ctx);
int phrase=edge.getPhraseId();
double p[]=posterior(edge);
double z = arr.F.l1norm(p);
assert z > 0;
bdist += edge.getCount() * Math.log(z);
arr.F.l1normalize(p);
double count = edge.getCount();
//increment expected count
TIntArrayList phraseToks = edge.getPhrase();
TIntArrayList contextToks = edge.getContext();
for(int tag=0;tag<K;tag++){
for(int position=0;position<n_positions1;position++){
exp_emit1[tag][position][contextToks.get(position)]+=p[tag]*count;
}
exp_emit2[tag][0][phraseToks.get(0)]+=p[tag]*count;
exp_emit2[tag][1][phraseToks.get(phraseToks.size()-1)]+=p[tag]*count;
exp_pi1[phrase][tag]+=p[tag]*count;
exp_pi2[context][tag]+=p[tag]*count;
}
}
}
//System.out.println("Log likelihood: "+loglikelihood);
//M
for(double [][]i:exp_emit1){
for(double []j:i){
arr.F.l1normalize(j);
}
}
for(double []j:exp_pi1){
arr.F.l1normalize(j);
}
for(double [][]i:exp_emit2){
for(double []j:i){
arr.F.l1normalize(j);
}
}
for(double []j:exp_pi2){
arr.F.l1normalize(j);
}
model1.emit=exp_emit1;
model1.pi=exp_pi1;
model2.emit=exp_emit2;
model2.pi=exp_pi2;
return llh;
}
public double[] posterior(Corpus.Edge edge)
{
double[] prob1=model1.posterior(edge);
double[] prob2=model2.posterior(edge);
llh+=edge.getCount()*Math.log(arr.F.l1norm(prob1));
llh+=edge.getCount()*Math.log(arr.F.l1norm(prob2));
arr.F.l1normalize(prob1);
arr.F.l1normalize(prob2);
for(int i=0;i<prob1.length;i++){
prob1[i]*=prob2[i];
prob1[i]=Math.sqrt(prob1[i]);
}
return prob1;
}
public void displayPosterior(PrintStream ps)
{
displayPosterior(ps, c.getEdges());
}
public void displayPosterior(PrintStream ps, List<Edge> test)
{
for (Edge edge : test)
{
double probs[] = posterior(edge);
arr.F.l1normalize(probs);
// emit phrase
ps.print(edge.getPhraseString());
ps.print("\t");
ps.print(edge.getContextString(true));
int t=arr.F.argmax(probs);
ps.println(" ||| C=" + t);
}
}
}