diff options
Diffstat (limited to 'gi/posterior-regularisation/PhraseContextModel.java')
-rw-r--r-- | gi/posterior-regularisation/PhraseContextModel.java | 29 |
1 files changed, 22 insertions, 7 deletions
diff --git a/gi/posterior-regularisation/PhraseContextModel.java b/gi/posterior-regularisation/PhraseContextModel.java index c48cfacd..db152e73 100644 --- a/gi/posterior-regularisation/PhraseContextModel.java +++ b/gi/posterior-regularisation/PhraseContextModel.java @@ -88,11 +88,23 @@ class PhraseContextModel lambda = new double[training.getNumEdges() * numTags]; for (double[][] emissionTW : emissions) + { for (double[] emissionW : emissionTW) + { randomise(emissionW); - +// for (int i = 0; i < emissionW.length; ++i) +// emissionW[i] = i+1; +// normalise(emissionW); + } + } + for (double[] priorTag : prior) + { randomise(priorTag); +// for (int i = 0; i < priorTag.length; ++i) +// priorTag[i] = i+1; +// normalise(priorTag); + } } void expectationMaximisation(int numIterations) @@ -327,6 +339,7 @@ class PhraseContextModel gradient = new double[training.getNumEdges()*numTags]; setInitialParameters(lambda); + computeObjectiveAndGradient(); } @Override @@ -353,8 +366,8 @@ class PhraseContextModel edgeIndex += edges.size() * numTags; } - //System.out.println("Project point: " + Arrays.toString(point) - // + " => " + Arrays.toString(newPoint)); +// System.out.println("Proj from: " + Arrays.toString(point)); +// System.out.println("Proj to: " + Arrays.toString(newPoint)); return newPoint; } @@ -368,12 +381,14 @@ class PhraseContextModel @Override public double[] getGradient() { + gradientCalls += 1; return gradient; } @Override public double getValue() { + functionCalls += 1; return objective; } @@ -397,7 +412,7 @@ class PhraseContextModel q.get(i).get(j).set(t, v); z += v; } - objective = log(z) * e.getCount(); + objective += log(z) * e.getCount(); for (int t = 0; t < numTags; t++) { @@ -409,9 +424,9 @@ class PhraseContextModel edgeIndex += numTags; } } - System.out.println("computeObjectiveAndGradient logz=" + objective); - System.out.println("lambda= " + Arrays.toString(parameters)); - System.out.println("gradient=" + Arrays.toString(gradient)); +// System.out.println("computeObjectiveAndGradient logz=" + objective); +// System.out.println("lambda= " + Arrays.toString(parameters)); +// System.out.println("gradient=" + Arrays.toString(gradient)); } public String toString() |