package phrase;

import io.FileUtil;

import java.io.PrintStream;
import java.util.Arrays;

public class PhraseCluster {

	/**@brief number of clusters*/
	public int K;
	private int n_phrase;
	private int n_words;
	public PhraseCorpus c;
	
	/**@brief
	 * emit[tag][position][word]
	 */
	private double emit[][][];
	private double pi[][];
	
	public static int ITER=20;
	public static String postFilename="../pdata/posterior.out";
	public static String phraseStatFilename="../pdata/phrase_stat.out";
	private static int NUM_TAG=3;
	public static void main(String[] args) {
		
		PhraseCorpus c=new PhraseCorpus(PhraseCorpus.DATA_FILENAME);
		
		PhraseCluster cluster=new PhraseCluster(NUM_TAG,c);
		PhraseObjective.ps=FileUtil.openOutFile(phraseStatFilename);
		for(int i=0;i<ITER;i++){
			PhraseObjective.ps.println("ITER: "+i);
			cluster.PREM();
		//	cluster.EM();
		}
		
		PrintStream ps=io.FileUtil.openOutFile(postFilename);
		cluster.displayPosterior(ps);
		ps.println();
		cluster.displayModelParam(ps);
		ps.close();
		PhraseObjective.ps.close();
	}

	public PhraseCluster(int numCluster,PhraseCorpus corpus){
		K=numCluster;
		c=corpus;
		n_words=c.wordLex.size();
		n_phrase=c.data.length;
		
		emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
		pi=new double[n_phrase][K];
		
		for(double [][]i:emit){
			for(double []j:i){
				arr.F.randomise(j);
			}
		}
		
		for(double []j:pi){
			arr.F.randomise(j);
		}
		
		pi[0]=new double[]{
			0.3,0.5,0.2
		};
		
		double temp[][]=new double[][]{
				{0.11,0.16,0.19,0.11,0.1},
				{0.10,0.15,0.18,0.1,0.11},
				{0.09,0.07,0.12,0.14,0.13} 
		};
		
		for(int tag=0;tag<3;tag++){
			for(int word=0;word<4;word++){
				for(int pos=0;pos<4;pos++){
					emit[tag][pos][word]=temp[tag][word];
				}          
			}
		}
		
	}
		
	public void EM(){
		double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
		double [][]exp_pi=new double[n_phrase][K];
		
		double loglikelihood=0;
		
		//E
		for(int phrase=0;phrase<c.data.length;phrase++){
			int [][] data=c.data[phrase];
			for(int ctx=0;ctx<data.length;ctx++){
				int context[]=data[ctx];
				double p[]=posterior(phrase,context);
				loglikelihood+=Math.log(arr.F.l1norm(p));
				arr.F.l1normalize(p);
				
				int contextCnt=context[context.length-1];
				//increment expected count
				for(int tag=0;tag<K;tag++){
					for(int pos=0;pos<context.length-1;pos++){
						exp_emit[tag][pos][context[pos]]+=p[tag]*contextCnt;
					}
					
					exp_pi[phrase][tag]+=p[tag]*contextCnt;
				}
			}
		}
		
		System.out.println("Log likelihood: "+loglikelihood);
		
		//M
		for(double [][]i:exp_emit){
			for(double []j:i){
				arr.F.l1normalize(j);
			}
		}
		
		emit=exp_emit;
		
		for(double []j:exp_pi){
			arr.F.l1normalize(j);
		}
		
		pi=exp_pi;
	}
	
	public void PREM(){
		double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
		double [][]exp_pi=new double[n_phrase][K];
		
		double loglikelihood=0;
		double primal=0;
		//E
		for(int phrase=0;phrase<c.data.length;phrase++){
			PhraseObjective po=new PhraseObjective(this,phrase);
			po.optimizeWithProjectedGradientDescent();
			double [][] q=po.posterior();
			loglikelihood+=po.getValue();
			primal+=po.primal();
			for(int edge=0;edge<q.length;edge++){
				int []context=c.data[phrase][edge];
				int contextCnt=context[context.length-1];
				//increment expected count
				for(int tag=0;tag<K;tag++){
					for(int pos=0;pos<context.length-1;pos++){
						exp_emit[tag][pos][context[pos]]+=q[edge][tag]*contextCnt;
					}
					
					exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
				}
			}
		}
		
		System.out.println("Log likelihood: "+loglikelihood);
		System.out.println("Primal Objective: "+primal);
		
		//M
		for(double [][]i:exp_emit){
			for(double []j:i){
				arr.F.l1normalize(j);
			}
		}
		
		emit=exp_emit;
		
		for(double []j:exp_pi){
			arr.F.l1normalize(j);
		}
		
		pi=exp_pi;
	}
	
	/**
	 * 
	 * @param phrase index of phrase
	 * @param ctx array of context
	 * @return unnormalized posterior
	 */
	public double[]posterior(int phrase, int[]ctx){
		double[] prob=Arrays.copyOf(pi[phrase], K);
		
		for(int tag=0;tag<K;tag++){
			for(int c=0;c<ctx.length-1;c++){
				int word=ctx[c];
				prob[tag]*=emit[tag][c][word];
			}
		}
		
		return prob;
	}
	
	public void displayPosterior(PrintStream ps)
	{
		
		c.buildList();
		
		for (int i = 0; i < n_phrase; ++i)
		{
			int [][]data=c.data[i];
			for (int[] e: data)
			{
				double probs[] = posterior(i, e);
				arr.F.l1normalize(probs);

				// emit phrase
				ps.print(c.phraseList[i]);
				ps.print("\t");
				ps.print(c.getContextString(e));
				ps.print("||| C=" + e[e.length-1] + " |||");

				int t=arr.F.argmax(probs);
				
				ps.print(t+"||| [");
				for(t=0;t<K;t++){
					ps.print(probs[t]+", ");
				}
				// for (int t = 0; t < numTags; ++t)
				// System.out.print(" " + probs[t]);
				ps.println("]");
			}
		}
	}
	
	public void displayModelParam(PrintStream ps)
	{
		
		c.buildList();
		
		ps.println("P(tag|phrase)");
		for (int i = 0; i < n_phrase; ++i)
		{
			ps.print(c.phraseList[i]);
			for(int j=0;j<pi[i].length;j++){
				ps.print("\t"+pi[i][j]);
			}
			ps.println();
		}
		
		ps.println("P(word|tag,position)");
		for (int i = 0; i < K; ++i)
		{
			ps.println(i);
			for(int position=0;position<PhraseCorpus.NUM_CONTEXT;position++){
				ps.println(position);
				for(int word=0;word<emit[i][position].length;word++){
					if((word+1)%100==0){
						ps.println();
					}
					ps.print(c.wordList[word]+"="+emit[i][position][word]+"\t");
				}
				ps.println();
			}
			ps.println();
		}
		
	}
}