summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java24
1 files changed, 12 insertions, 12 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index 13ac14ba..ccb6ae9d 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -111,8 +111,9 @@ public class PhraseCluster {
TIntArrayList context = edge.getContext();
for(int tag=0;tag<K;tag++)
{
- for(int pos=0;pos<n_positions;pos++)
- exp_emit[tag][pos][context.get(pos)]+=p[tag]*count;
+ for(int pos=0;pos<n_positions;pos++){
+ exp_emit[tag][pos][context.get(pos)]+=p[tag]*count;
+ }
exp_pi[tag]+=p[tag]*count;
}
}
@@ -248,14 +249,12 @@ public class PhraseCluster {
public double PREM_phrase_constraints(double scalePT, int phraseSizeLimit)
{
double [][][]exp_emit=new double[K][n_positions][n_words];
- double [][]exp_pi=new double[n_phrases][K];
+ double []exp_pi=new double[K];
for(double [][]i:exp_emit)
for(double []j:i)
Arrays.fill(j, 1e-10);
- for(double []j:exp_pi)
- Arrays.fill(j, 1e-10);
-
+
if (lambdaPT == null && cacheLambda)
lambdaPT = new double[n_phrases][];
@@ -267,10 +266,12 @@ public class PhraseCluster {
{
if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit)
{
- System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
+ //System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
continue;
}
+ Arrays.fill(exp_pi, 1e-10);
+
// FIXME: add rare edge check to phrase objective & posterior processing
PhraseObjective po = new PhraseObjective(this, phrase, scalePT, (cacheLambda) ? lambdaPT[phrase] : null);
boolean ok = po.optimizeWithProjectedGradientDescent();
@@ -294,9 +295,12 @@ public class PhraseCluster {
exp_emit[tag][pos][context.get(pos)]+=q[edge][tag]*contextCnt;
}
- exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
+ exp_pi[tag]+=q[edge][tag]*contextCnt;
+
}
}
+ arr.F.l1normalize(exp_pi);
+ System.arraycopy(exp_pi, 0, pi[phrase], 0, K);
}
long end = System.currentTimeMillis();
@@ -313,10 +317,6 @@ public class PhraseCluster {
arr.F.l1normalize(j);
emit=exp_emit;
- for(double []j:exp_pi)
- arr.F.l1normalize(j);
- pi=exp_pi;
-
return primal;
}