From 0c901842ddb907fd45d29bdece5b48d42a599616 Mon Sep 17 00:00:00 2001
From: desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>
Date: Wed, 21 Jul 2010 14:53:58 +0000
Subject:  corpus reads optional tags from data, EM trains with those tags, fix
 a bug in PhraseCluster where phrase priors are not learned

git-svn-id: https://ws10smt.googlecode.com/svn/trunk@354 ec762483-ff6d-05da-a07a-a48fb63a330f
---
 .../prjava/src/phrase/Corpus.java                  | 24 ++++++++++++++++++-
 .../prjava/src/phrase/PhraseCluster.java           | 27 +++++++++++++---------
 2 files changed, 39 insertions(+), 12 deletions(-)

(limited to 'gi/posterior-regularisation/prjava/src')

diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
index 6936b28b..21375baa 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
@@ -28,12 +28,26 @@ public class Corpus
 	
 	public class Edge
 	{
+		
+		Edge(int phraseId, int contextId, double count,int tag)
+		{
+			this.phraseId = phraseId;
+			this.contextId = contextId;
+			this.count = count;
+			fixTag=tag;
+		}
+		
 		Edge(int phraseId, int contextId, double count)
 		{
 			this.phraseId = phraseId;
 			this.contextId = contextId;
 			this.count = count;
+			fixTag=-1;
+		}
+		public int getTag(){
+			return fixTag;
 		}
+		
 		public int getPhraseId()
 		{
 			return phraseId;
@@ -85,6 +99,7 @@ public class Corpus
 		private int phraseId;
 		private int contextId;
 		private double count;
+		private int fixTag;
 	}
 
 	List<Edge> getEdges()
@@ -218,7 +233,14 @@ public class Corpus
 				}
 				int contextId = contextLexicon.insert(ctx);
 
-				edges.add(new Edge(phraseId, contextId, count));
+				String []countToks=countString.split(" ");
+				if(countToks.length<2){
+					edges.add(new Edge(phraseId, contextId, count));
+				}
+				else{
+					int tag=Integer.parseInt(countToks[1]);
+					edges.add(new Edge(phraseId, contextId, count,tag));
+				}
 			}
 		}
 		return edges;
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index 560100d4..93e743fc 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -78,13 +78,11 @@ public class PhraseCluster {
 	public double EM(int phraseSizeLimit)
 	{
 		double [][][]exp_emit=new double [K][n_positions][n_words];
-		double [][]exp_pi=new double[n_phrases][K];
+		double []exp_pi=new double[K];
 		
 		for(double [][]i:exp_emit)
 			for(double []j:i)
 				Arrays.fill(j, 1e-10);
-		for(double []j:pi)
-			Arrays.fill(j, 1e-10);
 		
 		double loglikelihood=0;
 		
@@ -93,10 +91,12 @@ public class PhraseCluster {
 		{
 			if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit)
 			{
-				System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
+			//	System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
 				continue;
 			}	
 
+			Arrays.fill(exp_pi, 1e-10);
+			
 			List<Edge> contexts = c.getEdgesForPhrase(phrase);
 
 			for (int ctx=0; ctx<contexts.size(); ctx++)
@@ -116,21 +116,19 @@ public class PhraseCluster {
 				{
 					for(int pos=0;pos<n_positions;pos++)
 						exp_emit[tag][pos][context.get(pos)]+=p[tag]*count;		
-					exp_pi[phrase][tag]+=p[tag]*count;
+					exp_pi[tag]+=p[tag]*count;
 				}
 			}
+			arr.F.l1norm(exp_pi);
+			pi[phrase]=exp_pi;
 		}
 
 		//M
 		for(double [][]i:exp_emit)
 			for(double []j:i)
 				arr.F.l1normalize(j);
-		
-		for(double []j:exp_pi)
-			arr.F.l1normalize(j);
 			
 		emit=exp_emit;
-		pi=exp_pi;
 
 		return loglikelihood;
 	}
@@ -258,7 +256,7 @@ public class PhraseCluster {
 		for(double [][]i:exp_emit)
 			for(double []j:i)
 				Arrays.fill(j, 1e-10);
-		for(double []j:pi)
+		for(double []j:exp_pi)
 			Arrays.fill(j, 1e-10);
 
 		if (lambdaPT == null && cacheLambda)
@@ -338,7 +336,7 @@ public class PhraseCluster {
 		for(double [][]i:exp_emit)
 			for(double []j:i)
 				Arrays.fill(j, 1e-10);
-		for(double []j:pi)
+		for(double []j:exp_pi)
 			Arrays.fill(j, 1e-10);
 		
 		double loglikelihood=0, kl=0, l1lmax=0, primal=0;
@@ -496,6 +494,13 @@ public class PhraseCluster {
 	public double[] posterior(Corpus.Edge edge) 
 	{
 		double[] prob;
+		
+		if(edge.getTag()>=0){
+			prob=new double[K];
+			prob[edge.getTag()]=1;
+			return prob;
+		}
+		
 		if (edge.getPhraseId() < n_phrases)
 			prob = Arrays.copyOf(pi[edge.getPhraseId()], K);
 		else
-- 
cgit v1.2.3