summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-29 04:47:25 +0000
committertrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-29 04:47:25 +0000
commit40d446cd8044c001f7a30e88113f26027cce6d22 (patch)
tree56353a32a7c1e5400f0d5afe04af7a848f06d2b9
parent9705782c98265c05bc87a10b1fb69bd35781c754 (diff)
Added first attempt at PR variant of phrase-context model. It does something but not convinced that it's working.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@50 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r--gi/posterior-regularisation/PhraseContextModel.java802
1 files changed, 522 insertions, 280 deletions
diff --git a/gi/posterior-regularisation/PhraseContextModel.java b/gi/posterior-regularisation/PhraseContextModel.java
index 3af72d54..d0a92dde 100644
--- a/gi/posterior-regularisation/PhraseContextModel.java
+++ b/gi/posterior-regularisation/PhraseContextModel.java
@@ -24,294 +24,536 @@
// - agreement between phrase->context model and context->phrase model
import java.io.*;
+import optimization.gradientBasedMethods.*;
+import optimization.gradientBasedMethods.stats.OptimizerStats;
+import optimization.gradientBasedMethods.stats.ProjectedOptimizerStats;
+import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;
+import optimization.linesearch.GenericPickFirstStep;
+import optimization.linesearch.InterpolationPickFirstStep;
+import optimization.linesearch.LineSearchMethod;
+import optimization.linesearch.WolfRuleLineSearch;
+import optimization.projections.SimplexProjection;
+import optimization.stopCriteria.CompositeStopingCriteria;
+import optimization.stopCriteria.NormalizedProjectedGradientL2Norm;
+import optimization.stopCriteria.NormalizedValueDifference;
+import optimization.stopCriteria.ProjectedGradientL2Norm;
+import optimization.stopCriteria.StopingCriteria;
+import optimization.stopCriteria.ValueDifference;
+import optimization.util.MathUtils;
+
import java.util.*;
import java.util.regex.*;
+import gnu.trove.TDoubleArrayList;
import static java.lang.Math.*;
class Lexicon<T>
{
- public int insert(T word)
- {
- Integer i = wordToIndex.get(word);
- if (i == null)
- {
- i = indexToWord.size();
- wordToIndex.put(word, i);
- indexToWord.add(word);
- }
- return i;
- }
-
- public T lookup(int index)
- {
- return indexToWord.get(index);
- }
-
- public int size()
- {
- return indexToWord.size();
- }
-
- private Map<T,Integer> wordToIndex = new HashMap<T,Integer>();
- private List<T> indexToWord = new ArrayList<T>();
+ public int insert(T word)
+ {
+ Integer i = wordToIndex.get(word);
+ if (i == null)
+ {
+ i = indexToWord.size();
+ wordToIndex.put(word, i);
+ indexToWord.add(word);
+ }
+ return i;
+ }
+
+ public T lookup(int index)
+ {
+ return indexToWord.get(index);
+ }
+
+ public int size()
+ {
+ return indexToWord.size();
+ }
+
+ private Map<T, Integer> wordToIndex = new HashMap<T, Integer>();
+ private List<T> indexToWord = new ArrayList<T>();
}
class PhraseContextModel
{
- // model/optimisation configuration parameters
- int numTags;
- int numPRIterations = 5;
- boolean posteriorRegularisation = false;
- double constraintScale = 10;
-
- // book keeping
- Lexicon<String> tokenLexicon = new Lexicon<String>();
- int numPositions;
- Random rng = new Random();
-
- // training set; 1 entry for each unique phrase
- PhraseAndContexts training[];
-
- // model parameters (learnt)
- double emissions[][][]; // position in 0 .. 3 x tag x word Pr(word | tag, position)
- double prior[][]; // phrase x tag Pr(tag | phrase)
- //double lambda[][][]; // word x context x tag langrange multipliers
-
- PhraseContextModel(File infile, int tags) throws IOException
- {
- numTags = tags;
- readTrainingFromFile(new FileReader(infile));
- assert(training.length > 0);
-
- // now initialise emissions
- assert(training[0].contexts.length > 0);
- numPositions = training[0].contexts[0].tokens.length;
-
- emissions = new double[numPositions][numTags][tokenLexicon.size()];
- prior = new double[training.length][numTags];
- //lambda = new double[tokenLexicon.size()][???][numTags]
-
- for (double[][] emissionTW: emissions)
- for (double[] emissionW: emissionTW)
- randomise(emissionW);
-
- for (double[] priorTag: prior)
- randomise(priorTag);
- }
-
- void expectationMaximisation(int numIterations)
- {
- for (int iteration = 0; iteration < numIterations; ++iteration)
- {
- double emissionsCounts[][][] = new double[numPositions][numTags][tokenLexicon.size()];
- double priorCounts[][] = new double[training.length][numTags];
-
- // E-step
- double llh = 0;
- for (int i = 0; i < training.length; ++i)
- {
- PhraseAndContexts instance = training[i];
- for (Context ctx: instance.contexts)
- {
- double probs[] = posterior(i, ctx);
- double z = normalise(probs);
- llh += log(z) * ctx.count;
-
- for (int t = 0; t < numTags; ++t)
- {
- priorCounts[i][t] += ctx.count * probs[t];
- for (int j = 0; j < ctx.tokens.length; ++j)
- emissionsCounts[j][t][ctx.tokens[j]] += ctx.count * probs[t];
- }
- }
- }
-
- // M-step: normalise
- for (double[][] emissionTW: emissionsCounts)
- for (double[] emissionW: emissionTW)
- normalise(emissionW);
-
- for (double[] priorTag: priorCounts)
- normalise(priorTag);
-
- emissions = emissionsCounts;
- prior = priorCounts;
-
- System.out.println("Iteration " + iteration + " llh " + llh);
- }
- }
-
- static double normalise(double probs[])
- {
- double z = 0;
- for (double p: probs)
- z += p;
- for (int i = 0; i < probs.length; ++i)
- probs[i] /= z;
- return z;
- }
-
- void randomise(double probs[])
- {
- double z = 0;
- for (int i = 0; i < probs.length; ++i)
- {
- probs[i] = 10 + rng.nextDouble();
- z += probs[i];
- }
-
- for (int i = 0; i < probs.length; ++i)
- probs[i] /= z;
- }
-
- static int argmax(double probs[])
- {
- double m = Double.NEGATIVE_INFINITY;
- int mi = -1;
- for (int i = 0; i < probs.length; ++i)
- {
- if (probs[i] > m)
- {
- m = probs[i];
- mi = i;
- }
- }
- return mi;
- }
-
- double[] posterior(int phraseId, Context c) // unnormalised
- {
- double probs[] = new double[numTags];
- for (int t = 0; t < numTags; ++t)
- {
- probs[t] = prior[phraseId][t];
- for (int j = 0; j < c.tokens.length; ++j)
- probs[t] *= emissions[j][t][c.tokens[j]];
- }
- return probs;
- }
-
- private void readTrainingFromFile(Reader in) throws IOException
- {
- // read in line-by-line
- BufferedReader bin = new BufferedReader(in);
- String line;
- List<PhraseAndContexts> instances = new ArrayList<PhraseAndContexts>();
- Pattern separator = Pattern.compile(" \\|\\|\\| ");
-
- int numEdges = 0;
- while ((line = bin.readLine()) != null)
- {
- // split into phrase and contexts
- StringTokenizer st = new StringTokenizer(line, "\t");
- assert(st.hasMoreTokens());
- String phrase = st.nextToken();
- assert(st.hasMoreTokens());
- String rest = st.nextToken();
- assert(!st.hasMoreTokens());
-
- // process phrase
- st = new StringTokenizer(phrase, " ");
- List<Integer> ptoks = new ArrayList<Integer>();
- while (st.hasMoreTokens())
- ptoks.add(tokenLexicon.insert(st.nextToken()));
-
- // process contexts
- ArrayList<Context> contexts = new ArrayList<Context>();
- String[] parts = separator.split(rest);
- assert(parts.length % 2 == 0);
- for (int i = 0; i < parts.length; i += 2)
- {
- // process pairs of strings - context and count
- ArrayList<Integer> ctx = new ArrayList<Integer>();
- String ctxString = parts[i];
- String countString = parts[i+1];
- StringTokenizer ctxStrtok = new StringTokenizer(ctxString, " ");
- while (ctxStrtok.hasMoreTokens())
- {
- String token = ctxStrtok.nextToken();
- if (!token.equals("<PHRASE>"))
- ctx.add(tokenLexicon.insert(token));
- }
-
- assert(countString.startsWith("C="));
- Context c = new Context();
- c.count = Integer.parseInt(countString.substring(2).trim());
- // damn unboxing doesn't work with toArray
- c.tokens = new int[ctx.size()];
- for (int k = 0; k < ctx.size(); ++k)
- c.tokens[k] = ctx.get(k);
- contexts.add(c);
-
- numEdges += 1;
- }
-
- // package up
- PhraseAndContexts instance = new PhraseAndContexts();
- // damn unboxing doesn't work with toArray
- instance.phraseTokens = new int[ptoks.size()];
- for (int k = 0; k < ptoks.size(); ++k)
- instance.phraseTokens[k] = ptoks.get(k);
- instance.contexts = contexts.toArray(new Context[] {});
- instances.add(instance);
- }
-
- training = instances.toArray(new PhraseAndContexts[] {});
-
- System.out.println("Read in " + training.length + " phrases and " + numEdges + " edges");
- }
-
- void displayPosterior()
- {
- for (int i = 0; i < training.length; ++i)
- {
- PhraseAndContexts instance = training[i];
- for (Context ctx: instance.contexts)
- {
- double probs[] = posterior(i, ctx);
- double z = normalise(probs);
-
- // emit phrase
- for (int t: instance.phraseTokens)
- System.out.print(tokenLexicon.lookup(t) + " ");
- System.out.print("\t");
- for (int c: ctx.tokens)
- System.out.print(tokenLexicon.lookup(c) + " ");
- System.out.print("||| C=" + ctx.count + " |||");
-
- System.out.print(" " + argmax(probs));
- //for (int t = 0; t < numTags; ++t)
- //System.out.print(" " + probs[t]);
- System.out.println();
- }
- }
- }
-
- class PhraseAndContexts
- {
- int phraseTokens[];
- Context contexts[];
- }
-
- class Context
- {
- int count;
- int[] tokens;
- }
-
- public static void main(String []args)
- {
- assert(args.length >= 2);
- try
- {
- PhraseContextModel model = new PhraseContextModel(new File(args[0]), Integer.parseInt(args[1]));
- model.expectationMaximisation(Integer.parseInt(args[2]));
- model.displayPosterior();
- }
- catch (IOException e)
- {
- System.out.println("Failed to read input file: " + args[0]);
- e.printStackTrace();
- }
- }
+ // model/optimisation configuration parameters
+ int numTags, numEdges;
+ boolean posteriorRegularisation = true;
+ double constraintScale = 3; // FIXME: make configurable
+
+ // copied from L1LMax in depparsing code
+ final double c1= 0.0001, c2=0.9, stoppingPrecision = 1e-5, maxStep = 10;
+ final int maxZoomEvals = 10, maxExtrapolationIters = 200;
+ int maxProjectionIterations = 200;
+ int minOccurrencesForProjection = 0;
+
+ // book keeping
+ Lexicon<String> tokenLexicon = new Lexicon<String>();
+ int numPositions;
+ Random rng = new Random();
+
+ // training set; 1 entry for each unique phrase
+ PhraseAndContexts training[];
+
+ // model parameters (learnt)
+ double emissions[][][]; // position in 0 .. 3 x tag x word Pr(word | tag, position)
+ double prior[][]; // phrase x tag Pr(tag | phrase)
+ double lambda[]; // edge = (phrase, context) x tag flattened lagrange multipliers
+
+ PhraseContextModel(File infile, int tags) throws IOException
+ {
+ numTags = tags;
+ numEdges = 0;
+ readTrainingFromFile(new FileReader(infile));
+ assert (training.length > 0);
+
+ // now initialise emissions
+ assert (training[0].contexts.length > 0);
+ numPositions = training[0].contexts[0].tokens.length;
+
+ emissions = new double[numPositions][numTags][tokenLexicon.size()];
+ prior = new double[training.length][numTags];
+ if (posteriorRegularisation)
+ lambda = new double[numEdges * numTags];
+
+ for (double[][] emissionTW : emissions)
+ for (double[] emissionW : emissionTW)
+ randomise(emissionW);
+
+ for (double[] priorTag : prior)
+ randomise(priorTag);
+ }
+
+ void expectationMaximisation(int numIterations)
+ {
+ double lastLlh = Double.NEGATIVE_INFINITY;
+
+ for (int iteration = 0; iteration < numIterations; ++iteration)
+ {
+ double emissionsCounts[][][] = new double[numPositions][numTags][tokenLexicon.size()];
+ double priorCounts[][] = new double[training.length][numTags];
+
+ // E-step
+ double llh = 0;
+ if (posteriorRegularisation)
+ {
+ EStepDualObjective objective = new EStepDualObjective();
+
+ // copied from x2y2withconstraints
+ LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc(new InterpolationPickFirstStep(1));
+ OptimizerStats stats = new OptimizerStats();
+ ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
+ CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
+ compositeStop.add(new ProjectedGradientL2Norm(0.001));
+ compositeStop.add(new ValueDifference(0.001));
+ optimizer.setMaxIterations(50);
+ boolean succeed = optimizer.optimize(objective,stats,compositeStop);
+
+ // copied from depparser l1lmaxobjective
+// ProjectedOptimizerStats stats = new ProjectedOptimizerStats();
+// GenericPickFirstStep pickFirstStep = new GenericPickFirstStep(1);
+// LineSearchMethod linesearch = new WolfRuleLineSearch(pickFirstStep, c1, c2);
+// ProjectedGradientDescent optimizer = new ProjectedGradientDescent(linesearch);
+// optimizer.setMaxIterations(maxProjectionIterations);
+// StopingCriteria stopGrad = new NormalizedProjectedGradientL2Norm(stoppingPrecision);
+// StopingCriteria stopValue = new NormalizedValueDifference(stoppingPrecision);
+// CompositeStopingCriteria stop = new CompositeStopingCriteria();
+// stop.add(stopGrad);
+// stop.add(stopValue);
+// boolean succeed = optimizer.optimize(objective, stats, stop);
+
+ //System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
+ //System.out.println("Solution: " + objective.parameters);
+ if (!succeed)
+ System.out.println("Failed to optimize");
+ //System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
+
+ // make sure we update the dual params
+ //llh = objective.getValue();
+ llh = objective.primal();
+ // FIXME: this is the dual not the primal and omits the llh term
+
+ for (int i = 0; i < training.length; ++i)
+ {
+ PhraseAndContexts instance = training[i];
+ for (int j = 0; j < instance.contexts.length; ++j)
+ {
+ Context c = instance.contexts[j];
+ for (int t = 0; t < numTags; t++)
+ {
+ double p = objective.q.get(i).get(j).get(t);
+ priorCounts[i][t] += c.count * p;
+ for (int k = 0; k < c.tokens.length; ++k)
+ emissionsCounts[k][t][c.tokens[k]] += c.count * p;
+ }
+ }
+ }
+ }
+ else
+ {
+ for (int i = 0; i < training.length; ++i)
+ {
+ PhraseAndContexts instance = training[i];
+ for (Context ctx : instance.contexts)
+ {
+ double probs[] = posterior(i, ctx);
+ double z = normalise(probs);
+ llh += log(z) * ctx.count;
+
+ for (int t = 0; t < numTags; ++t)
+ {
+ priorCounts[i][t] += ctx.count * probs[t];
+ for (int j = 0; j < ctx.tokens.length; ++j)
+ emissionsCounts[j][t][ctx.tokens[j]] += ctx.count * probs[t];
+ }
+ }
+ }
+ }
+
+ // M-step: normalise
+ for (double[][] emissionTW : emissionsCounts)
+ for (double[] emissionW : emissionTW)
+ normalise(emissionW);
+
+ for (double[] priorTag : priorCounts)
+ normalise(priorTag);
+
+ emissions = emissionsCounts;
+ prior = priorCounts;
+
+ System.out.println("Iteration " + iteration + " llh " + llh);
+
+// if (llh - lastLlh < 1e-4)
+// break;
+// else
+// lastLlh = llh;
+ }
+ }
+
+ static double normalise(double probs[])
+ {
+ double z = 0;
+ for (double p : probs)
+ z += p;
+ for (int i = 0; i < probs.length; ++i)
+ probs[i] /= z;
+ return z;
+ }
+
+ void randomise(double probs[])
+ {
+ double z = 0;
+ for (int i = 0; i < probs.length; ++i)
+ {
+ probs[i] = 10 + rng.nextDouble();
+ z += probs[i];
+ }
+
+ for (int i = 0; i < probs.length; ++i)
+ probs[i] /= z;
+ }
+
+ static int argmax(double probs[])
+ {
+ double m = Double.NEGATIVE_INFINITY;
+ int mi = -1;
+ for (int i = 0; i < probs.length; ++i)
+ {
+ if (probs[i] > m)
+ {
+ m = probs[i];
+ mi = i;
+ }
+ }
+ return mi;
+ }
+
+ double[] posterior(int phraseId, Context c) // unnormalised
+ {
+ double probs[] = new double[numTags];
+ for (int t = 0; t < numTags; ++t)
+ {
+ probs[t] = prior[phraseId][t];
+ for (int j = 0; j < c.tokens.length; ++j)
+ probs[t] *= emissions[j][t][c.tokens[j]];
+ }
+ return probs;
+ }
+
+ private void readTrainingFromFile(Reader in) throws IOException
+ {
+ // read in line-by-line
+ BufferedReader bin = new BufferedReader(in);
+ String line;
+ List<PhraseAndContexts> instances = new ArrayList<PhraseAndContexts>();
+ Pattern separator = Pattern.compile(" \\|\\|\\| ");
+
+ while ((line = bin.readLine()) != null)
+ {
+ // split into phrase and contexts
+ StringTokenizer st = new StringTokenizer(line, "\t");
+ assert (st.hasMoreTokens());
+ String phrase = st.nextToken();
+ assert (st.hasMoreTokens());
+ String rest = st.nextToken();
+ assert (!st.hasMoreTokens());
+
+ // process phrase
+ st = new StringTokenizer(phrase, " ");
+ List<Integer> ptoks = new ArrayList<Integer>();
+ while (st.hasMoreTokens())
+ ptoks.add(tokenLexicon.insert(st.nextToken()));
+
+ // process contexts
+ ArrayList<Context> contexts = new ArrayList<Context>();
+ String[] parts = separator.split(rest);
+ assert (parts.length % 2 == 0);
+ for (int i = 0; i < parts.length; i += 2)
+ {
+ // process pairs of strings - context and count
+ ArrayList<Integer> ctx = new ArrayList<Integer>();
+ String ctxString = parts[i];
+ String countString = parts[i + 1];
+ StringTokenizer ctxStrtok = new StringTokenizer(ctxString, " ");
+ while (ctxStrtok.hasMoreTokens())
+ {
+ String token = ctxStrtok.nextToken();
+ if (!token.equals("<PHRASE>"))
+ ctx.add(tokenLexicon.insert(token));
+ }
+
+ assert (countString.startsWith("C="));
+ Context c = new Context();
+ c.count = Integer.parseInt(countString.substring(2).trim());
+ // damn unboxing doesn't work with toArray
+ c.tokens = new int[ctx.size()];
+ for (int k = 0; k < ctx.size(); ++k)
+ c.tokens[k] = ctx.get(k);
+ contexts.add(c);
+
+ numEdges += 1;
+ }
+
+ // package up
+ PhraseAndContexts instance = new PhraseAndContexts();
+ // damn unboxing doesn't work with toArray
+ instance.phraseTokens = new int[ptoks.size()];
+ for (int k = 0; k < ptoks.size(); ++k)
+ instance.phraseTokens[k] = ptoks.get(k);
+ instance.contexts = contexts.toArray(new Context[] {});
+ instances.add(instance);
+ }
+
+ training = instances.toArray(new PhraseAndContexts[] {});
+
+ System.out.println("Read in " + training.length + " phrases and " + numEdges + " edges");
+ }
+
+ void displayPosterior()
+ {
+ for (int i = 0; i < training.length; ++i)
+ {
+ PhraseAndContexts instance = training[i];
+ for (Context ctx : instance.contexts)
+ {
+ double probs[] = posterior(i, ctx);
+ normalise(probs);
+
+ // emit phrase
+ for (int t : instance.phraseTokens)
+ System.out.print(tokenLexicon.lookup(t) + " ");
+ System.out.print("\t");
+ for (int c : ctx.tokens)
+ System.out.print(tokenLexicon.lookup(c) + " ");
+ System.out.print("||| C=" + ctx.count + " |||");
+
+ int t = argmax(probs);
+ System.out.print(" " + t + " ||| " + probs[t]);
+ // for (int t = 0; t < numTags; ++t)
+ // System.out.print(" " + probs[t]);
+ System.out.println();
+ }
+ }
+ }
+
+ class PhraseAndContexts
+ {
+ int phraseTokens[];
+ Context contexts[];
+ }
+
+ class Context
+ {
+ int count;
+ int[] tokens;
+ }
+
+ public static void main(String[] args)
+ {
+ assert (args.length >= 2);
+ try
+ {
+ PhraseContextModel model = new PhraseContextModel(new File(args[0]), Integer.parseInt(args[1]));
+ model.expectationMaximisation(Integer.parseInt(args[2]));
+ model.displayPosterior();
+ }
+ catch (IOException e)
+ {
+ System.out.println("Failed to read input file: " + args[0]);
+ e.printStackTrace();
+ }
+ }
+
+ class EStepDualObjective extends ProjectedObjective
+ {
+ List<List<TDoubleArrayList>> conditionals; // phrase id x context # x tag - precomputed
+ List<List<TDoubleArrayList>> q; // ditto, but including exp(-lambda) terms
+ double objective = 0; // log(z)
+ // Objective.gradient = d log(z) / d lambda = E_q[phi]
+ double llh = 0;
+
+ public EStepDualObjective()
+ {
+ super();
+ // compute conditionals p(context, tag | phrase) for all training instances
+ conditionals = new ArrayList<List<TDoubleArrayList>>(training.length);
+ q = new ArrayList<List<TDoubleArrayList>>(training.length);
+ for (int i = 0; i < training.length; ++i)
+ {
+ PhraseAndContexts instance = training[i];
+ conditionals.add(new ArrayList<TDoubleArrayList>(instance.contexts.length));
+ q.add(new ArrayList<TDoubleArrayList>(instance.contexts.length));
+
+ for (int j = 0; j < instance.contexts.length; ++j)
+ {
+ Context c = instance.contexts[j];
+ double probs[] = posterior(i, c);
+ double z = normalise(probs);
+ llh += log(z) * c.count;
+ conditionals.get(i).add(new TDoubleArrayList(probs));
+ q.get(i).add(new TDoubleArrayList(probs));
+ }
+ }
+
+ gradient = new double[numEdges*numTags];
+ setInitialParameters(lambda);
+ }
+
+ @Override
+ public double[] projectPoint(double[] point)
+ {
+ SimplexProjection p = new SimplexProjection(constraintScale);
+
+ double[] newPoint = point.clone();
+ int edgeIndex = 0;
+ for (int i = 0; i < training.length; ++i)
+ {
+ PhraseAndContexts instance = training[i];
+
+ for (int t = 0; t < numTags; t++)
+ {
+ double[] subPoint = new double[instance.contexts.length];
+ for (int j = 0; j < instance.contexts.length; ++j)
+ subPoint[j] = point[edgeIndex+j*numTags+t];
+
+ p.project(subPoint);
+ for (int j = 0; j < instance.contexts.length; ++j)
+ newPoint[edgeIndex+j*numTags+t] = subPoint[j];
+ }
+
+ edgeIndex += instance.contexts.length * numTags;
+ }
+ //System.out.println("Project point: " + Arrays.toString(point)
+ // + " => " + Arrays.toString(newPoint));
+ return newPoint;
+ }
+
+ @Override
+ public void setParameters(double[] params)
+ {
+ super.setParameters(params);
+ computeObjectiveAndGradient();
+ }
+
+ @Override
+ public double[] getGradient()
+ {
+ return gradient;
+ }
+
+ @Override
+ public double getValue()
+ {
+ return objective;
+ }
+
+ public void computeObjectiveAndGradient()
+ {
+ int edgeIndex = 0;
+ objective = 0;
+ Arrays.fill(gradient, 0);
+ for (int i = 0; i < training.length; ++i)
+ {
+ PhraseAndContexts instance = training[i];
+
+ for (int j = 0; j < instance.contexts.length; ++j)
+ {
+ Context c = instance.contexts[j];
+
+ double z = 0;
+ for (int t = 0; t < numTags; t++)
+ {
+ double v = conditionals.get(i).get(j).get(t) * exp(-parameters[edgeIndex+t]);
+ q.get(i).get(j).set(t, v);
+ z += v;
+ }
+ objective = log(z) * c.count;
+
+ for (int t = 0; t < numTags; t++)
+ {
+ double v = q.get(i).get(j).get(t) / z;
+ q.get(i).get(j).set(t, v);
+ gradient[edgeIndex+t] -= c.count * v;
+ }
+
+ edgeIndex += numTags;
+ }
+ }
+// System.out.println("computeObjectiveAndGradient logz=" + objective);
+// System.out.println("gradient=" + Arrays.toString(gradient));
+ }
+
+ public String toString()
+ {
+ StringBuilder sb = new StringBuilder();
+ sb.append(getClass().getCanonicalName()).append(" with ");
+ sb.append(parameters.length).append(" parameters and ");
+ sb.append(training.length * numTags).append(" constraints");
+ return sb.toString();
+ }
+
+ double primal()
+ {
+ // primal = llh + KL(q||p) + scale * sum_pt max_c E_q[phi_pct]
+ // kl = sum_Y q(Y) log q(Y) / p(Y|X)
+ // = sum_Y q(Y) { -lambda . phi(Y) - log Z }
+ // = -log Z - lambda . E_q[phi]
+
+ double kl = -objective - MathUtils.dotProduct(parameters, gradient);
+ double l1lmax = 0;
+ for (int i = 0; i < training.length; ++i)
+ {
+ PhraseAndContexts instance = training[i];
+ for (int t = 0; t < numTags; t++)
+ {
+ double lmax = Double.NEGATIVE_INFINITY;
+ for (int j = 0; j < instance.contexts.length; ++j)
+ lmax = max(lmax, q.get(i).get(j).get(t));
+ l1lmax += lmax;
+ }
+ }
+
+ return llh + kl + constraintScale * l1lmax;
+ }
+ }
}