package phrase;

import gnu.trove.TIntArrayList;

import io.FileUtil;

import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.List;

import phrase.Corpus.Edge;

/**
 * @brief context generates phrase
 * @author desaic
 *
 */
public class C2F {
	public int K;
	private int n_words, n_contexts, n_positions;
	public Corpus c;
	
	/**@brief
	 *  emit[tag][position][word] = p(word | tag, position in phrase)
	 */
	public double emit[][][];
	/**@brief
	 *  pi[context][tag] = p(tag | context)
	 */
	public double pi[][];
	
	public C2F(int numCluster, Corpus corpus){
		K=numCluster;
		c=corpus;
		n_words=c.getNumWords();
		n_contexts=c.getNumContexts();
		
		//number of words in a phrase to be considered
		//currently the first and last word in source and target
		//if the phrase has length 1 in either dimension then
		//we use the same word for two positions
		n_positions=c.phraseEdges(c.getEdges().get(0).getPhrase()).size();
		
		emit=new double [K][n_positions][n_words];
		pi=new double[n_contexts][K];
		
		for(double [][]i:emit){
			for(double []j:i){
				arr.F.randomise(j);
			}
		}
		
		for(double []j:pi){
			arr.F.randomise(j);
		}
	}
	
	/**@brief test
	 * 
	 */
	public static void main(String args[]){
		String in="../pdata/canned.con";
		String out="../pdata/posterior.out";
		int numCluster=25;
		Corpus corpus = null;
		File infile = new File(in);
		try {
			System.out.println("Reading concordance from " + infile);
			corpus = Corpus.readFromFile(FileUtil.reader(infile));
			corpus.printStats(System.out);
		} catch (IOException e) {
			System.err.println("Failed to open input file: " + infile);
			e.printStackTrace();
			System.exit(1);
		}
		
		C2F c2f=new C2F(numCluster,corpus);
		int iter=20;
		double llh=0;
		for(int i=0;i<iter;i++){
			llh=c2f.EM();
			System.out.println("Iter"+i+", llh: "+llh);
		}
		
		File outfile = new File (out);
		try {
			PrintStream ps = FileUtil.printstream(outfile);
			c2f.displayPosterior(ps);
		//	ps.println();
		//	c2f.displayModelParam(ps);
			ps.close();
		} catch (IOException e) {
			System.err.println("Failed to open output file: " + outfile);
			e.printStackTrace();
			System.exit(1);
		}
		
	}
	
	public double EM(){
		double [][][]exp_emit=new double [K][n_positions][n_words];
		double [][]exp_pi=new double[n_contexts][K];
		
		double loglikelihood=0;
		
		//E
		for(int context=0; context< n_contexts; context++){
			
			List<Edge> contexts = c.getEdgesForContext(context);

			for (int ctx=0; ctx<contexts.size(); ctx++){
				Edge edge = contexts.get(ctx);
				double p[]=posterior(edge);
				double z = arr.F.l1norm(p);
				assert z > 0;
				loglikelihood += edge.getCount() * Math.log(z);
				arr.F.l1normalize(p);
				
				double count = edge.getCount();
				//increment expected count
				TIntArrayList phrase= edge.getPhrase();
				for(int tag=0;tag<K;tag++){

					exp_emit[tag][0][phrase.get(0)]+=p[tag]*count;
					exp_emit[tag][1][phrase.get(phrase.size()-1)]+=p[tag]*count;
					
					exp_pi[context][tag]+=p[tag]*count;
				}
			}
		}
		
		//System.out.println("Log likelihood: "+loglikelihood);
		
		//M
		for(double [][]i:exp_emit){
			for(double []j:i){
				arr.F.l1normalize(j);
			}
		}
		
		emit=exp_emit;
		
		for(double []j:exp_pi){
			arr.F.l1normalize(j);
		}
		
		pi=exp_pi;
		
		return loglikelihood;
	}

	public double[] posterior(Corpus.Edge edge) 
	{
		double[] prob=Arrays.copyOf(pi[edge.getContextId()], K);
		
		TIntArrayList phrase = edge.getPhrase();
		TIntArrayList offsets = c.phraseEdges(phrase);
		for(int tag=0;tag<K;tag++)
		{
			for (int i=0; i < offsets.size(); ++i)
				prob[tag]*=emit[tag][i][phrase.get(offsets.get(i))];
		}
			
		return prob;
	}

	public void displayPosterior(PrintStream ps)
	{	
		for (Edge edge : c.getEdges())
		{
			double probs[] = posterior(edge);
			arr.F.l1normalize(probs);

			// emit phrase
			ps.print(edge.getPhraseString());
			ps.print("\t");
			ps.print(edge.getContextString(true));
			int t=arr.F.argmax(probs);
			ps.println(" ||| C=" + t);
		}
	}
	
	public void displayModelParam(PrintStream ps)
	{
		final double EPS = 1e-6;
		
		ps.println("P(tag|context)");
		for (int i = 0; i < n_contexts; ++i)
		{
			ps.print(c.getContext(i));
			for(int j=0;j<pi[i].length;j++){
				if (pi[i][j] > EPS)
					ps.print("\t" + j + ": " + pi[i][j]);
			}
			ps.println();
		}
		
		ps.println("P(word|tag,position)");
		for (int i = 0; i < K; ++i)
		{
			for(int position=0;position<n_positions;position++){
				ps.println("tag " + i + " position " + position);
				for(int word=0;word<emit[i][position].length;word++){
					if (emit[i][position][word] > EPS)
						ps.print(c.getWord(word)+"="+emit[i][position][word]+"\t");
				}
				ps.println();
			}
			ps.println();
		}
		
	}
	
}