From 3368ed9579857982a51d78e834cd6f44e1915deb Mon Sep 17 00:00:00 2001
From: desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>
Date: Mon, 12 Jul 2010 14:17:09 +0000
Subject: agreement model

git-svn-id: https://ws10smt.googlecode.com/svn/trunk@221 ec762483-ff6d-05da-a07a-a48fb63a330f
---
 .../prjava/src/phrase/Agree.java                   | 174 +++++++++++++++++++++
 .../prjava/src/phrase/C2F.java                     |   4 +-
 2 files changed, 176 insertions(+), 2 deletions(-)
 create mode 100644 gi/posterior-regularisation/prjava/src/phrase/Agree.java

(limited to 'gi')

diff --git a/gi/posterior-regularisation/prjava/src/phrase/Agree.java b/gi/posterior-regularisation/prjava/src/phrase/Agree.java
new file mode 100644
index 00000000..091875ce
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/phrase/Agree.java
@@ -0,0 +1,174 @@
+package phrase;
+
+import gnu.trove.TIntArrayList;
+
+import io.FileUtil;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.List;
+
+import phrase.Corpus.Edge;
+
+public class Agree {
+	private PhraseCluster model1;
+	private C2F model2;
+	Corpus c;
+	private int K,n_phrases, n_words, n_contexts, n_positions1,n_positions2;
+	
+	/**
+	 * 
+	 * @param numCluster
+	 * @param corpus
+	 */
+	public Agree(int numCluster, Corpus corpus){
+		
+		model1=new PhraseCluster(numCluster, corpus, 0, 0, 0);
+		model2=new C2F(numCluster,corpus);
+		c=corpus;
+		n_words=c.getNumWords();
+		n_phrases=c.getNumPhrases();
+		n_contexts=c.getNumContexts();
+		n_positions1=c.getNumContextPositions();
+		n_positions2=2;
+		K=numCluster;
+		
+	}
+	
+	/**@brief test
+	 * 
+	 */
+	public static void main(String args[]){
+		String in="../pdata/canned.con";
+		String out="../pdata/posterior.out";
+		int numCluster=25;
+		Corpus corpus = null;
+		File infile = new File(in);
+		try {
+			System.out.println("Reading concordance from " + infile);
+			corpus = Corpus.readFromFile(FileUtil.reader(infile));
+			corpus.printStats(System.out);
+		} catch (IOException e) {
+			System.err.println("Failed to open input file: " + infile);
+			e.printStackTrace();
+			System.exit(1);
+		}
+		
+		Agree agree=new Agree(numCluster, corpus);
+		int iter=20;
+		double llh=0;
+		for(int i=0;i<iter;i++){
+			llh=agree.EM();
+			System.out.println("Iter"+i+", llh: "+llh);
+		}
+		
+		File outfile = new File (out);
+		try {
+			PrintStream ps = FileUtil.printstream(outfile);
+			agree.displayPosterior(ps);
+		//	ps.println();
+		//	c2f.displayModelParam(ps);
+			ps.close();
+		} catch (IOException e) {
+			System.err.println("Failed to open output file: " + outfile);
+			e.printStackTrace();
+			System.exit(1);
+		}
+		
+	}
+	
+	public double EM(){
+		
+		double [][][]exp_emit1=new double [K][n_positions1][n_words];
+		double [][]exp_pi1=new double[n_phrases][K];
+		
+		double [][][]exp_emit2=new double [K][n_positions2][n_words];
+		double [][]exp_pi2=new double[n_contexts][K];
+		
+		double loglikelihood=0;
+		
+		//E
+		for(int context=0; context< n_contexts; context++){
+			
+			List<Edge> contexts = c.getEdgesForContext(context);
+
+			for (int ctx=0; ctx<contexts.size(); ctx++){
+				Edge edge = contexts.get(ctx);
+				int phrase=edge.getPhraseId();
+				double p[]=posterior(edge);
+				double z = arr.F.l1norm(p);
+				assert z > 0;
+				loglikelihood += edge.getCount() * Math.log(z);
+				arr.F.l1normalize(p);
+				
+				int count = edge.getCount();
+				//increment expected count
+				TIntArrayList phraseToks = edge.getPhrase();
+				TIntArrayList contextToks = edge.getContext();
+				for(int tag=0;tag<K;tag++){
+
+					for(int position=0;position<n_positions1;position++){
+						exp_emit1[tag][position][contextToks.get(position)]+=p[tag]*count;
+					}
+					
+					exp_emit2[tag][0][phraseToks.get(0)]+=p[tag]*count;
+					exp_emit2[tag][1][phraseToks.get(phraseToks.size()-1)]+=p[tag]*count;
+					
+					exp_pi1[phrase][tag]+=p[tag]*count;
+					exp_pi2[context][tag]+=p[tag]*count;
+				}
+			}
+		}
+		
+		//System.out.println("Log likelihood: "+loglikelihood);
+		
+		//M
+		for(double [][]i:exp_emit1){
+			for(double []j:i){
+				arr.F.l1normalize(j);
+			}
+		}
+		
+		for(double []j:exp_pi1){
+			arr.F.l1normalize(j);
+		}
+		
+		model1.emit=exp_emit1;
+		model1.pi=exp_pi1;
+		model2.emit=exp_emit2;
+		model2.pi=exp_pi2;
+		
+		return loglikelihood;
+	}
+
+	public double[] posterior(Corpus.Edge edge) 
+	{
+		double[] prob1=model1.posterior(edge);
+		double[] prob2=model2.posterior(edge);
+		
+		for(int i=0;i<prob1.length;i++){
+			prob1[i]*=prob2[i];
+			prob1[i]=Math.sqrt(prob1[i]);
+		}
+		
+		return prob1;
+	}
+	
+	public void displayPosterior(PrintStream ps)
+	{	
+		for (Edge edge : c.getEdges())
+		{
+			double probs[] = posterior(edge);
+			arr.F.l1normalize(probs);
+
+			// emit phrase
+			ps.print(edge.getPhraseString());
+			ps.print("\t");
+			ps.print(edge.getContextString(true));
+			int t=arr.F.argmax(probs);
+			ps.println(" ||| C=" + t);
+		}
+	}
+	
+}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/C2F.java b/gi/posterior-regularisation/prjava/src/phrase/C2F.java
index 3456c953..a8e557f2 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/C2F.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/C2F.java
@@ -25,11 +25,11 @@ public class C2F {
 	/**@brief
 	 *  emit[tag][position][word] = p(word | tag, position in phrase)
 	 */
-	private double emit[][][];
+	public double emit[][][];
 	/**@brief
 	 *  pi[context][tag] = p(tag | context)
 	 */
-	private double pi[][];
+	public double pi[][];
 	
 	public C2F(int numCluster, Corpus corpus){
 		K=numCluster;
-- 
cgit v1.2.3