From 1207aaee1f55dbaac8a46f37635a4d1baf392760 Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Fri, 16 Jul 2010 21:34:28 +0000 Subject: Added various flags to filter out low count events (words, edges). git-svn-id: https://ws10smt.googlecode.com/svn/trunk@298 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/posterior-regularisation/prjava/src/arr/F.java | 13 +- .../prjava/src/phrase/Corpus.java | 83 +++++++++++- .../prjava/src/phrase/PhraseCluster.java | 147 ++++++++++++++++----- .../prjava/src/phrase/Trainer.java | 48 ++++--- 4 files changed, 228 insertions(+), 63 deletions(-) (limited to 'gi/posterior-regularisation/prjava/src') diff --git a/gi/posterior-regularisation/prjava/src/arr/F.java b/gi/posterior-regularisation/prjava/src/arr/F.java index 2201179e..0f74cbab 100644 --- a/gi/posterior-regularisation/prjava/src/arr/F.java +++ b/gi/posterior-regularisation/prjava/src/arr/F.java @@ -1,5 +1,6 @@ package arr; +import java.util.Arrays; import java.util.Random; public class F { @@ -36,11 +37,13 @@ public class F { for(int i=0;i> contextToPhrase = new ArrayList>(); public int splitSentinel; public int phraseSentinel; - + public int rareSentinel; + private boolean[] wordIsRare; + public Corpus() { splitSentinel = wordLexicon.insert(""); phraseSentinel = wordLexicon.insert(""); + rareSentinel = wordLexicon.insert(""); } public class Edge @@ -40,6 +43,10 @@ public class Corpus { return Corpus.this.getPhrase(phraseId); } + public TIntArrayList getRawPhrase() + { + return Corpus.this.getRawPhrase(phraseId); + } public String getPhraseString() { return Corpus.this.getPhraseString(phraseId); @@ -52,7 +59,10 @@ public class Corpus { return Corpus.this.getContext(contextId); } - public String getContextString(boolean insertPhraseSentinel) + public TIntArrayList getRawContext() + { + return Corpus.this.getRawContext(contextId); + } public String getContextString(boolean insertPhraseSentinel) { return Corpus.this.getContextString(contextId, insertPhraseSentinel); } @@ -131,6 +141,28 @@ public class Corpus } public TIntArrayList getPhrase(int phraseId) + { + TIntArrayList phrase = phraseLexicon.lookup(phraseId); + if (wordIsRare != null) + { + boolean first = true; + for (int i = 0; i < phrase.size(); ++i) + { + if (wordIsRare[phrase.get(i)]) + { + if (first) + { + phrase = (TIntArrayList) phrase.clone(); + first = false; + } + phrase.set(i, rareSentinel); + } + } + } + return phrase; + } + + public TIntArrayList getRawPhrase(int phraseId) { return phraseLexicon.lookup(phraseId); } @@ -138,7 +170,7 @@ public class Corpus public String getPhraseString(int phraseId) { StringBuffer b = new StringBuffer(); - for (int tid: getPhrase(phraseId).toNativeArray()) + for (int tid: getRawPhrase(phraseId).toNativeArray()) { if (b.length() > 0) b.append(" "); @@ -148,6 +180,28 @@ public class Corpus } public TIntArrayList getContext(int contextId) + { + TIntArrayList context = contextLexicon.lookup(contextId); + if (wordIsRare != null) + { + boolean first = true; + for (int i = 0; i < context.size(); ++i) + { + if (wordIsRare[context.get(i)]) + { + if (first) + { + context = (TIntArrayList) context.clone(); + first = false; + } + context.set(i, rareSentinel); + } + } + } + return context; + } + + public TIntArrayList getRawContext(int contextId) { return contextLexicon.lookup(contextId); } @@ -155,7 +209,7 @@ public class Corpus public String getContextString(int contextId, boolean insertPhraseSentinel) { StringBuffer b = new StringBuffer(); - TIntArrayList c = getContext(contextId); + TIntArrayList c = getRawContext(contextId); for (int i = 0; i < c.size(); ++i) { if (i > 0) b.append(" "); @@ -249,5 +303,24 @@ public class Corpus { out.println("Corpus has " + edges.size() + " edges " + phraseLexicon.size() + " phrases " + contextLexicon.size() + " contexts and " + wordLexicon.size() + " word types"); - } + } + + public void applyWordThreshold(int wordThreshold) + { + int[] counts = new int[wordLexicon.size()]; + for (Edge e: edges) + { + TIntArrayList phrase = e.getPhrase(); + for (int i = 0; i < phrase.size(); ++i) + counts[phrase.get(i)] += e.getCount(); + + TIntArrayList context = e.getContext(); + for (int i = 0; i < context.size(); ++i) + counts[context.get(i)] += e.getCount(); + } + + wordIsRare = new boolean[wordLexicon.size()]; + for (int i = 0; i < wordLexicon.size(); ++i) + wordIsRare[i] = counts[i] < wordThreshold; + } } \ No newline at end of file diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 5efaf52e..feab5eda 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -2,14 +2,20 @@ package phrase; import gnu.trove.TIntArrayList; import org.apache.commons.math.special.Gamma; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; import java.io.PrintStream; import java.util.Arrays; import java.util.List; +import java.util.StringTokenizer; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import java.util.regex.Pattern; import phrase.Corpus.Edge; @@ -17,7 +23,7 @@ import phrase.Corpus.Edge; public class PhraseCluster { public int K; - private int n_phrases, n_words, n_contexts, n_positions; + private int n_phrases, n_words, n_contexts, n_positions, edge_threshold; public Corpus c; public ExecutorService pool; @@ -38,6 +44,7 @@ public class PhraseCluster { n_phrases=c.getNumPhrases(); n_contexts=c.getNumContexts(); n_positions=c.getNumContextPositions(); + edge_threshold=0; emit=new double [K][n_positions][n_words]; pi=new double[n_phrases][K]; @@ -74,12 +81,11 @@ public class PhraseCluster { double [][][]exp_emit=new double [K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; - if (skipBigPhrases) - { - for(double [][]i:exp_emit) - for(double []j:i) - Arrays.fill(j, 1e-100); - } + for(double [][]i:exp_emit) + for(double []j:i) + Arrays.fill(j, 1e-10); + for(double []j:pi) + Arrays.fill(j, 1e-10); double loglikelihood=0; @@ -97,6 +103,9 @@ public class PhraseCluster { for (int ctx=0; ctx 0; @@ -121,7 +130,7 @@ public class PhraseCluster { arr.F.l1normalize(j); for(double []j:exp_pi) - arr.F.l1normalize(j); + arr.F.l1normalize(j); emit=exp_emit; pi=exp_pi; @@ -250,12 +259,11 @@ public class PhraseCluster { double [][][]exp_emit=new double[K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; - if (skipBigPhrases) - { - for(double [][]i:exp_emit) - for(double []j:i) - Arrays.fill(j, 1e-100); - } + for(double [][]i:exp_emit) + for(double []j:i) + Arrays.fill(j, 1e-10); + for(double []j:pi) + Arrays.fill(j, 1e-10); if (lambdaPT == null && cacheLambda) lambdaPT = new double[n_phrases][]; @@ -271,6 +279,7 @@ public class PhraseCluster { continue; } + // FIXME: add rare edge check to phrase objective & posterior processing PhraseObjective po = new PhraseObjective(this, phrase, scalePT, (cacheLambda) ? lambdaPT[phrase] : null); boolean ok = po.optimizeWithProjectedGradientDescent(); if (!ok) ++failures; @@ -493,11 +502,25 @@ public class PhraseCluster { { double[] prob=Arrays.copyOf(pi[edge.getPhraseId()], K); + //if (edge.getCount() < edge_threshold) + //System.out.println("Edge: " + edge + " probs for phrase " + Arrays.toString(prob)); + TIntArrayList ctx = edge.getContext(); for(int tag=0;tag EPS) - ps.print("\t" + j + ": " + pi[i][j]); - } - ps.println(); - } - - ps.println("P(word|tag,position)"); + ps.println(i + " " + j + " " + pi[i][j]); + + ps.println(); for (int i = 0; i < K; ++i) { - for(int position=0;position EPS) - ps.print(c.getWord(word)+"="+emit[i][position][word]+"\t"); + ps.println(i + " " + position + " " + word + " " + emit[i][position][word]); } - ps.println(); } - ps.println(); } - } double phrase_l1lmax() @@ -586,4 +603,66 @@ public class PhraseCluster { } return sum; } + + public void loadParameters(BufferedReader input) throws IOException + { + final double EPS = 1e-50; + + // overwrite pi, emit with ~zeros + for(double [][]i:emit) + for(double []j:i) + Arrays.fill(j, EPS); + + for(double []j:pi) + Arrays.fill(j, EPS); + + String line = input.readLine(); + assert line != null; + + Pattern space = Pattern.compile(" +"); + String[] parts = space.split(line); + assert parts.length == 6; + + assert parts[0].equals("phrases"); + int phrases = Integer.parseInt(parts[1]); + int tags = Integer.parseInt(parts[3]); + int positions = Integer.parseInt(parts[5]); + + assert phrases == n_phrases; + assert tags == K; + assert positions == n_positions; + + // read in pi + while ((line = input.readLine()) != null) + { + line = line.trim(); + if (line.isEmpty()) break; + + String[] tokens = space.split(line); + assert tokens.length == 3; + int p = Integer.parseInt(tokens[0]); + int t = Integer.parseInt(tokens[1]); + double v = Double.parseDouble(tokens[2]); + + pi[p][t] = v; + } + + // read in emissions + while ((line = input.readLine()) != null) + { + String[] tokens = space.split(line); + assert tokens.length == 4; + int t = Integer.parseInt(tokens[0]); + int p = Integer.parseInt(tokens[1]); + int w = Integer.parseInt(tokens[2]); + double v = Double.parseDouble(tokens[3]); + + emit[t][p][w] = v; + } + } + + public void setEdgeThreshold(int edgeThreshold) + { + this.edge_threshold = edgeThreshold; + } } diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index a67c17a2..ed7a6bbe 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -4,6 +4,7 @@ import io.FileUtil; import joptsimple.OptionParser; import joptsimple.OptionSet; import java.io.File; +import java.io.FileNotFoundException; import java.io.IOException; import java.io.PrintStream; import java.util.Random; @@ -18,12 +19,12 @@ public class Trainer parser.accepts("help"); parser.accepts("in").withRequiredArg().ofType(File.class); parser.accepts("out").withRequiredArg().ofType(File.class); + parser.accepts("start").withRequiredArg().ofType(File.class); parser.accepts("parameters").withRequiredArg().ofType(File.class); parser.accepts("topics").withRequiredArg().ofType(Integer.class).defaultsTo(5); - parser.accepts("em-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(5); - parser.accepts("pr-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(0); + parser.accepts("iterations").withRequiredArg().ofType(Integer.class).defaultsTo(10); parser.accepts("threads").withRequiredArg().ofType(Integer.class).defaultsTo(0); - parser.accepts("scale-phrase").withRequiredArg().ofType(Double.class).defaultsTo(5.0); + parser.accepts("scale-phrase").withRequiredArg().ofType(Double.class).defaultsTo(0.0); parser.accepts("scale-context").withRequiredArg().ofType(Double.class).defaultsTo(0.0); parser.accepts("seed").withRequiredArg().ofType(Long.class).defaultsTo(0l); parser.accepts("convergence-threshold").withRequiredArg().ofType(Double.class).defaultsTo(1e-6); @@ -33,6 +34,8 @@ public class Trainer parser.accepts("agree"); parser.accepts("no-parameter-cache"); parser.accepts("skip-large-phrases").withRequiredArg().ofType(Integer.class).defaultsTo(5); + parser.accepts("rare-word").withRequiredArg().ofType(Integer.class).defaultsTo(0); + parser.accepts("rare-edge").withRequiredArg().ofType(Integer.class).defaultsTo(0); OptionSet options = parser.parse(args); if (options.has("help") || !options.has("in")) @@ -47,8 +50,7 @@ public class Trainer } int tags = (Integer) options.valueOf("topics"); - int em_iterations = (Integer) options.valueOf("em-iterations"); - int pr_iterations = (Integer) options.valueOf("pr-iterations"); + int iterations = (Integer) options.valueOf("iterations"); double scale_phrase = (Double) options.valueOf("scale-phrase"); double scale_context = (Double) options.valueOf("scale-context"); int threads = (Integer) options.valueOf("threads"); @@ -57,6 +59,8 @@ public class Trainer double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0; double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0; int skip = (Integer) options.valueOf("skip-large-phrases"); + int wordThreshold = (Integer) options.valueOf("rare-word"); + int edgeThreshold = (Integer) options.valueOf("rare-edge"); if (options.has("seed")) F.rng = new Random((Long) options.valueOf("seed")); @@ -79,15 +83,18 @@ public class Trainer System.exit(1); } + if (wordThreshold > 0) + corpus.applyWordThreshold(wordThreshold); + if (!options.has("agree")) System.out.println("Running with " + tags + " tags " + - "for " + em_iterations + " EM and " + pr_iterations + " PR iterations " + - "skipping large phrases for first " + skip + " iterations " + + "for " + iterations + " iterations " + + ((skip > 0) ? "skipping large phrases for first " + skip + " iterations " : "") + "with scale " + scale_phrase + " phrase and " + scale_context + " context " + "and " + threads + " threads"); else System.out.println("Running agreement model with " + tags + " tags " + - "for " + em_iterations); + "for " + iterations); System.out.println(); @@ -102,17 +109,28 @@ public class Trainer if (vb) cluster.initialiseVB(alphaEmit, alphaPi); if (options.has("no-parameter-cache")) cluster.cacheLambda = false; + if (options.has("start")) + { + try { + System.err.println("Reading starting parameters from " + options.valueOf("start")); + cluster.loadParameters(FileUtil.reader((File)options.valueOf("start"))); + } catch (IOException e) { + System.err.println("Failed to open input file: " + options.valueOf("start")); + e.printStackTrace(); + } + } + cluster.setEdgeThreshold(edgeThreshold); } double last = 0; - for (int i=0; i