summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Corpus.java24
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java27
2 files changed, 39 insertions, 12 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
index 6936b28b..21375baa 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
@@ -28,12 +28,26 @@ public class Corpus
public class Edge
{
+
+ Edge(int phraseId, int contextId, double count,int tag)
+ {
+ this.phraseId = phraseId;
+ this.contextId = contextId;
+ this.count = count;
+ fixTag=tag;
+ }
+
Edge(int phraseId, int contextId, double count)
{
this.phraseId = phraseId;
this.contextId = contextId;
this.count = count;
+ fixTag=-1;
+ }
+ public int getTag(){
+ return fixTag;
}
+
public int getPhraseId()
{
return phraseId;
@@ -85,6 +99,7 @@ public class Corpus
private int phraseId;
private int contextId;
private double count;
+ private int fixTag;
}
List<Edge> getEdges()
@@ -218,7 +233,14 @@ public class Corpus
}
int contextId = contextLexicon.insert(ctx);
- edges.add(new Edge(phraseId, contextId, count));
+ String []countToks=countString.split(" ");
+ if(countToks.length<2){
+ edges.add(new Edge(phraseId, contextId, count));
+ }
+ else{
+ int tag=Integer.parseInt(countToks[1]);
+ edges.add(new Edge(phraseId, contextId, count,tag));
+ }
}
}
return edges;
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index 560100d4..93e743fc 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -78,13 +78,11 @@ public class PhraseCluster {
public double EM(int phraseSizeLimit)
{
double [][][]exp_emit=new double [K][n_positions][n_words];
- double [][]exp_pi=new double[n_phrases][K];
+ double []exp_pi=new double[K];
for(double [][]i:exp_emit)
for(double []j:i)
Arrays.fill(j, 1e-10);
- for(double []j:pi)
- Arrays.fill(j, 1e-10);
double loglikelihood=0;
@@ -93,10 +91,12 @@ public class PhraseCluster {
{
if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit)
{
- System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
+ // System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
continue;
}
+ Arrays.fill(exp_pi, 1e-10);
+
List<Edge> contexts = c.getEdgesForPhrase(phrase);
for (int ctx=0; ctx<contexts.size(); ctx++)
@@ -116,21 +116,19 @@ public class PhraseCluster {
{
for(int pos=0;pos<n_positions;pos++)
exp_emit[tag][pos][context.get(pos)]+=p[tag]*count;
- exp_pi[phrase][tag]+=p[tag]*count;
+ exp_pi[tag]+=p[tag]*count;
}
}
+ arr.F.l1norm(exp_pi);
+ pi[phrase]=exp_pi;
}
//M
for(double [][]i:exp_emit)
for(double []j:i)
arr.F.l1normalize(j);
-
- for(double []j:exp_pi)
- arr.F.l1normalize(j);
emit=exp_emit;
- pi=exp_pi;
return loglikelihood;
}
@@ -258,7 +256,7 @@ public class PhraseCluster {
for(double [][]i:exp_emit)
for(double []j:i)
Arrays.fill(j, 1e-10);
- for(double []j:pi)
+ for(double []j:exp_pi)
Arrays.fill(j, 1e-10);
if (lambdaPT == null && cacheLambda)
@@ -338,7 +336,7 @@ public class PhraseCluster {
for(double [][]i:exp_emit)
for(double []j:i)
Arrays.fill(j, 1e-10);
- for(double []j:pi)
+ for(double []j:exp_pi)
Arrays.fill(j, 1e-10);
double loglikelihood=0, kl=0, l1lmax=0, primal=0;
@@ -496,6 +494,13 @@ public class PhraseCluster {
public double[] posterior(Corpus.Edge edge)
{
double[] prob;
+
+ if(edge.getTag()>=0){
+ prob=new double[K];
+ prob[edge.getTag()]=1;
+ return prob;
+ }
+
if (edge.getPhraseId() < n_phrases)
prob = Arrays.copyOf(pi[edge.getPhraseId()], K);
else