summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/PhraseContextModel.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/PhraseContextModel.java')
-rw-r--r--gi/posterior-regularisation/PhraseContextModel.java314
1 files changed, 103 insertions, 211 deletions
diff --git a/gi/posterior-regularisation/PhraseContextModel.java b/gi/posterior-regularisation/PhraseContextModel.java
index d0a92dde..c48cfacd 100644
--- a/gi/posterior-regularisation/PhraseContextModel.java
+++ b/gi/posterior-regularisation/PhraseContextModel.java
@@ -40,44 +40,16 @@ import optimization.stopCriteria.ProjectedGradientL2Norm;
import optimization.stopCriteria.StopingCriteria;
import optimization.stopCriteria.ValueDifference;
import optimization.util.MathUtils;
-
import java.util.*;
import java.util.regex.*;
import gnu.trove.TDoubleArrayList;
+import gnu.trove.TIntArrayList;
import static java.lang.Math.*;
-class Lexicon<T>
-{
- public int insert(T word)
- {
- Integer i = wordToIndex.get(word);
- if (i == null)
- {
- i = indexToWord.size();
- wordToIndex.put(word, i);
- indexToWord.add(word);
- }
- return i;
- }
-
- public T lookup(int index)
- {
- return indexToWord.get(index);
- }
-
- public int size()
- {
- return indexToWord.size();
- }
-
- private Map<T, Integer> wordToIndex = new HashMap<T, Integer>();
- private List<T> indexToWord = new ArrayList<T>();
-}
-
class PhraseContextModel
{
// model/optimisation configuration parameters
- int numTags, numEdges;
+ int numTags;
boolean posteriorRegularisation = true;
double constraintScale = 3; // FIXME: make configurable
@@ -88,33 +60,32 @@ class PhraseContextModel
int minOccurrencesForProjection = 0;
// book keeping
- Lexicon<String> tokenLexicon = new Lexicon<String>();
int numPositions;
Random rng = new Random();
- // training set; 1 entry for each unique phrase
- PhraseAndContexts training[];
+ // training set
+ Corpus training;
// model parameters (learnt)
double emissions[][][]; // position in 0 .. 3 x tag x word Pr(word | tag, position)
double prior[][]; // phrase x tag Pr(tag | phrase)
double lambda[]; // edge = (phrase, context) x tag flattened lagrange multipliers
- PhraseContextModel(File infile, int tags) throws IOException
+ PhraseContextModel(Corpus training, int tags)
{
- numTags = tags;
- numEdges = 0;
- readTrainingFromFile(new FileReader(infile));
- assert (training.length > 0);
+ this.training = training;
+ this.numTags = tags;
+ assert (!training.getEdges().isEmpty());
+ assert (numTags > 1);
// now initialise emissions
- assert (training[0].contexts.length > 0);
- numPositions = training[0].contexts[0].tokens.length;
+ numPositions = training.getEdges().get(0).getContext().size();
+ assert (numPositions > 0);
- emissions = new double[numPositions][numTags][tokenLexicon.size()];
- prior = new double[training.length][numTags];
+ emissions = new double[numPositions][numTags][training.getNumTokens()];
+ prior = new double[training.getNumEdges()][numTags];
if (posteriorRegularisation)
- lambda = new double[numEdges * numTags];
+ lambda = new double[training.getNumEdges() * numTags];
for (double[][] emissionTW : emissions)
for (double[] emissionW : emissionTW)
@@ -130,8 +101,8 @@ class PhraseContextModel
for (int iteration = 0; iteration < numIterations; ++iteration)
{
- double emissionsCounts[][][] = new double[numPositions][numTags][tokenLexicon.size()];
- double priorCounts[][] = new double[training.length][numTags];
+ double emissionsCounts[][][] = new double[numPositions][numTags][training.getNumTokens()];
+ double priorCounts[][] = new double[training.getNumPhrases()][numTags];
// E-step
double llh = 0;
@@ -140,71 +111,70 @@ class PhraseContextModel
EStepDualObjective objective = new EStepDualObjective();
// copied from x2y2withconstraints
- LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc(new InterpolationPickFirstStep(1));
- OptimizerStats stats = new OptimizerStats();
- ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
- CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
- compositeStop.add(new ProjectedGradientL2Norm(0.001));
- compositeStop.add(new ValueDifference(0.001));
- optimizer.setMaxIterations(50);
- boolean succeed = optimizer.optimize(objective,stats,compositeStop);
+// LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc(new InterpolationPickFirstStep(1));
+// OptimizerStats stats = new OptimizerStats();
+// ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
+// CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
+// compositeStop.add(new ProjectedGradientL2Norm(0.001));
+// compositeStop.add(new ValueDifference(0.001));
+// optimizer.setMaxIterations(50);
+// boolean succeed = optimizer.optimize(objective,stats,compositeStop);
// copied from depparser l1lmaxobjective
-// ProjectedOptimizerStats stats = new ProjectedOptimizerStats();
-// GenericPickFirstStep pickFirstStep = new GenericPickFirstStep(1);
-// LineSearchMethod linesearch = new WolfRuleLineSearch(pickFirstStep, c1, c2);
-// ProjectedGradientDescent optimizer = new ProjectedGradientDescent(linesearch);
-// optimizer.setMaxIterations(maxProjectionIterations);
-// StopingCriteria stopGrad = new NormalizedProjectedGradientL2Norm(stoppingPrecision);
-// StopingCriteria stopValue = new NormalizedValueDifference(stoppingPrecision);
-// CompositeStopingCriteria stop = new CompositeStopingCriteria();
-// stop.add(stopGrad);
-// stop.add(stopValue);
-// boolean succeed = optimizer.optimize(objective, stats, stop);
-
- //System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
+ ProjectedOptimizerStats stats = new ProjectedOptimizerStats();
+ GenericPickFirstStep pickFirstStep = new GenericPickFirstStep(1);
+ LineSearchMethod linesearch = new WolfRuleLineSearch(pickFirstStep, c1, c2);
+ ProjectedGradientDescent optimizer = new ProjectedGradientDescent(linesearch);
+ optimizer.setMaxIterations(maxProjectionIterations);
+ CompositeStopingCriteria stop = new CompositeStopingCriteria();
+ stop.add(new NormalizedProjectedGradientL2Norm(stoppingPrecision));
+ stop.add(new NormalizedValueDifference(stoppingPrecision));
+ boolean succeed = optimizer.optimize(objective, stats, stop);
+
+ System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
//System.out.println("Solution: " + objective.parameters);
if (!succeed)
System.out.println("Failed to optimize");
//System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
- // make sure we update the dual params
- //llh = objective.getValue();
+ lambda = objective.getParameters();
llh = objective.primal();
- // FIXME: this is the dual not the primal and omits the llh term
- for (int i = 0; i < training.length; ++i)
+ for (int i = 0; i < training.getNumPhrases(); ++i)
{
- PhraseAndContexts instance = training[i];
- for (int j = 0; j < instance.contexts.length; ++j)
+ List<Corpus.Edge> edges = training.getEdgesForPhrase(i);
+ for (int j = 0; j < edges.size(); ++j)
{
- Context c = instance.contexts[j];
+ Corpus.Edge e = edges.get(j);
for (int t = 0; t < numTags; t++)
{
double p = objective.q.get(i).get(j).get(t);
- priorCounts[i][t] += c.count * p;
- for (int k = 0; k < c.tokens.length; ++k)
- emissionsCounts[k][t][c.tokens[k]] += c.count * p;
+ priorCounts[i][t] += e.getCount() * p;
+ TIntArrayList tokens = e.getContext();
+ for (int k = 0; k < tokens.size(); ++k)
+ emissionsCounts[k][t][tokens.get(k)] += e.getCount() * p;
}
}
}
}
else
{
- for (int i = 0; i < training.length; ++i)
+ for (int i = 0; i < training.getNumPhrases(); ++i)
{
- PhraseAndContexts instance = training[i];
- for (Context ctx : instance.contexts)
+ List<Corpus.Edge> edges = training.getEdgesForPhrase(i);
+ for (int j = 0; j < edges.size(); ++j)
{
- double probs[] = posterior(i, ctx);
+ Corpus.Edge e = edges.get(j);
+ double probs[] = posterior(i, e);
double z = normalise(probs);
- llh += log(z) * ctx.count;
-
+ llh += log(z) * e.getCount();
+
+ TIntArrayList tokens = e.getContext();
for (int t = 0; t < numTags; ++t)
{
- priorCounts[i][t] += ctx.count * probs[t];
- for (int j = 0; j < ctx.tokens.length; ++j)
- emissionsCounts[j][t][ctx.tokens[j]] += ctx.count * probs[t];
+ priorCounts[i][t] += e.getCount() * probs[t];
+ for (int k = 0; k < tokens.size(); ++k)
+ emissionsCounts[j][t][tokens.get(k)] += e.getCount() * probs[t];
}
}
}
@@ -268,104 +238,34 @@ class PhraseContextModel
return mi;
}
- double[] posterior(int phraseId, Context c) // unnormalised
+ double[] posterior(int phraseId, Corpus.Edge e) // unnormalised
{
double probs[] = new double[numTags];
+ TIntArrayList tokens = e.getContext();
for (int t = 0; t < numTags; ++t)
{
probs[t] = prior[phraseId][t];
- for (int j = 0; j < c.tokens.length; ++j)
- probs[t] *= emissions[j][t][c.tokens[j]];
+ for (int k = 0; k < tokens.size(); ++k)
+ probs[t] *= emissions[k][t][tokens.get(k)];
}
return probs;
}
- private void readTrainingFromFile(Reader in) throws IOException
- {
- // read in line-by-line
- BufferedReader bin = new BufferedReader(in);
- String line;
- List<PhraseAndContexts> instances = new ArrayList<PhraseAndContexts>();
- Pattern separator = Pattern.compile(" \\|\\|\\| ");
-
- while ((line = bin.readLine()) != null)
- {
- // split into phrase and contexts
- StringTokenizer st = new StringTokenizer(line, "\t");
- assert (st.hasMoreTokens());
- String phrase = st.nextToken();
- assert (st.hasMoreTokens());
- String rest = st.nextToken();
- assert (!st.hasMoreTokens());
-
- // process phrase
- st = new StringTokenizer(phrase, " ");
- List<Integer> ptoks = new ArrayList<Integer>();
- while (st.hasMoreTokens())
- ptoks.add(tokenLexicon.insert(st.nextToken()));
-
- // process contexts
- ArrayList<Context> contexts = new ArrayList<Context>();
- String[] parts = separator.split(rest);
- assert (parts.length % 2 == 0);
- for (int i = 0; i < parts.length; i += 2)
- {
- // process pairs of strings - context and count
- ArrayList<Integer> ctx = new ArrayList<Integer>();
- String ctxString = parts[i];
- String countString = parts[i + 1];
- StringTokenizer ctxStrtok = new StringTokenizer(ctxString, " ");
- while (ctxStrtok.hasMoreTokens())
- {
- String token = ctxStrtok.nextToken();
- if (!token.equals("<PHRASE>"))
- ctx.add(tokenLexicon.insert(token));
- }
-
- assert (countString.startsWith("C="));
- Context c = new Context();
- c.count = Integer.parseInt(countString.substring(2).trim());
- // damn unboxing doesn't work with toArray
- c.tokens = new int[ctx.size()];
- for (int k = 0; k < ctx.size(); ++k)
- c.tokens[k] = ctx.get(k);
- contexts.add(c);
-
- numEdges += 1;
- }
-
- // package up
- PhraseAndContexts instance = new PhraseAndContexts();
- // damn unboxing doesn't work with toArray
- instance.phraseTokens = new int[ptoks.size()];
- for (int k = 0; k < ptoks.size(); ++k)
- instance.phraseTokens[k] = ptoks.get(k);
- instance.contexts = contexts.toArray(new Context[] {});
- instances.add(instance);
- }
-
- training = instances.toArray(new PhraseAndContexts[] {});
-
- System.out.println("Read in " + training.length + " phrases and " + numEdges + " edges");
- }
-
void displayPosterior()
{
- for (int i = 0; i < training.length; ++i)
+ for (int i = 0; i < training.getNumPhrases(); ++i)
{
- PhraseAndContexts instance = training[i];
- for (Context ctx : instance.contexts)
+ List<Corpus.Edge> edges = training.getEdgesForPhrase(i);
+ for (Corpus.Edge e: edges)
{
- double probs[] = posterior(i, ctx);
+ double probs[] = posterior(i, e);
normalise(probs);
// emit phrase
- for (int t : instance.phraseTokens)
- System.out.print(tokenLexicon.lookup(t) + " ");
+ System.out.print(e.getPhraseString());
System.out.print("\t");
- for (int c : ctx.tokens)
- System.out.print(tokenLexicon.lookup(c) + " ");
- System.out.print("||| C=" + ctx.count + " |||");
+ System.out.print(e.getContextString());
+ System.out.print("||| C=" + e.getCount() + " |||");
int t = argmax(probs);
System.out.print(" " + t + " ||| " + probs[t]);
@@ -376,24 +276,13 @@ class PhraseContextModel
}
}
- class PhraseAndContexts
- {
- int phraseTokens[];
- Context contexts[];
- }
-
- class Context
- {
- int count;
- int[] tokens;
- }
-
public static void main(String[] args)
{
assert (args.length >= 2);
try
{
- PhraseContextModel model = new PhraseContextModel(new File(args[0]), Integer.parseInt(args[1]));
+ Corpus corpus = Corpus.readFromFile(new FileReader(new File(args[0])));
+ PhraseContextModel model = new PhraseContextModel(corpus, Integer.parseInt(args[1]));
model.expectationMaximisation(Integer.parseInt(args[2]));
model.displayPosterior();
}
@@ -416,26 +305,27 @@ class PhraseContextModel
{
super();
// compute conditionals p(context, tag | phrase) for all training instances
- conditionals = new ArrayList<List<TDoubleArrayList>>(training.length);
- q = new ArrayList<List<TDoubleArrayList>>(training.length);
- for (int i = 0; i < training.length; ++i)
+ conditionals = new ArrayList<List<TDoubleArrayList>>(training.getNumPhrases());
+ q = new ArrayList<List<TDoubleArrayList>>(training.getNumPhrases());
+ for (int i = 0; i < training.getNumPhrases(); ++i)
{
- PhraseAndContexts instance = training[i];
- conditionals.add(new ArrayList<TDoubleArrayList>(instance.contexts.length));
- q.add(new ArrayList<TDoubleArrayList>(instance.contexts.length));
+ List<Corpus.Edge> edges = training.getEdgesForPhrase(i);
- for (int j = 0; j < instance.contexts.length; ++j)
+ conditionals.add(new ArrayList<TDoubleArrayList>(edges.size()));
+ q.add(new ArrayList<TDoubleArrayList>(edges.size()));
+
+ for (int j = 0; j < edges.size(); ++j)
{
- Context c = instance.contexts[j];
- double probs[] = posterior(i, c);
+ Corpus.Edge e = edges.get(j);
+ double probs[] = posterior(i, e);
double z = normalise(probs);
- llh += log(z) * c.count;
+ llh += log(z) * e.getCount();
conditionals.get(i).add(new TDoubleArrayList(probs));
q.get(i).add(new TDoubleArrayList(probs));
}
}
- gradient = new double[numEdges*numTags];
+ gradient = new double[training.getNumEdges()*numTags];
setInitialParameters(lambda);
}
@@ -446,22 +336,22 @@ class PhraseContextModel
double[] newPoint = point.clone();
int edgeIndex = 0;
- for (int i = 0; i < training.length; ++i)
+ for (int i = 0; i < training.getNumPhrases(); ++i)
{
- PhraseAndContexts instance = training[i];
+ List<Corpus.Edge> edges = training.getEdgesForPhrase(i);
for (int t = 0; t < numTags; t++)
{
- double[] subPoint = new double[instance.contexts.length];
- for (int j = 0; j < instance.contexts.length; ++j)
+ double[] subPoint = new double[edges.size()];
+ for (int j = 0; j < edges.size(); ++j)
subPoint[j] = point[edgeIndex+j*numTags+t];
-
+
p.project(subPoint);
- for (int j = 0; j < instance.contexts.length; ++j)
+ for (int j = 0; j < edges.size(); ++j)
newPoint[edgeIndex+j*numTags+t] = subPoint[j];
}
- edgeIndex += instance.contexts.length * numTags;
+ edgeIndex += edges.size() * numTags;
}
//System.out.println("Project point: " + Arrays.toString(point)
// + " => " + Arrays.toString(newPoint));
@@ -492,13 +382,13 @@ class PhraseContextModel
int edgeIndex = 0;
objective = 0;
Arrays.fill(gradient, 0);
- for (int i = 0; i < training.length; ++i)
+ for (int i = 0; i < training.getNumPhrases(); ++i)
{
- PhraseAndContexts instance = training[i];
-
- for (int j = 0; j < instance.contexts.length; ++j)
+ List<Corpus.Edge> edges = training.getEdgesForPhrase(i);
+
+ for (int j = 0; j < edges.size(); ++j)
{
- Context c = instance.contexts[j];
+ Corpus.Edge e = edges.get(j);
double z = 0;
for (int t = 0; t < numTags; t++)
@@ -507,20 +397,21 @@ class PhraseContextModel
q.get(i).get(j).set(t, v);
z += v;
}
- objective = log(z) * c.count;
+ objective = log(z) * e.getCount();
for (int t = 0; t < numTags; t++)
{
double v = q.get(i).get(j).get(t) / z;
q.get(i).get(j).set(t, v);
- gradient[edgeIndex+t] -= c.count * v;
+ gradient[edgeIndex+t] -= e.getCount() * v;
}
edgeIndex += numTags;
}
}
-// System.out.println("computeObjectiveAndGradient logz=" + objective);
-// System.out.println("gradient=" + Arrays.toString(gradient));
+ System.out.println("computeObjectiveAndGradient logz=" + objective);
+ System.out.println("lambda= " + Arrays.toString(parameters));
+ System.out.println("gradient=" + Arrays.toString(gradient));
}
public String toString()
@@ -528,7 +419,7 @@ class PhraseContextModel
StringBuilder sb = new StringBuilder();
sb.append(getClass().getCanonicalName()).append(" with ");
sb.append(parameters.length).append(" parameters and ");
- sb.append(training.length * numTags).append(" constraints");
+ sb.append(training.getNumPhrases() * numTags).append(" constraints");
return sb.toString();
}
@@ -538,16 +429,17 @@ class PhraseContextModel
// kl = sum_Y q(Y) log q(Y) / p(Y|X)
// = sum_Y q(Y) { -lambda . phi(Y) - log Z }
// = -log Z - lambda . E_q[phi]
+ // = -objective + lambda . gradient
- double kl = -objective - MathUtils.dotProduct(parameters, gradient);
+ double kl = -objective + MathUtils.dotProduct(parameters, gradient);
double l1lmax = 0;
- for (int i = 0; i < training.length; ++i)
+ for (int i = 0; i < training.getNumPhrases(); ++i)
{
- PhraseAndContexts instance = training[i];
+ List<Corpus.Edge> edges = training.getEdgesForPhrase(i);
for (int t = 0; t < numTags; t++)
{
double lmax = Double.NEGATIVE_INFINITY;
- for (int j = 0; j < instance.contexts.length; ++j)
+ for (int j = 0; j < edges.size(); ++j)
lmax = max(lmax, q.get(i).get(j).get(t));
l1lmax += lmax;
}