From a900eeb513e71ecbf5de9ba545da052002184fdc Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Fri, 16 Jul 2010 21:34:28 +0000 Subject: Added various flags to filter out low count events (words, edges). git-svn-id: https://ws10smt.googlecode.com/svn/trunk@298 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../prjava/src/phrase/PhraseCluster.java | 147 ++++++++++++++++----- 1 file changed, 113 insertions(+), 34 deletions(-) (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java') diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 5efaf52e..feab5eda 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -2,14 +2,20 @@ package phrase; import gnu.trove.TIntArrayList; import org.apache.commons.math.special.Gamma; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; import java.io.PrintStream; import java.util.Arrays; import java.util.List; +import java.util.StringTokenizer; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import java.util.regex.Pattern; import phrase.Corpus.Edge; @@ -17,7 +23,7 @@ import phrase.Corpus.Edge; public class PhraseCluster { public int K; - private int n_phrases, n_words, n_contexts, n_positions; + private int n_phrases, n_words, n_contexts, n_positions, edge_threshold; public Corpus c; public ExecutorService pool; @@ -38,6 +44,7 @@ public class PhraseCluster { n_phrases=c.getNumPhrases(); n_contexts=c.getNumContexts(); n_positions=c.getNumContextPositions(); + edge_threshold=0; emit=new double [K][n_positions][n_words]; pi=new double[n_phrases][K]; @@ -74,12 +81,11 @@ public class PhraseCluster { double [][][]exp_emit=new double [K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; - if (skipBigPhrases) - { - for(double [][]i:exp_emit) - for(double []j:i) - Arrays.fill(j, 1e-100); - } + for(double [][]i:exp_emit) + for(double []j:i) + Arrays.fill(j, 1e-10); + for(double []j:pi) + Arrays.fill(j, 1e-10); double loglikelihood=0; @@ -97,6 +103,9 @@ public class PhraseCluster { for (int ctx=0; ctx 0; @@ -121,7 +130,7 @@ public class PhraseCluster { arr.F.l1normalize(j); for(double []j:exp_pi) - arr.F.l1normalize(j); + arr.F.l1normalize(j); emit=exp_emit; pi=exp_pi; @@ -250,12 +259,11 @@ public class PhraseCluster { double [][][]exp_emit=new double[K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; - if (skipBigPhrases) - { - for(double [][]i:exp_emit) - for(double []j:i) - Arrays.fill(j, 1e-100); - } + for(double [][]i:exp_emit) + for(double []j:i) + Arrays.fill(j, 1e-10); + for(double []j:pi) + Arrays.fill(j, 1e-10); if (lambdaPT == null && cacheLambda) lambdaPT = new double[n_phrases][]; @@ -271,6 +279,7 @@ public class PhraseCluster { continue; } + // FIXME: add rare edge check to phrase objective & posterior processing PhraseObjective po = new PhraseObjective(this, phrase, scalePT, (cacheLambda) ? lambdaPT[phrase] : null); boolean ok = po.optimizeWithProjectedGradientDescent(); if (!ok) ++failures; @@ -493,11 +502,25 @@ public class PhraseCluster { { double[] prob=Arrays.copyOf(pi[edge.getPhraseId()], K); + //if (edge.getCount() < edge_threshold) + //System.out.println("Edge: " + edge + " probs for phrase " + Arrays.toString(prob)); + TIntArrayList ctx = edge.getContext(); for(int tag=0;tag EPS) - ps.print("\t" + j + ": " + pi[i][j]); - } - ps.println(); - } - - ps.println("P(word|tag,position)"); + ps.println(i + " " + j + " " + pi[i][j]); + + ps.println(); for (int i = 0; i < K; ++i) { - for(int position=0;position EPS) - ps.print(c.getWord(word)+"="+emit[i][position][word]+"\t"); + ps.println(i + " " + position + " " + word + " " + emit[i][position][word]); } - ps.println(); } - ps.println(); } - } double phrase_l1lmax() @@ -586,4 +603,66 @@ public class PhraseCluster { } return sum; } + + public void loadParameters(BufferedReader input) throws IOException + { + final double EPS = 1e-50; + + // overwrite pi, emit with ~zeros + for(double [][]i:emit) + for(double []j:i) + Arrays.fill(j, EPS); + + for(double []j:pi) + Arrays.fill(j, EPS); + + String line = input.readLine(); + assert line != null; + + Pattern space = Pattern.compile(" +"); + String[] parts = space.split(line); + assert parts.length == 6; + + assert parts[0].equals("phrases"); + int phrases = Integer.parseInt(parts[1]); + int tags = Integer.parseInt(parts[3]); + int positions = Integer.parseInt(parts[5]); + + assert phrases == n_phrases; + assert tags == K; + assert positions == n_positions; + + // read in pi + while ((line = input.readLine()) != null) + { + line = line.trim(); + if (line.isEmpty()) break; + + String[] tokens = space.split(line); + assert tokens.length == 3; + int p = Integer.parseInt(tokens[0]); + int t = Integer.parseInt(tokens[1]); + double v = Double.parseDouble(tokens[2]); + + pi[p][t] = v; + } + + // read in emissions + while ((line = input.readLine()) != null) + { + String[] tokens = space.split(line); + assert tokens.length == 4; + int t = Integer.parseInt(tokens[0]); + int p = Integer.parseInt(tokens[1]); + int w = Integer.parseInt(tokens[2]); + double v = Double.parseDouble(tokens[3]); + + emit[t][p][w] = v; + } + } + + public void setEdgeThreshold(int edgeThreshold) + { + this.edge_threshold = edgeThreshold; + } } -- cgit v1.2.3