From ab3534c45f463e541f3baf05006a50b64e3bbe31 Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Mon, 28 Jun 2010 19:34:58 +0000 Subject: First bits of code for PR training git-svn-id: https://ws10smt.googlecode.com/svn/trunk@44 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../PhraseContextModel.java | 317 +++++++++++++++++++++ 1 file changed, 317 insertions(+) create mode 100644 gi/posterior-regularisation/PhraseContextModel.java (limited to 'gi/posterior-regularisation/PhraseContextModel.java') diff --git a/gi/posterior-regularisation/PhraseContextModel.java b/gi/posterior-regularisation/PhraseContextModel.java new file mode 100644 index 00000000..3af72d54 --- /dev/null +++ b/gi/posterior-regularisation/PhraseContextModel.java @@ -0,0 +1,317 @@ +// Input of the form: +// " the phantom of the opera " tickets for tonight ? ||| C=1 ||| seats for ? ||| C=1 ||| i see ? ||| C=1 +// phrase TAB [context]+ +// where context = phrase ||| C=... which are separated by ||| + +// Model parameterised as follows: +// - each phrase, p, is allocated a latent state, t +// - this is used to generate the contexts, c +// - each context is generated using 4 independent multinomials, one for each position LL, L, R, RR + +// Training with EM: +// - e-step is estimating q(t) = P(t|p,c) for all x,c +// - m-step is estimating model parameters P(c,t|p) = P(t) P(c|t) +// - PR uses alternate e-step, which first optimizes lambda +// min_q KL(q||p) + delta sum_pt max_c E_q[phi_ptc] +// where +// q(t|p,c) propto p(t,c|p) exp( -phi_ptc ) +// Then q is used to obtain expectations for vanilla M-step. + +// Sexing it up: +// - learn p-specific conditionals P(t|p) +// - or generate phrase internals, e.g., generate edge words from +// different distribution to central words +// - agreement between phrase->context model and context->phrase model + +import java.io.*; +import java.util.*; +import java.util.regex.*; +import static java.lang.Math.*; + +class Lexicon +{ + public int insert(T word) + { + Integer i = wordToIndex.get(word); + if (i == null) + { + i = indexToWord.size(); + wordToIndex.put(word, i); + indexToWord.add(word); + } + return i; + } + + public T lookup(int index) + { + return indexToWord.get(index); + } + + public int size() + { + return indexToWord.size(); + } + + private Map wordToIndex = new HashMap(); + private List indexToWord = new ArrayList(); +} + +class PhraseContextModel +{ + // model/optimisation configuration parameters + int numTags; + int numPRIterations = 5; + boolean posteriorRegularisation = false; + double constraintScale = 10; + + // book keeping + Lexicon tokenLexicon = new Lexicon(); + int numPositions; + Random rng = new Random(); + + // training set; 1 entry for each unique phrase + PhraseAndContexts training[]; + + // model parameters (learnt) + double emissions[][][]; // position in 0 .. 3 x tag x word Pr(word | tag, position) + double prior[][]; // phrase x tag Pr(tag | phrase) + //double lambda[][][]; // word x context x tag langrange multipliers + + PhraseContextModel(File infile, int tags) throws IOException + { + numTags = tags; + readTrainingFromFile(new FileReader(infile)); + assert(training.length > 0); + + // now initialise emissions + assert(training[0].contexts.length > 0); + numPositions = training[0].contexts[0].tokens.length; + + emissions = new double[numPositions][numTags][tokenLexicon.size()]; + prior = new double[training.length][numTags]; + //lambda = new double[tokenLexicon.size()][???][numTags] + + for (double[][] emissionTW: emissions) + for (double[] emissionW: emissionTW) + randomise(emissionW); + + for (double[] priorTag: prior) + randomise(priorTag); + } + + void expectationMaximisation(int numIterations) + { + for (int iteration = 0; iteration < numIterations; ++iteration) + { + double emissionsCounts[][][] = new double[numPositions][numTags][tokenLexicon.size()]; + double priorCounts[][] = new double[training.length][numTags]; + + // E-step + double llh = 0; + for (int i = 0; i < training.length; ++i) + { + PhraseAndContexts instance = training[i]; + for (Context ctx: instance.contexts) + { + double probs[] = posterior(i, ctx); + double z = normalise(probs); + llh += log(z) * ctx.count; + + for (int t = 0; t < numTags; ++t) + { + priorCounts[i][t] += ctx.count * probs[t]; + for (int j = 0; j < ctx.tokens.length; ++j) + emissionsCounts[j][t][ctx.tokens[j]] += ctx.count * probs[t]; + } + } + } + + // M-step: normalise + for (double[][] emissionTW: emissionsCounts) + for (double[] emissionW: emissionTW) + normalise(emissionW); + + for (double[] priorTag: priorCounts) + normalise(priorTag); + + emissions = emissionsCounts; + prior = priorCounts; + + System.out.println("Iteration " + iteration + " llh " + llh); + } + } + + static double normalise(double probs[]) + { + double z = 0; + for (double p: probs) + z += p; + for (int i = 0; i < probs.length; ++i) + probs[i] /= z; + return z; + } + + void randomise(double probs[]) + { + double z = 0; + for (int i = 0; i < probs.length; ++i) + { + probs[i] = 10 + rng.nextDouble(); + z += probs[i]; + } + + for (int i = 0; i < probs.length; ++i) + probs[i] /= z; + } + + static int argmax(double probs[]) + { + double m = Double.NEGATIVE_INFINITY; + int mi = -1; + for (int i = 0; i < probs.length; ++i) + { + if (probs[i] > m) + { + m = probs[i]; + mi = i; + } + } + return mi; + } + + double[] posterior(int phraseId, Context c) // unnormalised + { + double probs[] = new double[numTags]; + for (int t = 0; t < numTags; ++t) + { + probs[t] = prior[phraseId][t]; + for (int j = 0; j < c.tokens.length; ++j) + probs[t] *= emissions[j][t][c.tokens[j]]; + } + return probs; + } + + private void readTrainingFromFile(Reader in) throws IOException + { + // read in line-by-line + BufferedReader bin = new BufferedReader(in); + String line; + List instances = new ArrayList(); + Pattern separator = Pattern.compile(" \\|\\|\\| "); + + int numEdges = 0; + while ((line = bin.readLine()) != null) + { + // split into phrase and contexts + StringTokenizer st = new StringTokenizer(line, "\t"); + assert(st.hasMoreTokens()); + String phrase = st.nextToken(); + assert(st.hasMoreTokens()); + String rest = st.nextToken(); + assert(!st.hasMoreTokens()); + + // process phrase + st = new StringTokenizer(phrase, " "); + List ptoks = new ArrayList(); + while (st.hasMoreTokens()) + ptoks.add(tokenLexicon.insert(st.nextToken())); + + // process contexts + ArrayList contexts = new ArrayList(); + String[] parts = separator.split(rest); + assert(parts.length % 2 == 0); + for (int i = 0; i < parts.length; i += 2) + { + // process pairs of strings - context and count + ArrayList ctx = new ArrayList(); + String ctxString = parts[i]; + String countString = parts[i+1]; + StringTokenizer ctxStrtok = new StringTokenizer(ctxString, " "); + while (ctxStrtok.hasMoreTokens()) + { + String token = ctxStrtok.nextToken(); + if (!token.equals("")) + ctx.add(tokenLexicon.insert(token)); + } + + assert(countString.startsWith("C=")); + Context c = new Context(); + c.count = Integer.parseInt(countString.substring(2).trim()); + // damn unboxing doesn't work with toArray + c.tokens = new int[ctx.size()]; + for (int k = 0; k < ctx.size(); ++k) + c.tokens[k] = ctx.get(k); + contexts.add(c); + + numEdges += 1; + } + + // package up + PhraseAndContexts instance = new PhraseAndContexts(); + // damn unboxing doesn't work with toArray + instance.phraseTokens = new int[ptoks.size()]; + for (int k = 0; k < ptoks.size(); ++k) + instance.phraseTokens[k] = ptoks.get(k); + instance.contexts = contexts.toArray(new Context[] {}); + instances.add(instance); + } + + training = instances.toArray(new PhraseAndContexts[] {}); + + System.out.println("Read in " + training.length + " phrases and " + numEdges + " edges"); + } + + void displayPosterior() + { + for (int i = 0; i < training.length; ++i) + { + PhraseAndContexts instance = training[i]; + for (Context ctx: instance.contexts) + { + double probs[] = posterior(i, ctx); + double z = normalise(probs); + + // emit phrase + for (int t: instance.phraseTokens) + System.out.print(tokenLexicon.lookup(t) + " "); + System.out.print("\t"); + for (int c: ctx.tokens) + System.out.print(tokenLexicon.lookup(c) + " "); + System.out.print("||| C=" + ctx.count + " |||"); + + System.out.print(" " + argmax(probs)); + //for (int t = 0; t < numTags; ++t) + //System.out.print(" " + probs[t]); + System.out.println(); + } + } + } + + class PhraseAndContexts + { + int phraseTokens[]; + Context contexts[]; + } + + class Context + { + int count; + int[] tokens; + } + + public static void main(String []args) + { + assert(args.length >= 2); + try + { + PhraseContextModel model = new PhraseContextModel(new File(args[0]), Integer.parseInt(args[1])); + model.expectationMaximisation(Integer.parseInt(args[2])); + model.displayPosterior(); + } + catch (IOException e) + { + System.out.println("Failed to read input file: " + args[0]); + e.printStackTrace(); + } + } +} -- cgit v1.2.3