diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase')
| -rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/Corpus.java | 24 | ||||
| -rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java | 27 | 
2 files changed, 39 insertions, 12 deletions
| diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java index 6936b28b..21375baa 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java @@ -28,12 +28,26 @@ public class Corpus  	public class Edge  	{ +		 +		Edge(int phraseId, int contextId, double count,int tag) +		{ +			this.phraseId = phraseId; +			this.contextId = contextId; +			this.count = count; +			fixTag=tag; +		} +		  		Edge(int phraseId, int contextId, double count)  		{  			this.phraseId = phraseId;  			this.contextId = contextId;  			this.count = count; +			fixTag=-1; +		} +		public int getTag(){ +			return fixTag;  		} +		  		public int getPhraseId()  		{  			return phraseId; @@ -85,6 +99,7 @@ public class Corpus  		private int phraseId;  		private int contextId;  		private double count; +		private int fixTag;  	}  	List<Edge> getEdges() @@ -218,7 +233,14 @@ public class Corpus  				}  				int contextId = contextLexicon.insert(ctx); -				edges.add(new Edge(phraseId, contextId, count)); +				String []countToks=countString.split(" "); +				if(countToks.length<2){ +					edges.add(new Edge(phraseId, contextId, count)); +				} +				else{ +					int tag=Integer.parseInt(countToks[1]); +					edges.add(new Edge(phraseId, contextId, count,tag)); +				}  			}  		}  		return edges; diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 560100d4..93e743fc 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -78,13 +78,11 @@ public class PhraseCluster {  	public double EM(int phraseSizeLimit)
  	{
  		double [][][]exp_emit=new double [K][n_positions][n_words];
 -		double [][]exp_pi=new double[n_phrases][K];
 +		double []exp_pi=new double[K];
  		for(double [][]i:exp_emit)
  			for(double []j:i)
  				Arrays.fill(j, 1e-10);
 -		for(double []j:pi)
 -			Arrays.fill(j, 1e-10);
  		double loglikelihood=0;
 @@ -93,10 +91,12 @@ public class PhraseCluster {  		{
  			if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit)
  			{
 -				System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
 +			//	System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
  				continue;
  			}	
 +			Arrays.fill(exp_pi, 1e-10);
 +			
  			List<Edge> contexts = c.getEdgesForPhrase(phrase);
  			for (int ctx=0; ctx<contexts.size(); ctx++)
 @@ -116,21 +116,19 @@ public class PhraseCluster {  				{
  					for(int pos=0;pos<n_positions;pos++)
  						exp_emit[tag][pos][context.get(pos)]+=p[tag]*count;		
 -					exp_pi[phrase][tag]+=p[tag]*count;
 +					exp_pi[tag]+=p[tag]*count;
  				}
  			}
 +			arr.F.l1norm(exp_pi);
 +			pi[phrase]=exp_pi;
  		}
  		//M
  		for(double [][]i:exp_emit)
  			for(double []j:i)
  				arr.F.l1normalize(j);
 -		
 -		for(double []j:exp_pi)
 -			arr.F.l1normalize(j);
  		emit=exp_emit;
 -		pi=exp_pi;
  		return loglikelihood;
  	}
 @@ -258,7 +256,7 @@ public class PhraseCluster {  		for(double [][]i:exp_emit)
  			for(double []j:i)
  				Arrays.fill(j, 1e-10);
 -		for(double []j:pi)
 +		for(double []j:exp_pi)
  			Arrays.fill(j, 1e-10);
  		if (lambdaPT == null && cacheLambda)
 @@ -338,7 +336,7 @@ public class PhraseCluster {  		for(double [][]i:exp_emit)
  			for(double []j:i)
  				Arrays.fill(j, 1e-10);
 -		for(double []j:pi)
 +		for(double []j:exp_pi)
  			Arrays.fill(j, 1e-10);
  		double loglikelihood=0, kl=0, l1lmax=0, primal=0;
 @@ -496,6 +494,13 @@ public class PhraseCluster {  	public double[] posterior(Corpus.Edge edge) 
  	{
  		double[] prob;
 +		
 +		if(edge.getTag()>=0){
 +			prob=new double[K];
 +			prob[edge.getTag()]=1;
 +			return prob;
 +		}
 +		
  		if (edge.getPhraseId() < n_phrases)
  			prob = Arrays.copyOf(pi[edge.getPhraseId()], K);
  		else
 | 
