summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation')
-rw-r--r--gi/posterior-regularisation/Corpus.java120
-rw-r--r--gi/posterior-regularisation/PhraseContextModel.java314
2 files changed, 155 insertions, 279 deletions
diff --git a/gi/posterior-regularisation/Corpus.java b/gi/posterior-regularisation/Corpus.java
index 047e6ee8..07b27387 100644
--- a/gi/posterior-regularisation/Corpus.java
+++ b/gi/posterior-regularisation/Corpus.java
@@ -7,69 +7,64 @@ import java.util.regex.Pattern;
public class Corpus
{
private Lexicon<String> tokenLexicon = new Lexicon<String>();
- private Lexicon<TIntArrayList> ngramLexicon = new Lexicon<TIntArrayList>();
+ private Lexicon<TIntArrayList> phraseLexicon = new Lexicon<TIntArrayList>();
+ private Lexicon<TIntArrayList> contextLexicon = new Lexicon<TIntArrayList>();
private List<Edge> edges = new ArrayList<Edge>();
- private Map<Ngram,List<Edge>> phraseToContext = new HashMap<Ngram,List<Edge>>();
- private Map<Ngram,List<Edge>> contextToPhrase = new HashMap<Ngram,List<Edge>>();
+ private List<List<Edge>> phraseToContext = new ArrayList<List<Edge>>();
+ private List<List<Edge>> contextToPhrase = new ArrayList<List<Edge>>();
- public class Ngram
+ public class Edge
{
- private Ngram(int id)
+ Edge(int phraseId, int contextId, int count)
{
- ngramId = id;
+ this.phraseId = phraseId;
+ this.contextId = contextId;
+ this.count = count;
}
- public int getId()
+ public int getPhraseId()
{
- return ngramId;
+ return phraseId;
}
- public TIntArrayList getTokenIds()
+ public TIntArrayList getPhrase()
{
- return ngramLexicon.lookup(ngramId);
+ return phraseLexicon.lookup(phraseId);
}
- public String toString()
+ public String getPhraseString()
{
StringBuffer b = new StringBuffer();
- for (int tid: getTokenIds().toNativeArray())
+ for (int tid: getPhrase().toNativeArray())
{
if (b.length() > 0)
b.append(" ");
b.append(tokenLexicon.lookup(tid));
}
return b.toString();
- }
- public int hashCode()
- {
- return ngramId;
- }
- public boolean equals(Object other)
- {
- return other instanceof Ngram && ngramId == ((Ngram) other).ngramId;
- }
- private int ngramId;
- }
-
- public class Edge
- {
- Edge(Ngram phrase, Ngram context, int count)
+ }
+ public int getContextId()
{
- this.phrase = phrase;
- this.context = context;
- this.count = count;
+ return contextId;
}
- public Ngram getPhrase()
+ public TIntArrayList getContext()
{
- return phrase;
+ return contextLexicon.lookup(contextId);
}
- public Ngram getContext()
+ public String getContextString()
{
- return context;
+ StringBuffer b = new StringBuffer();
+ for (int tid: getContext().toNativeArray())
+ {
+ if (b.length() > 0)
+ b.append(" ");
+ b.append(tokenLexicon.lookup(tid));
+ }
+ return b.toString();
}
public int getCount()
{
return count;
}
- private Ngram phrase;
- private Ngram context;
+ private int phraseId;
+ private int contextId;
private int count;
}
@@ -78,32 +73,32 @@ public class Corpus
return edges;
}
- int numEdges()
+ int getNumEdges()
{
return edges.size();
}
- Set<Ngram> getPhrases()
+ int getNumPhrases()
{
- return phraseToContext.keySet();
+ return phraseLexicon.size();
}
- List<Edge> getEdgesForPhrase(Ngram phrase)
+ List<Edge> getEdgesForPhrase(int phraseId)
{
- return phraseToContext.get(phrase);
+ return phraseToContext.get(phraseId);
}
- Set<Ngram> getContexts()
+ int getNumContexts()
{
- return contextToPhrase.keySet();
+ return contextLexicon.size();
}
- List<Edge> getEdgesForContext(Ngram context)
+ List<Edge> getEdgesForContext(int contextId)
{
- return contextToPhrase.get(context);
+ return contextToPhrase.get(contextId);
}
- int numTokens()
+ int getNumTokens()
{
return tokenLexicon.size();
}
@@ -132,8 +127,9 @@ public class Corpus
TIntArrayList ptoks = new TIntArrayList();
while (st.hasMoreTokens())
ptoks.add(c.tokenLexicon.insert(st.nextToken()));
- int phraseId = c.ngramLexicon.insert(ptoks);
- Ngram phrase = c.new Ngram(phraseId);
+ int phraseId = c.phraseLexicon.insert(ptoks);
+ if (phraseId == c.phraseToContext.size())
+ c.phraseToContext.add(new ArrayList<Edge>());
// process contexts
String[] parts = separator.split(rest);
@@ -151,30 +147,18 @@ public class Corpus
if (!token.equals("<PHRASE>"))
ctx.add(c.tokenLexicon.insert(token));
}
- int contextId = c.ngramLexicon.insert(ctx);
- Ngram context = c.new Ngram(contextId);
+ int contextId = c.contextLexicon.insert(ctx);
+ if (contextId == c.contextToPhrase.size())
+ c.contextToPhrase.add(new ArrayList<Edge>());
assert (countString.startsWith("C="));
- Edge e = c.new Edge(phrase, context, Integer.parseInt(countString.substring(2).trim()));
+ Edge e = c.new Edge(phraseId, contextId,
+ Integer.parseInt(countString.substring(2).trim()));
c.edges.add(e);
- // index the edge for fast phrase lookup
- List<Edge> edges = c.phraseToContext.get(phrase);
- if (edges == null)
- {
- edges = new ArrayList<Edge>();
- c.phraseToContext.put(phrase, edges);
- }
- edges.add(e);
-
- // index the edge for fast context lookup
- edges = c.contextToPhrase.get(context);
- if (edges == null)
- {
- edges = new ArrayList<Edge>();
- c.contextToPhrase.put(context, edges);
- }
- edges.add(e);
+ // index the edge for fast phrase, context lookup
+ c.phraseToContext.get(phraseId).add(e);
+ c.contextToPhrase.get(contextId).add(e);
}
}
diff --git a/gi/posterior-regularisation/PhraseContextModel.java b/gi/posterior-regularisation/PhraseContextModel.java
index d0a92dde..c48cfacd 100644
--- a/gi/posterior-regularisation/PhraseContextModel.java
+++ b/gi/posterior-regularisation/PhraseContextModel.java
@@ -40,44 +40,16 @@ import optimization.stopCriteria.ProjectedGradientL2Norm;
import optimization.stopCriteria.StopingCriteria;
import optimization.stopCriteria.ValueDifference;
import optimization.util.MathUtils;
-
import java.util.*;
import java.util.regex.*;
import gnu.trove.TDoubleArrayList;
+import gnu.trove.TIntArrayList;
import static java.lang.Math.*;
-class Lexicon<T>
-{
- public int insert(T word)
- {
- Integer i = wordToIndex.get(word);
- if (i == null)
- {
- i = indexToWord.size();
- wordToIndex.put(word, i);
- indexToWord.add(word);
- }
- return i;
- }
-
- public T lookup(int index)
- {
- return indexToWord.get(index);
- }
-
- public int size()
- {
- return indexToWord.size();
- }
-
- private Map<T, Integer> wordToIndex = new HashMap<T, Integer>();
- private List<T> indexToWord = new ArrayList<T>();
-}
-
class PhraseContextModel
{
// model/optimisation configuration parameters
- int numTags, numEdges;
+ int numTags;
boolean posteriorRegularisation = true;
double constraintScale = 3; // FIXME: make configurable
@@ -88,33 +60,32 @@ class PhraseContextModel
int minOccurrencesForProjection = 0;
// book keeping
- Lexicon<String> tokenLexicon = new Lexicon<String>();
int numPositions;
Random rng = new Random();
- // training set; 1 entry for each unique phrase
- PhraseAndContexts training[];
+ // training set
+ Corpus training;
// model parameters (learnt)
double emissions[][][]; // position in 0 .. 3 x tag x word Pr(word | tag, position)
double prior[][]; // phrase x tag Pr(tag | phrase)
double lambda[]; // edge = (phrase, context) x tag flattened lagrange multipliers
- PhraseContextModel(File infile, int tags) throws IOException
+ PhraseContextModel(Corpus training, int tags)
{
- numTags = tags;
- numEdges = 0;
- readTrainingFromFile(new FileReader(infile));
- assert (training.length > 0);
+ this.training = training;
+ this.numTags = tags;
+ assert (!training.getEdges().isEmpty());
+ assert (numTags > 1);
// now initialise emissions
- assert (training[0].contexts.length > 0);
- numPositions = training[0].contexts[0].tokens.length;
+ numPositions = training.getEdges().get(0).getContext().size();
+ assert (numPositions > 0);
- emissions = new double[numPositions][numTags][tokenLexicon.size()];
- prior = new double[training.length][numTags];
+ emissions = new double[numPositions][numTags][training.getNumTokens()];
+ prior = new double[training.getNumEdges()][numTags];
if (posteriorRegularisation)
- lambda = new double[numEdges * numTags];
+ lambda = new double[training.getNumEdges() * numTags];
for (double[][] emissionTW : emissions)
for (double[] emissionW : emissionTW)
@@ -130,8 +101,8 @@ class PhraseContextModel
for (int iteration = 0; iteration < numIterations; ++iteration)
{
- double emissionsCounts[][][] = new double[numPositions][numTags][tokenLexicon.size()];
- double priorCounts[][] = new double[training.length][numTags];
+ double emissionsCounts[][][] = new double[numPositions][numTags][training.getNumTokens()];
+ double priorCounts[][] = new double[training.getNumPhrases()][numTags];
// E-step
double llh = 0;
@@ -140,71 +111,70 @@ class PhraseContextModel
EStepDualObjective objective = new EStepDualObjective();
// copied from x2y2withconstraints
- LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc(new InterpolationPickFirstStep(1));
- OptimizerStats stats = new OptimizerStats();
- ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
- CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
- compositeStop.add(new ProjectedGradientL2Norm(0.001));
- compositeStop.add(new ValueDifference(0.001));
- optimizer.setMaxIterations(50);
- boolean succeed = optimizer.optimize(objective,stats,compositeStop);
+// LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc(new InterpolationPickFirstStep(1));
+// OptimizerStats stats = new OptimizerStats();
+// ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
+// CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
+// compositeStop.add(new ProjectedGradientL2Norm(0.001));
+// compositeStop.add(new ValueDifference(0.001));
+// optimizer.setMaxIterations(50);
+// boolean succeed = optimizer.optimize(objective,stats,compositeStop);
// copied from depparser l1lmaxobjective
-// ProjectedOptimizerStats stats = new ProjectedOptimizerStats();
-// GenericPickFirstStep pickFirstStep = new GenericPickFirstStep(1);
-// LineSearchMethod linesearch = new WolfRuleLineSearch(pickFirstStep, c1, c2);
-// ProjectedGradientDescent optimizer = new ProjectedGradientDescent(linesearch);
-// optimizer.setMaxIterations(maxProjectionIterations);
-// StopingCriteria stopGrad = new NormalizedProjectedGradientL2Norm(stoppingPrecision);
-// StopingCriteria stopValue = new NormalizedValueDifference(stoppingPrecision);
-// CompositeStopingCriteria stop = new CompositeStopingCriteria();
-// stop.add(stopGrad);
-// stop.add(stopValue);
-// boolean succeed = optimizer.optimize(objective, stats, stop);
-
- //System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
+ ProjectedOptimizerStats stats = new ProjectedOptimizerStats();
+ GenericPickFirstStep pickFirstStep = new GenericPickFirstStep(1);
+ LineSearchMethod linesearch = new WolfRuleLineSearch(pickFirstStep, c1, c2);
+ ProjectedGradientDescent optimizer = new ProjectedGradientDescent(linesearch);
+ optimizer.setMaxIterations(maxProjectionIterations);
+ CompositeStopingCriteria stop = new CompositeStopingCriteria();
+ stop.add(new NormalizedProjectedGradientL2Norm(stoppingPrecision));
+ stop.add(new NormalizedValueDifference(stoppingPrecision));
+ boolean succeed = optimizer.optimize(objective, stats, stop);
+
+ System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
//System.out.println("Solution: " + objective.parameters);
if (!succeed)
System.out.println("Failed to optimize");
//System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
- // make sure we update the dual params
- //llh = objective.getValue();
+ lambda = objective.getParameters();
llh = objective.primal();
- // FIXME: this is the dual not the primal and omits the llh term
- for (int i = 0; i < training.length; ++i)
+ for (int i = 0; i < training.getNumPhrases(); ++i)
{
- PhraseAndContexts instance = training[i];
- for (int j = 0; j < instance.contexts.length; ++j)
+ List<Corpus.Edge> edges = training.getEdgesForPhrase(i);
+ for (int j = 0; j < edges.size(); ++j)
{
- Context c = instance.contexts[j];
+ Corpus.Edge e = edges.get(j);
for (int t = 0; t < numTags; t++)
{
double p = objective.q.get(i).get(j).get(t);
- priorCounts[i][t] += c.count * p;
- for (int k = 0; k < c.tokens.length; ++k)
- emissionsCounts[k][t][c.tokens[k]] += c.count * p;
+ priorCounts[i][t] += e.getCount() * p;
+ TIntArrayList tokens = e.getContext();
+ for (int k = 0; k < tokens.size(); ++k)
+ emissionsCounts[k][t][tokens.get(k)] += e.getCount() * p;
}
}
}
}
else
{
- for (int i = 0; i < training.length; ++i)
+ for (int i = 0; i < training.getNumPhrases(); ++i)
{
- PhraseAndContexts instance = training[i];
- for (Context ctx : instance.contexts)
+ List<Corpus.Edge> edges = training.getEdgesForPhrase(i);
+ for (int j = 0; j < edges.size(); ++j)
{
- double probs[] = posterior(i, ctx);
+ Corpus.Edge e = edges.get(j);
+ double probs[] = posterior(i, e);
double z = normalise(probs);
- llh += log(z) * ctx.count;
-
+ llh += log(z) * e.getCount();
+
+ TIntArrayList tokens = e.getContext();
for (int t = 0; t < numTags; ++t)
{
- priorCounts[i][t] += ctx.count * probs[t];
- for (int j = 0; j < ctx.tokens.length; ++j)
- emissionsCounts[j][t][ctx.tokens[j]] += ctx.count * probs[t];
+ priorCounts[i][t] += e.getCount() * probs[t];
+ for (int k = 0; k < tokens.size(); ++k)
+ emissionsCounts[j][t][tokens.get(k)] += e.getCount() * probs[t];
}
}
}
@@ -268,104 +238,34 @@ class PhraseContextModel
return mi;
}
- double[] posterior(int phraseId, Context c) // unnormalised
+ double[] posterior(int phraseId, Corpus.Edge e) // unnormalised
{
double probs[] = new double[numTags];
+ TIntArrayList tokens = e.getContext();
for (int t = 0; t < numTags; ++t)
{
probs[t] = prior[phraseId][t];
- for (int j = 0; j < c.tokens.length; ++j)
- probs[t] *= emissions[j][t][c.tokens[j]];
+ for (int k = 0; k < tokens.size(); ++k)
+ probs[t] *= emissions[k][t][tokens.get(k)];
}
return probs;
}
- private void readTrainingFromFile(Reader in) throws IOException
- {
- // read in line-by-line
- BufferedReader bin = new BufferedReader(in);
- String line;
- List<PhraseAndContexts> instances = new ArrayList<PhraseAndContexts>();
- Pattern separator = Pattern.compile(" \\|\\|\\| ");
-
- while ((line = bin.readLine()) != null)
- {
- // split into phrase and contexts
- StringTokenizer st = new StringTokenizer(line, "\t");
- assert (st.hasMoreTokens());
- String phrase = st.nextToken();
- assert (st.hasMoreTokens());
- String rest = st.nextToken();
- assert (!st.hasMoreTokens());
-
- // process phrase
- st = new StringTokenizer(phrase, " ");
- List<Integer> ptoks = new ArrayList<Integer>();
- while (st.hasMoreTokens())
- ptoks.add(tokenLexicon.insert(st.nextToken()));
-
- // process contexts
- ArrayList<Context> contexts = new ArrayList<Context>();
- String[] parts = separator.split(rest);
- assert (parts.length % 2 == 0);
- for (int i = 0; i < parts.length; i += 2)
- {
- // process pairs of strings - context and count
- ArrayList<Integer> ctx = new ArrayList<Integer>();
- String ctxString = parts[i];
- String countString = parts[i + 1];
- StringTokenizer ctxStrtok = new StringTokenizer(ctxString, " ");
- while (ctxStrtok.hasMoreTokens())
- {
- String token = ctxStrtok.nextToken();
- if (!token.equals("<PHRASE>"))
- ctx.add(tokenLexicon.insert(token));
- }
-
- assert (countString.startsWith("C="));
- Context c = new Context();
- c.count = Integer.parseInt(countString.substring(2).trim());
- // damn unboxing doesn't work with toArray
- c.tokens = new int[ctx.size()];
- for (int k = 0; k < ctx.size(); ++k)
- c.tokens[k] = ctx.get(k);
- contexts.add(c);
-
- numEdges += 1;
- }
-
- // package up
- PhraseAndContexts instance = new PhraseAndContexts();
- // damn unboxing doesn't work with toArray
- instance.phraseTokens = new int[ptoks.size()];
- for (int k = 0; k < ptoks.size(); ++k)
- instance.phraseTokens[k] = ptoks.get(k);
- instance.contexts = contexts.toArray(new Context[] {});
- instances.add(instance);
- }
-
- training = instances.toArray(new PhraseAndContexts[] {});
-
- System.out.println("Read in " + training.length + " phrases and " + numEdges + " edges");
- }
-
void displayPosterior()
{
- for (int i = 0; i < training.length; ++i)
+ for (int i = 0; i < training.getNumPhrases(); ++i)
{
- PhraseAndContexts instance = training[i];
- for (Context ctx : instance.contexts)
+ List<Corpus.Edge> edges = training.getEdgesForPhrase(i);
+ for (Corpus.Edge e: edges)
{
- double probs[] = posterior(i, ctx);
+ double probs[] = posterior(i, e);
normalise(probs);
// emit phrase
- for (int t : instance.phraseTokens)
- System.out.print(tokenLexicon.lookup(t) + " ");
+ System.out.print(e.getPhraseString());
System.out.print("\t");
- for (int c : ctx.tokens)
- System.out.print(tokenLexicon.lookup(c) + " ");
- System.out.print("||| C=" + ctx.count + " |||");
+ System.out.print(e.getContextString());
+ System.out.print("||| C=" + e.getCount() + " |||");
int t = argmax(probs);
System.out.print(" " + t + " ||| " + probs[t]);
@@ -376,24 +276,13 @@ class PhraseContextModel
}
}
- class PhraseAndContexts
- {
- int phraseTokens[];
- Context contexts[];
- }
-
- class Context
- {
- int count;
- int[] tokens;
- }
-
public static void main(String[] args)
{
assert (args.length >= 2);
try
{
- PhraseContextModel model = new PhraseContextModel(new File(args[0]), Integer.parseInt(args[1]));
+ Corpus corpus = Corpus.readFromFile(new FileReader(new File(args[0])));
+ PhraseContextModel model = new PhraseContextModel(corpus, Integer.parseInt(args[1]));
model.expectationMaximisation(Integer.parseInt(args[2]));
model.displayPosterior();
}
@@ -416,26 +305,27 @@ class PhraseContextModel
{
super();
// compute conditionals p(context, tag | phrase) for all training instances
- conditionals = new ArrayList<List<TDoubleArrayList>>(training.length);
- q = new ArrayList<List<TDoubleArrayList>>(training.length);
- for (int i = 0; i < training.length; ++i)
+ conditionals = new ArrayList<List<TDoubleArrayList>>(training.getNumPhrases());
+ q = new ArrayList<List<TDoubleArrayList>>(training.getNumPhrases());
+ for (int i = 0; i < training.getNumPhrases(); ++i)
{
- PhraseAndContexts instance = training[i];
- conditionals.add(new ArrayList<TDoubleArrayList>(instance.contexts.length));
- q.add(new ArrayList<TDoubleArrayList>(instance.contexts.length));
+ List<Corpus.Edge> edges = training.getEdgesForPhrase(i);
- for (int j = 0; j < instance.contexts.length; ++j)
+ conditionals.add(new ArrayList<TDoubleArrayList>(edges.size()));
+ q.add(new ArrayList<TDoubleArrayList>(edges.size()));
+
+ for (int j = 0; j < edges.size(); ++j)
{
- Context c = instance.contexts[j];
- double probs[] = posterior(i, c);
+ Corpus.Edge e = edges.get(j);
+ double probs[] = posterior(i, e);
double z = normalise(probs);
- llh += log(z) * c.count;
+ llh += log(z) * e.getCount();
conditionals.get(i).add(new TDoubleArrayList(probs));
q.get(i).add(new TDoubleArrayList(probs));
}
}
- gradient = new double[numEdges*numTags];
+ gradient = new double[training.getNumEdges()*numTags];
setInitialParameters(lambda);
}
@@ -446,22 +336,22 @@ class PhraseContextModel
double[] newPoint = point.clone();
int edgeIndex = 0;
- for (int i = 0; i < training.length; ++i)
+ for (int i = 0; i < training.getNumPhrases(); ++i)
{
- PhraseAndContexts instance = training[i];
+ List<Corpus.Edge> edges = training.getEdgesForPhrase(i);
for (int t = 0; t < numTags; t++)
{
- double[] subPoint = new double[instance.contexts.length];
- for (int j = 0; j < instance.contexts.length; ++j)
+ double[] subPoint = new double[edges.size()];
+ for (int j = 0; j < edges.size(); ++j)
subPoint[j] = point[edgeIndex+j*numTags+t];
-
+
p.project(subPoint);
- for (int j = 0; j < instance.contexts.length; ++j)
+ for (int j = 0; j < edges.size(); ++j)
newPoint[edgeIndex+j*numTags+t] = subPoint[j];
}
- edgeIndex += instance.contexts.length * numTags;
+ edgeIndex += edges.size() * numTags;
}
//System.out.println("Project point: " + Arrays.toString(point)
// + " => " + Arrays.toString(newPoint));
@@ -492,13 +382,13 @@ class PhraseContextModel
int edgeIndex = 0;
objective = 0;
Arrays.fill(gradient, 0);
- for (int i = 0; i < training.length; ++i)
+ for (int i = 0; i < training.getNumPhrases(); ++i)
{
- PhraseAndContexts instance = training[i];
-
- for (int j = 0; j < instance.contexts.length; ++j)
+ List<Corpus.Edge> edges = training.getEdgesForPhrase(i);
+
+ for (int j = 0; j < edges.size(); ++j)
{
- Context c = instance.contexts[j];
+ Corpus.Edge e = edges.get(j);
double z = 0;
for (int t = 0; t < numTags; t++)
@@ -507,20 +397,21 @@ class PhraseContextModel
q.get(i).get(j).set(t, v);
z += v;
}
- objective = log(z) * c.count;
+ objective = log(z) * e.getCount();
for (int t = 0; t < numTags; t++)
{
double v = q.get(i).get(j).get(t) / z;
q.get(i).get(j).set(t, v);
- gradient[edgeIndex+t] -= c.count * v;
+ gradient[edgeIndex+t] -= e.getCount() * v;
}
edgeIndex += numTags;
}
}
-// System.out.println("computeObjectiveAndGradient logz=" + objective);
-// System.out.println("gradient=" + Arrays.toString(gradient));
+ System.out.println("computeObjectiveAndGradient logz=" + objective);
+ System.out.println("lambda= " + Arrays.toString(parameters));
+ System.out.println("gradient=" + Arrays.toString(gradient));
}
public String toString()
@@ -528,7 +419,7 @@ class PhraseContextModel
StringBuilder sb = new StringBuilder();
sb.append(getClass().getCanonicalName()).append(" with ");
sb.append(parameters.length).append(" parameters and ");
- sb.append(training.length * numTags).append(" constraints");
+ sb.append(training.getNumPhrases() * numTags).append(" constraints");
return sb.toString();
}
@@ -538,16 +429,17 @@ class PhraseContextModel
// kl = sum_Y q(Y) log q(Y) / p(Y|X)
// = sum_Y q(Y) { -lambda . phi(Y) - log Z }
// = -log Z - lambda . E_q[phi]
+ // = -objective + lambda . gradient
- double kl = -objective - MathUtils.dotProduct(parameters, gradient);
+ double kl = -objective + MathUtils.dotProduct(parameters, gradient);
double l1lmax = 0;
- for (int i = 0; i < training.length; ++i)
+ for (int i = 0; i < training.getNumPhrases(); ++i)
{
- PhraseAndContexts instance = training[i];
+ List<Corpus.Edge> edges = training.getEdgesForPhrase(i);
for (int t = 0; t < numTags; t++)
{
double lmax = Double.NEGATIVE_INFINITY;
- for (int j = 0; j < instance.contexts.length; ++j)
+ for (int j = 0; j < edges.size(); ++j)
lmax = max(lmax, q.get(i).get(j).get(t));
l1lmax += lmax;
}