From 5a8ea689c8a4e9cf3e72f88a253b08153bf32dde Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Mon, 19 Jul 2010 22:28:10 +0000 Subject: Reversed out broken thresholding git-svn-id: https://ws10smt.googlecode.com/svn/trunk@324 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../prjava/src/phrase/Corpus.java | 199 ++++----------------- .../prjava/src/phrase/PhraseCluster.java | 80 ++++----- .../prjava/src/phrase/Trainer.java | 46 ++--- 3 files changed, 92 insertions(+), 233 deletions(-) diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java index 2afc18dc..ad092cc6 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java @@ -18,9 +18,6 @@ public class Corpus public int splitSentinel; public int phraseSentinel; public int rareSentinel; - private boolean[] rareWords; - private boolean[] rarePhrases; - private boolean[] rareContexts; public Corpus() { @@ -45,10 +42,6 @@ public class Corpus { return Corpus.this.getPhrase(phraseId); } - public TIntArrayList getRawPhrase() - { - return Corpus.this.getRawPhrase(phraseId); - } public String getPhraseString() { return Corpus.this.getPhraseString(phraseId); @@ -61,10 +54,6 @@ public class Corpus { return Corpus.this.getContext(contextId); } - public TIntArrayList getRawContext() - { - return Corpus.this.getRawContext(contextId); - } public String getContextString(boolean insertPhraseSentinel) { return Corpus.this.getContextString(contextId, insertPhraseSentinel); @@ -144,28 +133,6 @@ public class Corpus } public TIntArrayList getPhrase(int phraseId) - { - TIntArrayList phrase = phraseLexicon.lookup(phraseId); - if (rareWords != null) - { - boolean first = true; - for (int i = 0; i < phrase.size(); ++i) - { - if (rareWords[phrase.get(i)]) - { - if (first) - { - phrase = (TIntArrayList) phrase.clone(); - first = false; - } - phrase.set(i, rareSentinel); - } - } - } - return phrase; - } - - public TIntArrayList getRawPhrase(int phraseId) { return phraseLexicon.lookup(phraseId); } @@ -173,7 +140,7 @@ public class Corpus public String getPhraseString(int phraseId) { StringBuffer b = new StringBuffer(); - for (int tid: getRawPhrase(phraseId).toNativeArray()) + for (int tid: getPhrase(phraseId).toNativeArray()) { if (b.length() > 0) b.append(" "); @@ -183,28 +150,6 @@ public class Corpus } public TIntArrayList getContext(int contextId) - { - TIntArrayList context = contextLexicon.lookup(contextId); - if (rareWords != null) - { - boolean first = true; - for (int i = 0; i < context.size(); ++i) - { - if (rareWords[context.get(i)]) - { - if (first) - { - context = (TIntArrayList) context.clone(); - first = false; - } - context.set(i, rareSentinel); - } - } - } - return context; - } - - public TIntArrayList getRawContext(int contextId) { return contextLexicon.lookup(contextId); } @@ -212,7 +157,7 @@ public class Corpus public String getContextString(int contextId, boolean insertPhraseSentinel) { StringBuffer b = new StringBuffer(); - TIntArrayList c = getRawContext(contextId); + TIntArrayList c = getContext(contextId); for (int i = 0; i < c.size(); ++i) { if (i > 0) b.append(" "); @@ -227,15 +172,14 @@ public class Corpus return wordId == splitSentinel || wordId == phraseSentinel; } - static Corpus readFromFile(Reader in) throws IOException - { - Corpus c = new Corpus(); - + List readEdges(Reader in) throws IOException + { // read in line-by-line BufferedReader bin = new BufferedReader(in); String line; Pattern separator = Pattern.compile(" \\|\\|\\| "); - + + List edges = new ArrayList(); while ((line = bin.readLine()) != null) { // split into phrase and contexts @@ -250,10 +194,8 @@ public class Corpus st = new StringTokenizer(phraseToks, " "); TIntArrayList ptoks = new TIntArrayList(); while (st.hasMoreTokens()) - ptoks.add(c.wordLexicon.insert(st.nextToken())); - int phraseId = c.phraseLexicon.insert(ptoks); - if (phraseId == c.phraseToContext.size()) - c.phraseToContext.add(new ArrayList()); + ptoks.add(wordLexicon.insert(st.nextToken())); + int phraseId = phraseLexicon.insert(ptoks); // process contexts String[] parts = separator.split(rest); @@ -261,34 +203,45 @@ public class Corpus for (int i = 0; i < parts.length; i += 2) { // process pairs of strings - context and count - TIntArrayList ctx = new TIntArrayList(); String ctxString = parts[i]; String countString = parts[i + 1]; + + assert (countString.startsWith("C=")); + int count = Integer.parseInt(countString.substring(2).trim()); + + TIntArrayList ctx = new TIntArrayList(); StringTokenizer ctxStrtok = new StringTokenizer(ctxString, " "); while (ctxStrtok.hasMoreTokens()) { String token = ctxStrtok.nextToken(); - //if (!token.equals("")) - ctx.add(c.wordLexicon.insert(token)); + ctx.add(wordLexicon.insert(token)); } - int contextId = c.contextLexicon.insert(ctx); - if (contextId == c.contextToPhrase.size()) - c.contextToPhrase.add(new ArrayList()); + int contextId = contextLexicon.insert(ctx); - assert (countString.startsWith("C=")); - Edge e = c.new Edge(phraseId, contextId, - Integer.parseInt(countString.substring(2).trim())); - c.edges.add(e); - - // index the edge for fast phrase, context lookup - c.phraseToContext.get(phraseId).add(e); - c.contextToPhrase.get(contextId).add(e); + edges.add(new Edge(phraseId, contextId, count)); } } - - return c; + return edges; } + static Corpus readFromFile(Reader in) throws IOException + { + Corpus c = new Corpus(); + c.edges = c.readEdges(in); + for (Edge edge: c.edges) + { + while (edge.getPhraseId() >= c.phraseToContext.size()) + c.phraseToContext.add(new ArrayList()); + while (edge.getContextId() >= c.contextToPhrase.size()) + c.contextToPhrase.add(new ArrayList()); + + // index the edge for fast phrase, context lookup + c.phraseToContext.get(edge.getPhraseId()).add(edge); + c.contextToPhrase.get(edge.getContextId()).add(edge); + } + return c; + } + TIntArrayList phraseEdges(TIntArrayList phrase) { TIntArrayList r = new TIntArrayList(4); @@ -307,86 +260,4 @@ public class Corpus out.println("Corpus has " + edges.size() + " edges " + phraseLexicon.size() + " phrases " + contextLexicon.size() + " contexts and " + wordLexicon.size() + " word types"); } - - public void applyWordThreshold(int wordThreshold) - { - int[] counts = new int[wordLexicon.size()]; - for (Edge e: edges) - { - TIntArrayList phrase = e.getPhrase(); - for (int i = 0; i < phrase.size(); ++i) - counts[phrase.get(i)] += e.getCount(); - - TIntArrayList context = e.getContext(); - for (int i = 0; i < context.size(); ++i) - counts[context.get(i)] += e.getCount(); - } - - int count = 0; - rareWords = new boolean[wordLexicon.size()]; - for (int i = 0; i < wordLexicon.size(); ++i) - { - rareWords[i] = counts[i] < wordThreshold; - if (rareWords[i]) - count++; - } - System.err.println("There are " + count + " rare words"); - } - - public void applyPhraseThreshold(int threshold) - { - rarePhrases = new boolean[phraseLexicon.size()]; - - int n = 0; - for (int i = 0; i < phraseLexicon.size(); ++i) - { - List contexts = phraseToContext.get(i); - int count = 0; - for (Edge edge: contexts) - { - count += edge.getCount(); - if (count >= threshold) - break; - } - - if (count < threshold) - { - rarePhrases[i] = true; - n++; - } - } - System.err.println("There are " + n + " rare phrases"); - } - - public void applyContextThreshold(int threshold) - { - rareContexts = new boolean[contextLexicon.size()]; - - int n = 0; - for (int i = 0; i < contextLexicon.size(); ++i) - { - List phrases = contextToPhrase.get(i); - int count = 0; - for (Edge edge: phrases) - { - count += edge.getCount(); - if (count >= threshold) - break; - } - - if (count < threshold) - { - rareContexts[i] = true; - n++; - } - } - System.err.println("There are " + n + " rare contexts"); - } - - boolean isRare(Edge edge) - { - if (rarePhrases != null && rarePhrases[edge.getPhraseId()] == true) return true; - if (rareContexts != null && rareContexts[edge.getContextId()] == true) return true; - return false; - } } \ No newline at end of file diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 9ee766d4..e7e4af32 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -23,7 +23,7 @@ import phrase.Corpus.Edge; public class PhraseCluster { public int K; - private int n_phrases, n_words, n_contexts, n_positions, edge_threshold; + private int n_phrases, n_words, n_contexts, n_positions; public Corpus c; public ExecutorService pool; @@ -44,7 +44,6 @@ public class PhraseCluster { n_phrases=c.getNumPhrases(); n_contexts=c.getNumContexts(); n_positions=c.getNumContextPositions(); - edge_threshold=0; emit=new double [K][n_positions][n_words]; pi=new double[n_phrases][K]; @@ -76,7 +75,7 @@ public class PhraseCluster { pool = Executors.newFixedThreadPool(threads); } - public double EM(boolean skipBigPhrases) + public double EM(int phraseSizeLimit) { double [][][]exp_emit=new double [K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; @@ -92,19 +91,17 @@ public class PhraseCluster { //E for(int phrase=0; phrase < n_phrases; phrase++) { - if (skipBigPhrases && c.getPhrase(phrase).size() >= 2) + if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit) { System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K); continue; - } + } List contexts = c.getEdgesForPhrase(phrase); for (int ctx=0; ctx= 2) + for(int phrase=0; phrase= 1 && c.getPhrase(phrase).size() > phraseSizeLimit) { System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K); continue; @@ -328,7 +325,7 @@ public class PhraseCluster { return primal; } - public double PREM_phrase_constraints_parallel(final double scalePT, boolean skipBigPhrases) + public double PREM_phrase_constraints_parallel(final double scalePT, int phraseSizeLimit) { assert(pool != null); @@ -338,12 +335,11 @@ public class PhraseCluster { double [][][]exp_emit=new double [K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; - if (skipBigPhrases) - { - for(double [][]i:exp_emit) - for(double []j:i) - Arrays.fill(j, 1e-100); - } + for(double [][]i:exp_emit) + for(double []j:i) + Arrays.fill(j, 1e-10); + for(double []j:pi) + Arrays.fill(j, 1e-10); double loglikelihood=0, kl=0, l1lmax=0, primal=0; final AtomicInteger failures = new AtomicInteger(0); @@ -356,12 +352,13 @@ public class PhraseCluster { //E for(int phrase=0;phrase= 2) + if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit) { n -= 1; System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K); continue; } + final int p=phrase; pool.execute(new Runnable() { public void run() { @@ -445,10 +442,8 @@ public class PhraseCluster { return primal; } - public double PREM_phrase_context_constraints(double scalePT, double scaleCT, boolean skipBigPhrases) + public double PREM_phrase_context_constraints(double scalePT, double scaleCT) { - assert !skipBigPhrases : "Not supported yet - FIXME!"; //FIXME - double[][][] exp_emit = new double [K][n_positions][n_words]; double[][] exp_pi = new double[n_phrases][K]; @@ -500,10 +495,14 @@ public class PhraseCluster { */ public double[] posterior(Corpus.Edge edge) { - double[] prob=Arrays.copyOf(pi[edge.getPhraseId()], K); - - //if (edge.getCount() < edge_threshold) - //System.out.println("Edge: " + edge + " probs for phrase " + Arrays.toString(prob)); + double[] prob; + if (edge.getPhraseId() < n_phrases) + prob = Arrays.copyOf(pi[edge.getPhraseId()], K); + else + { + prob = new double[K]; + Arrays.fill(prob, 1.0); + } TIntArrayList ctx = edge.getContext(); for(int tag=0;tag testing) { - for (Edge edge : c.getEdges()) + for (Edge edge : testing) { double probs[] = posterior(edge); arr.F.l1normalize(probs); @@ -660,9 +653,4 @@ public class PhraseCluster { emit[t][p][w] = v; } } - - public void setEdgeThreshold(int edgeThreshold) - { - this.edge_threshold = edgeThreshold; - } } diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index 7f0b1970..ec1a5804 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -7,8 +7,11 @@ import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; import java.io.PrintStream; +import java.util.List; import java.util.Random; +import phrase.Corpus.Edge; + import arr.F; public class Trainer @@ -34,10 +37,6 @@ public class Trainer parser.accepts("agree"); parser.accepts("no-parameter-cache"); parser.accepts("skip-large-phrases").withRequiredArg().ofType(Integer.class).defaultsTo(5); - parser.accepts("rare-word").withRequiredArg().ofType(Integer.class).defaultsTo(10); - parser.accepts("rare-edge").withRequiredArg().ofType(Integer.class).defaultsTo(1); - parser.accepts("rare-phrase").withRequiredArg().ofType(Integer.class).defaultsTo(2); - parser.accepts("rare-context").withRequiredArg().ofType(Integer.class).defaultsTo(2); OptionSet options = parser.parse(args); if (options.has("help") || !options.has("in")) @@ -61,10 +60,6 @@ public class Trainer double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0; double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0; int skip = (Integer) options.valueOf("skip-large-phrases"); - int wordThreshold = (Integer) options.valueOf("rare-word"); - int edgeThreshold = (Integer) options.valueOf("rare-edge"); - int phraseThreshold = (Integer) options.valueOf("rare-phrase"); - int contextThreshold = (Integer) options.valueOf("rare-context"); if (options.has("seed")) F.rng = new Random((Long) options.valueOf("seed")); @@ -86,14 +81,7 @@ public class Trainer e.printStackTrace(); System.exit(1); } - - if (wordThreshold > 1) - corpus.applyWordThreshold(wordThreshold); - if (phraseThreshold > 1) - corpus.applyPhraseThreshold(phraseThreshold); - if (contextThreshold > 1) - corpus.applyContextThreshold(contextThreshold); - + if (!options.has("agree")) System.out.println("Running with " + tags + " tags " + "for " + iterations + " iterations " + @@ -127,7 +115,6 @@ public class Trainer e.printStackTrace(); } } - cluster.setEdgeThreshold(edgeThreshold); } double last = 0; @@ -138,20 +125,24 @@ public class Trainer o = agree.EM(); else { + if (i < skip) + System.out.println("Skipping phrases of length > " + (i+1)); + if (scale_phrase <= 0 && scale_context <= 0) { if (!vb) - o = cluster.EM(i < skip); + o = cluster.EM((i < skip) ? i+1 : 0); else - o = cluster.VBEM(alphaEmit, alphaPi, i < skip); + o = cluster.VBEM(alphaEmit, alphaPi); } else - o = cluster.PREM(scale_phrase, scale_context, i < skip); + o = cluster.PREM(scale_phrase, scale_context, (i < skip) ? i+1 : 0); } System.out.println("ITER: "+i+" objective: " + o); - if (i != 0 && Math.abs((o - last) / o) < threshold) + // sometimes takes a few iterations to break the ties + if (i > 5 && Math.abs((o - last) / o) < threshold) { last = o; break; @@ -171,10 +162,19 @@ public class Trainer File outfile = (File) options.valueOf("out"); try { PrintStream ps = FileUtil.printstream(outfile); - cluster.displayPosterior(ps); + List test; + if (!options.has("test")) + test = corpus.getEdges(); + else + { + infile = (File) options.valueOf("test"); + System.out.println("Reading testing concordance from " + infile); + test = corpus.readEdges(FileUtil.reader(infile)); + } + cluster.displayPosterior(ps, test); ps.close(); } catch (IOException e) { - System.err.println("Failed to open output file: " + outfile); + System.err.println("Failed to open either testing file or output file"); e.printStackTrace(); System.exit(1); } -- cgit v1.2.3