diff options
Diffstat (limited to 'gi')
-rw-r--r-- | gi/posterior-regularisation/Corpus.java | 120 | ||||
-rw-r--r-- | gi/posterior-regularisation/PhraseContextModel.java | 314 |
2 files changed, 155 insertions, 279 deletions
diff --git a/gi/posterior-regularisation/Corpus.java b/gi/posterior-regularisation/Corpus.java index 047e6ee8..07b27387 100644 --- a/gi/posterior-regularisation/Corpus.java +++ b/gi/posterior-regularisation/Corpus.java @@ -7,69 +7,64 @@ import java.util.regex.Pattern; public class Corpus { private Lexicon<String> tokenLexicon = new Lexicon<String>(); - private Lexicon<TIntArrayList> ngramLexicon = new Lexicon<TIntArrayList>(); + private Lexicon<TIntArrayList> phraseLexicon = new Lexicon<TIntArrayList>(); + private Lexicon<TIntArrayList> contextLexicon = new Lexicon<TIntArrayList>(); private List<Edge> edges = new ArrayList<Edge>(); - private Map<Ngram,List<Edge>> phraseToContext = new HashMap<Ngram,List<Edge>>(); - private Map<Ngram,List<Edge>> contextToPhrase = new HashMap<Ngram,List<Edge>>(); + private List<List<Edge>> phraseToContext = new ArrayList<List<Edge>>(); + private List<List<Edge>> contextToPhrase = new ArrayList<List<Edge>>(); - public class Ngram + public class Edge { - private Ngram(int id) + Edge(int phraseId, int contextId, int count) { - ngramId = id; + this.phraseId = phraseId; + this.contextId = contextId; + this.count = count; } - public int getId() + public int getPhraseId() { - return ngramId; + return phraseId; } - public TIntArrayList getTokenIds() + public TIntArrayList getPhrase() { - return ngramLexicon.lookup(ngramId); + return phraseLexicon.lookup(phraseId); } - public String toString() + public String getPhraseString() { StringBuffer b = new StringBuffer(); - for (int tid: getTokenIds().toNativeArray()) + for (int tid: getPhrase().toNativeArray()) { if (b.length() > 0) b.append(" "); b.append(tokenLexicon.lookup(tid)); } return b.toString(); - } - public int hashCode() - { - return ngramId; - } - public boolean equals(Object other) - { - return other instanceof Ngram && ngramId == ((Ngram) other).ngramId; - } - private int ngramId; - } - - public class Edge - { - Edge(Ngram phrase, Ngram context, int count) + } + public int getContextId() { - this.phrase = phrase; - this.context = context; - this.count = count; + return contextId; } - public Ngram getPhrase() + public TIntArrayList getContext() { - return phrase; + return contextLexicon.lookup(contextId); } - public Ngram getContext() + public String getContextString() { - return context; + StringBuffer b = new StringBuffer(); + for (int tid: getContext().toNativeArray()) + { + if (b.length() > 0) + b.append(" "); + b.append(tokenLexicon.lookup(tid)); + } + return b.toString(); } public int getCount() { return count; } - private Ngram phrase; - private Ngram context; + private int phraseId; + private int contextId; private int count; } @@ -78,32 +73,32 @@ public class Corpus return edges; } - int numEdges() + int getNumEdges() { return edges.size(); } - Set<Ngram> getPhrases() + int getNumPhrases() { - return phraseToContext.keySet(); + return phraseLexicon.size(); } - List<Edge> getEdgesForPhrase(Ngram phrase) + List<Edge> getEdgesForPhrase(int phraseId) { - return phraseToContext.get(phrase); + return phraseToContext.get(phraseId); } - Set<Ngram> getContexts() + int getNumContexts() { - return contextToPhrase.keySet(); + return contextLexicon.size(); } - List<Edge> getEdgesForContext(Ngram context) + List<Edge> getEdgesForContext(int contextId) { - return contextToPhrase.get(context); + return contextToPhrase.get(contextId); } - int numTokens() + int getNumTokens() { return tokenLexicon.size(); } @@ -132,8 +127,9 @@ public class Corpus TIntArrayList ptoks = new TIntArrayList(); while (st.hasMoreTokens()) ptoks.add(c.tokenLexicon.insert(st.nextToken())); - int phraseId = c.ngramLexicon.insert(ptoks); - Ngram phrase = c.new Ngram(phraseId); + int phraseId = c.phraseLexicon.insert(ptoks); + if (phraseId == c.phraseToContext.size()) + c.phraseToContext.add(new ArrayList<Edge>()); // process contexts String[] parts = separator.split(rest); @@ -151,30 +147,18 @@ public class Corpus if (!token.equals("<PHRASE>")) ctx.add(c.tokenLexicon.insert(token)); } - int contextId = c.ngramLexicon.insert(ctx); - Ngram context = c.new Ngram(contextId); + int contextId = c.contextLexicon.insert(ctx); + if (contextId == c.contextToPhrase.size()) + c.contextToPhrase.add(new ArrayList<Edge>()); assert (countString.startsWith("C=")); - Edge e = c.new Edge(phrase, context, Integer.parseInt(countString.substring(2).trim())); + Edge e = c.new Edge(phraseId, contextId, + Integer.parseInt(countString.substring(2).trim())); c.edges.add(e); - // index the edge for fast phrase lookup - List<Edge> edges = c.phraseToContext.get(phrase); - if (edges == null) - { - edges = new ArrayList<Edge>(); - c.phraseToContext.put(phrase, edges); - } - edges.add(e); - - // index the edge for fast context lookup - edges = c.contextToPhrase.get(context); - if (edges == null) - { - edges = new ArrayList<Edge>(); - c.contextToPhrase.put(context, edges); - } - edges.add(e); + // index the edge for fast phrase, context lookup + c.phraseToContext.get(phraseId).add(e); + c.contextToPhrase.get(contextId).add(e); } } diff --git a/gi/posterior-regularisation/PhraseContextModel.java b/gi/posterior-regularisation/PhraseContextModel.java index d0a92dde..c48cfacd 100644 --- a/gi/posterior-regularisation/PhraseContextModel.java +++ b/gi/posterior-regularisation/PhraseContextModel.java @@ -40,44 +40,16 @@ import optimization.stopCriteria.ProjectedGradientL2Norm; import optimization.stopCriteria.StopingCriteria; import optimization.stopCriteria.ValueDifference; import optimization.util.MathUtils; - import java.util.*; import java.util.regex.*; import gnu.trove.TDoubleArrayList; +import gnu.trove.TIntArrayList; import static java.lang.Math.*; -class Lexicon<T> -{ - public int insert(T word) - { - Integer i = wordToIndex.get(word); - if (i == null) - { - i = indexToWord.size(); - wordToIndex.put(word, i); - indexToWord.add(word); - } - return i; - } - - public T lookup(int index) - { - return indexToWord.get(index); - } - - public int size() - { - return indexToWord.size(); - } - - private Map<T, Integer> wordToIndex = new HashMap<T, Integer>(); - private List<T> indexToWord = new ArrayList<T>(); -} - class PhraseContextModel { // model/optimisation configuration parameters - int numTags, numEdges; + int numTags; boolean posteriorRegularisation = true; double constraintScale = 3; // FIXME: make configurable @@ -88,33 +60,32 @@ class PhraseContextModel int minOccurrencesForProjection = 0; // book keeping - Lexicon<String> tokenLexicon = new Lexicon<String>(); int numPositions; Random rng = new Random(); - // training set; 1 entry for each unique phrase - PhraseAndContexts training[]; + // training set + Corpus training; // model parameters (learnt) double emissions[][][]; // position in 0 .. 3 x tag x word Pr(word | tag, position) double prior[][]; // phrase x tag Pr(tag | phrase) double lambda[]; // edge = (phrase, context) x tag flattened lagrange multipliers - PhraseContextModel(File infile, int tags) throws IOException + PhraseContextModel(Corpus training, int tags) { - numTags = tags; - numEdges = 0; - readTrainingFromFile(new FileReader(infile)); - assert (training.length > 0); + this.training = training; + this.numTags = tags; + assert (!training.getEdges().isEmpty()); + assert (numTags > 1); // now initialise emissions - assert (training[0].contexts.length > 0); - numPositions = training[0].contexts[0].tokens.length; + numPositions = training.getEdges().get(0).getContext().size(); + assert (numPositions > 0); - emissions = new double[numPositions][numTags][tokenLexicon.size()]; - prior = new double[training.length][numTags]; + emissions = new double[numPositions][numTags][training.getNumTokens()]; + prior = new double[training.getNumEdges()][numTags]; if (posteriorRegularisation) - lambda = new double[numEdges * numTags]; + lambda = new double[training.getNumEdges() * numTags]; for (double[][] emissionTW : emissions) for (double[] emissionW : emissionTW) @@ -130,8 +101,8 @@ class PhraseContextModel for (int iteration = 0; iteration < numIterations; ++iteration) { - double emissionsCounts[][][] = new double[numPositions][numTags][tokenLexicon.size()]; - double priorCounts[][] = new double[training.length][numTags]; + double emissionsCounts[][][] = new double[numPositions][numTags][training.getNumTokens()]; + double priorCounts[][] = new double[training.getNumPhrases()][numTags]; // E-step double llh = 0; @@ -140,71 +111,70 @@ class PhraseContextModel EStepDualObjective objective = new EStepDualObjective(); // copied from x2y2withconstraints - LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc(new InterpolationPickFirstStep(1)); - OptimizerStats stats = new OptimizerStats(); - ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls); - CompositeStopingCriteria compositeStop = new CompositeStopingCriteria(); - compositeStop.add(new ProjectedGradientL2Norm(0.001)); - compositeStop.add(new ValueDifference(0.001)); - optimizer.setMaxIterations(50); - boolean succeed = optimizer.optimize(objective,stats,compositeStop); +// LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc(new InterpolationPickFirstStep(1)); +// OptimizerStats stats = new OptimizerStats(); +// ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls); +// CompositeStopingCriteria compositeStop = new CompositeStopingCriteria(); +// compositeStop.add(new ProjectedGradientL2Norm(0.001)); +// compositeStop.add(new ValueDifference(0.001)); +// optimizer.setMaxIterations(50); +// boolean succeed = optimizer.optimize(objective,stats,compositeStop); // copied from depparser l1lmaxobjective -// ProjectedOptimizerStats stats = new ProjectedOptimizerStats(); -// GenericPickFirstStep pickFirstStep = new GenericPickFirstStep(1); -// LineSearchMethod linesearch = new WolfRuleLineSearch(pickFirstStep, c1, c2); -// ProjectedGradientDescent optimizer = new ProjectedGradientDescent(linesearch); -// optimizer.setMaxIterations(maxProjectionIterations); -// StopingCriteria stopGrad = new NormalizedProjectedGradientL2Norm(stoppingPrecision); -// StopingCriteria stopValue = new NormalizedValueDifference(stoppingPrecision); -// CompositeStopingCriteria stop = new CompositeStopingCriteria(); -// stop.add(stopGrad); -// stop.add(stopValue); -// boolean succeed = optimizer.optimize(objective, stats, stop); - - //System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1)); + ProjectedOptimizerStats stats = new ProjectedOptimizerStats(); + GenericPickFirstStep pickFirstStep = new GenericPickFirstStep(1); + LineSearchMethod linesearch = new WolfRuleLineSearch(pickFirstStep, c1, c2); + ProjectedGradientDescent optimizer = new ProjectedGradientDescent(linesearch); + optimizer.setMaxIterations(maxProjectionIterations); + CompositeStopingCriteria stop = new CompositeStopingCriteria(); + stop.add(new NormalizedProjectedGradientL2Norm(stoppingPrecision)); + stop.add(new NormalizedValueDifference(stoppingPrecision)); + boolean succeed = optimizer.optimize(objective, stats, stop); + + System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1)); //System.out.println("Solution: " + objective.parameters); if (!succeed) System.out.println("Failed to optimize"); //System.out.println("Ended optimization in " + optimizer.getCurrentIteration()); - // make sure we update the dual params - //llh = objective.getValue(); + lambda = objective.getParameters(); llh = objective.primal(); - // FIXME: this is the dual not the primal and omits the llh term - for (int i = 0; i < training.length; ++i) + for (int i = 0; i < training.getNumPhrases(); ++i) { - PhraseAndContexts instance = training[i]; - for (int j = 0; j < instance.contexts.length; ++j) + List<Corpus.Edge> edges = training.getEdgesForPhrase(i); + for (int j = 0; j < edges.size(); ++j) { - Context c = instance.contexts[j]; + Corpus.Edge e = edges.get(j); for (int t = 0; t < numTags; t++) { double p = objective.q.get(i).get(j).get(t); - priorCounts[i][t] += c.count * p; - for (int k = 0; k < c.tokens.length; ++k) - emissionsCounts[k][t][c.tokens[k]] += c.count * p; + priorCounts[i][t] += e.getCount() * p; + TIntArrayList tokens = e.getContext(); + for (int k = 0; k < tokens.size(); ++k) + emissionsCounts[k][t][tokens.get(k)] += e.getCount() * p; } } } } else { - for (int i = 0; i < training.length; ++i) + for (int i = 0; i < training.getNumPhrases(); ++i) { - PhraseAndContexts instance = training[i]; - for (Context ctx : instance.contexts) + List<Corpus.Edge> edges = training.getEdgesForPhrase(i); + for (int j = 0; j < edges.size(); ++j) { - double probs[] = posterior(i, ctx); + Corpus.Edge e = edges.get(j); + double probs[] = posterior(i, e); double z = normalise(probs); - llh += log(z) * ctx.count; - + llh += log(z) * e.getCount(); + + TIntArrayList tokens = e.getContext(); for (int t = 0; t < numTags; ++t) { - priorCounts[i][t] += ctx.count * probs[t]; - for (int j = 0; j < ctx.tokens.length; ++j) - emissionsCounts[j][t][ctx.tokens[j]] += ctx.count * probs[t]; + priorCounts[i][t] += e.getCount() * probs[t]; + for (int k = 0; k < tokens.size(); ++k) + emissionsCounts[j][t][tokens.get(k)] += e.getCount() * probs[t]; } } } @@ -268,104 +238,34 @@ class PhraseContextModel return mi; } - double[] posterior(int phraseId, Context c) // unnormalised + double[] posterior(int phraseId, Corpus.Edge e) // unnormalised { double probs[] = new double[numTags]; + TIntArrayList tokens = e.getContext(); for (int t = 0; t < numTags; ++t) { probs[t] = prior[phraseId][t]; - for (int j = 0; j < c.tokens.length; ++j) - probs[t] *= emissions[j][t][c.tokens[j]]; + for (int k = 0; k < tokens.size(); ++k) + probs[t] *= emissions[k][t][tokens.get(k)]; } return probs; } - private void readTrainingFromFile(Reader in) throws IOException - { - // read in line-by-line - BufferedReader bin = new BufferedReader(in); - String line; - List<PhraseAndContexts> instances = new ArrayList<PhraseAndContexts>(); - Pattern separator = Pattern.compile(" \\|\\|\\| "); - - while ((line = bin.readLine()) != null) - { - // split into phrase and contexts - StringTokenizer st = new StringTokenizer(line, "\t"); - assert (st.hasMoreTokens()); - String phrase = st.nextToken(); - assert (st.hasMoreTokens()); - String rest = st.nextToken(); - assert (!st.hasMoreTokens()); - - // process phrase - st = new StringTokenizer(phrase, " "); - List<Integer> ptoks = new ArrayList<Integer>(); - while (st.hasMoreTokens()) - ptoks.add(tokenLexicon.insert(st.nextToken())); - - // process contexts - ArrayList<Context> contexts = new ArrayList<Context>(); - String[] parts = separator.split(rest); - assert (parts.length % 2 == 0); - for (int i = 0; i < parts.length; i += 2) - { - // process pairs of strings - context and count - ArrayList<Integer> ctx = new ArrayList<Integer>(); - String ctxString = parts[i]; - String countString = parts[i + 1]; - StringTokenizer ctxStrtok = new StringTokenizer(ctxString, " "); - while (ctxStrtok.hasMoreTokens()) - { - String token = ctxStrtok.nextToken(); - if (!token.equals("<PHRASE>")) - ctx.add(tokenLexicon.insert(token)); - } - - assert (countString.startsWith("C=")); - Context c = new Context(); - c.count = Integer.parseInt(countString.substring(2).trim()); - // damn unboxing doesn't work with toArray - c.tokens = new int[ctx.size()]; - for (int k = 0; k < ctx.size(); ++k) - c.tokens[k] = ctx.get(k); - contexts.add(c); - - numEdges += 1; - } - - // package up - PhraseAndContexts instance = new PhraseAndContexts(); - // damn unboxing doesn't work with toArray - instance.phraseTokens = new int[ptoks.size()]; - for (int k = 0; k < ptoks.size(); ++k) - instance.phraseTokens[k] = ptoks.get(k); - instance.contexts = contexts.toArray(new Context[] {}); - instances.add(instance); - } - - training = instances.toArray(new PhraseAndContexts[] {}); - - System.out.println("Read in " + training.length + " phrases and " + numEdges + " edges"); - } - void displayPosterior() { - for (int i = 0; i < training.length; ++i) + for (int i = 0; i < training.getNumPhrases(); ++i) { - PhraseAndContexts instance = training[i]; - for (Context ctx : instance.contexts) + List<Corpus.Edge> edges = training.getEdgesForPhrase(i); + for (Corpus.Edge e: edges) { - double probs[] = posterior(i, ctx); + double probs[] = posterior(i, e); normalise(probs); // emit phrase - for (int t : instance.phraseTokens) - System.out.print(tokenLexicon.lookup(t) + " "); + System.out.print(e.getPhraseString()); System.out.print("\t"); - for (int c : ctx.tokens) - System.out.print(tokenLexicon.lookup(c) + " "); - System.out.print("||| C=" + ctx.count + " |||"); + System.out.print(e.getContextString()); + System.out.print("||| C=" + e.getCount() + " |||"); int t = argmax(probs); System.out.print(" " + t + " ||| " + probs[t]); @@ -376,24 +276,13 @@ class PhraseContextModel } } - class PhraseAndContexts - { - int phraseTokens[]; - Context contexts[]; - } - - class Context - { - int count; - int[] tokens; - } - public static void main(String[] args) { assert (args.length >= 2); try { - PhraseContextModel model = new PhraseContextModel(new File(args[0]), Integer.parseInt(args[1])); + Corpus corpus = Corpus.readFromFile(new FileReader(new File(args[0]))); + PhraseContextModel model = new PhraseContextModel(corpus, Integer.parseInt(args[1])); model.expectationMaximisation(Integer.parseInt(args[2])); model.displayPosterior(); } @@ -416,26 +305,27 @@ class PhraseContextModel { super(); // compute conditionals p(context, tag | phrase) for all training instances - conditionals = new ArrayList<List<TDoubleArrayList>>(training.length); - q = new ArrayList<List<TDoubleArrayList>>(training.length); - for (int i = 0; i < training.length; ++i) + conditionals = new ArrayList<List<TDoubleArrayList>>(training.getNumPhrases()); + q = new ArrayList<List<TDoubleArrayList>>(training.getNumPhrases()); + for (int i = 0; i < training.getNumPhrases(); ++i) { - PhraseAndContexts instance = training[i]; - conditionals.add(new ArrayList<TDoubleArrayList>(instance.contexts.length)); - q.add(new ArrayList<TDoubleArrayList>(instance.contexts.length)); + List<Corpus.Edge> edges = training.getEdgesForPhrase(i); - for (int j = 0; j < instance.contexts.length; ++j) + conditionals.add(new ArrayList<TDoubleArrayList>(edges.size())); + q.add(new ArrayList<TDoubleArrayList>(edges.size())); + + for (int j = 0; j < edges.size(); ++j) { - Context c = instance.contexts[j]; - double probs[] = posterior(i, c); + Corpus.Edge e = edges.get(j); + double probs[] = posterior(i, e); double z = normalise(probs); - llh += log(z) * c.count; + llh += log(z) * e.getCount(); conditionals.get(i).add(new TDoubleArrayList(probs)); q.get(i).add(new TDoubleArrayList(probs)); } } - gradient = new double[numEdges*numTags]; + gradient = new double[training.getNumEdges()*numTags]; setInitialParameters(lambda); } @@ -446,22 +336,22 @@ class PhraseContextModel double[] newPoint = point.clone(); int edgeIndex = 0; - for (int i = 0; i < training.length; ++i) + for (int i = 0; i < training.getNumPhrases(); ++i) { - PhraseAndContexts instance = training[i]; + List<Corpus.Edge> edges = training.getEdgesForPhrase(i); for (int t = 0; t < numTags; t++) { - double[] subPoint = new double[instance.contexts.length]; - for (int j = 0; j < instance.contexts.length; ++j) + double[] subPoint = new double[edges.size()]; + for (int j = 0; j < edges.size(); ++j) subPoint[j] = point[edgeIndex+j*numTags+t]; - + p.project(subPoint); - for (int j = 0; j < instance.contexts.length; ++j) + for (int j = 0; j < edges.size(); ++j) newPoint[edgeIndex+j*numTags+t] = subPoint[j]; } - edgeIndex += instance.contexts.length * numTags; + edgeIndex += edges.size() * numTags; } //System.out.println("Project point: " + Arrays.toString(point) // + " => " + Arrays.toString(newPoint)); @@ -492,13 +382,13 @@ class PhraseContextModel int edgeIndex = 0; objective = 0; Arrays.fill(gradient, 0); - for (int i = 0; i < training.length; ++i) + for (int i = 0; i < training.getNumPhrases(); ++i) { - PhraseAndContexts instance = training[i]; - - for (int j = 0; j < instance.contexts.length; ++j) + List<Corpus.Edge> edges = training.getEdgesForPhrase(i); + + for (int j = 0; j < edges.size(); ++j) { - Context c = instance.contexts[j]; + Corpus.Edge e = edges.get(j); double z = 0; for (int t = 0; t < numTags; t++) @@ -507,20 +397,21 @@ class PhraseContextModel q.get(i).get(j).set(t, v); z += v; } - objective = log(z) * c.count; + objective = log(z) * e.getCount(); for (int t = 0; t < numTags; t++) { double v = q.get(i).get(j).get(t) / z; q.get(i).get(j).set(t, v); - gradient[edgeIndex+t] -= c.count * v; + gradient[edgeIndex+t] -= e.getCount() * v; } edgeIndex += numTags; } } -// System.out.println("computeObjectiveAndGradient logz=" + objective); -// System.out.println("gradient=" + Arrays.toString(gradient)); + System.out.println("computeObjectiveAndGradient logz=" + objective); + System.out.println("lambda= " + Arrays.toString(parameters)); + System.out.println("gradient=" + Arrays.toString(gradient)); } public String toString() @@ -528,7 +419,7 @@ class PhraseContextModel StringBuilder sb = new StringBuilder(); sb.append(getClass().getCanonicalName()).append(" with "); sb.append(parameters.length).append(" parameters and "); - sb.append(training.length * numTags).append(" constraints"); + sb.append(training.getNumPhrases() * numTags).append(" constraints"); return sb.toString(); } @@ -538,16 +429,17 @@ class PhraseContextModel // kl = sum_Y q(Y) log q(Y) / p(Y|X) // = sum_Y q(Y) { -lambda . phi(Y) - log Z } // = -log Z - lambda . E_q[phi] + // = -objective + lambda . gradient - double kl = -objective - MathUtils.dotProduct(parameters, gradient); + double kl = -objective + MathUtils.dotProduct(parameters, gradient); double l1lmax = 0; - for (int i = 0; i < training.length; ++i) + for (int i = 0; i < training.getNumPhrases(); ++i) { - PhraseAndContexts instance = training[i]; + List<Corpus.Edge> edges = training.getEdgesForPhrase(i); for (int t = 0; t < numTags; t++) { double lmax = Double.NEGATIVE_INFINITY; - for (int j = 0; j < instance.contexts.length; ++j) + for (int j = 0; j < edges.size(); ++j) lmax = max(lmax, q.get(i).get(j).get(t)); l1lmax += lmax; } |