summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/PhraseContextModel.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/PhraseContextModel.java')
-rw-r--r--gi/posterior-regularisation/PhraseContextModel.java29
1 files changed, 22 insertions, 7 deletions
diff --git a/gi/posterior-regularisation/PhraseContextModel.java b/gi/posterior-regularisation/PhraseContextModel.java
index c48cfacd..db152e73 100644
--- a/gi/posterior-regularisation/PhraseContextModel.java
+++ b/gi/posterior-regularisation/PhraseContextModel.java
@@ -88,11 +88,23 @@ class PhraseContextModel
lambda = new double[training.getNumEdges() * numTags];
for (double[][] emissionTW : emissions)
+ {
for (double[] emissionW : emissionTW)
+ {
randomise(emissionW);
-
+// for (int i = 0; i < emissionW.length; ++i)
+// emissionW[i] = i+1;
+// normalise(emissionW);
+ }
+ }
+
for (double[] priorTag : prior)
+ {
randomise(priorTag);
+// for (int i = 0; i < priorTag.length; ++i)
+// priorTag[i] = i+1;
+// normalise(priorTag);
+ }
}
void expectationMaximisation(int numIterations)
@@ -327,6 +339,7 @@ class PhraseContextModel
gradient = new double[training.getNumEdges()*numTags];
setInitialParameters(lambda);
+ computeObjectiveAndGradient();
}
@Override
@@ -353,8 +366,8 @@ class PhraseContextModel
edgeIndex += edges.size() * numTags;
}
- //System.out.println("Project point: " + Arrays.toString(point)
- // + " => " + Arrays.toString(newPoint));
+// System.out.println("Proj from: " + Arrays.toString(point));
+// System.out.println("Proj to: " + Arrays.toString(newPoint));
return newPoint;
}
@@ -368,12 +381,14 @@ class PhraseContextModel
@Override
public double[] getGradient()
{
+ gradientCalls += 1;
return gradient;
}
@Override
public double getValue()
{
+ functionCalls += 1;
return objective;
}
@@ -397,7 +412,7 @@ class PhraseContextModel
q.get(i).get(j).set(t, v);
z += v;
}
- objective = log(z) * e.getCount();
+ objective += log(z) * e.getCount();
for (int t = 0; t < numTags; t++)
{
@@ -409,9 +424,9 @@ class PhraseContextModel
edgeIndex += numTags;
}
}
- System.out.println("computeObjectiveAndGradient logz=" + objective);
- System.out.println("lambda= " + Arrays.toString(parameters));
- System.out.println("gradient=" + Arrays.toString(gradient));
+// System.out.println("computeObjectiveAndGradient logz=" + objective);
+// System.out.println("lambda= " + Arrays.toString(parameters));
+// System.out.println("gradient=" + Arrays.toString(gradient));
}
public String toString()