summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation
diff options
context:
space:
mode:
authortrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-16 21:34:28 +0000
committertrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-16 21:34:28 +0000
commit1207aaee1f55dbaac8a46f37635a4d1baf392760 (patch)
treead335c14a9df152e4603cc70957103137817d018 /gi/posterior-regularisation
parent9ffba9b2a35582df415384117450f994e64d7cdb (diff)
Added various flags to filter out low count events (words, edges).
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@298 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation')
-rw-r--r--gi/posterior-regularisation/prjava/src/arr/F.java13
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Corpus.java83
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java147
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Trainer.java48
4 files changed, 228 insertions, 63 deletions
diff --git a/gi/posterior-regularisation/prjava/src/arr/F.java b/gi/posterior-regularisation/prjava/src/arr/F.java
index 2201179e..0f74cbab 100644
--- a/gi/posterior-regularisation/prjava/src/arr/F.java
+++ b/gi/posterior-regularisation/prjava/src/arr/F.java
@@ -1,5 +1,6 @@
package arr;
+import java.util.Arrays;
import java.util.Random;
public class F {
@@ -36,11 +37,13 @@ public class F {
for(int i=0;i<a.length;i++){
sum+=a[i];
}
- if(sum==0){
- return ;
- }
- for(int i=0;i<a.length;i++){
- a[i]/=sum;
+ if(sum==0)
+ Arrays.fill(a, 1.0/a.length);
+ else
+ {
+ for(int i=0;i<a.length;i++){
+ a[i]/=sum;
+ }
}
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
index 2de2797b..f2c6b132 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
@@ -17,11 +17,14 @@ public class Corpus
private List<List<Edge>> contextToPhrase = new ArrayList<List<Edge>>();
public int splitSentinel;
public int phraseSentinel;
-
+ public int rareSentinel;
+ private boolean[] wordIsRare;
+
public Corpus()
{
splitSentinel = wordLexicon.insert("<SPLIT>");
phraseSentinel = wordLexicon.insert("<PHRASE>");
+ rareSentinel = wordLexicon.insert("<RARE>");
}
public class Edge
@@ -40,6 +43,10 @@ public class Corpus
{
return Corpus.this.getPhrase(phraseId);
}
+ public TIntArrayList getRawPhrase()
+ {
+ return Corpus.this.getRawPhrase(phraseId);
+ }
public String getPhraseString()
{
return Corpus.this.getPhraseString(phraseId);
@@ -52,7 +59,10 @@ public class Corpus
{
return Corpus.this.getContext(contextId);
}
- public String getContextString(boolean insertPhraseSentinel)
+ public TIntArrayList getRawContext()
+ {
+ return Corpus.this.getRawContext(contextId);
+ } public String getContextString(boolean insertPhraseSentinel)
{
return Corpus.this.getContextString(contextId, insertPhraseSentinel);
}
@@ -132,13 +142,35 @@ public class Corpus
public TIntArrayList getPhrase(int phraseId)
{
+ TIntArrayList phrase = phraseLexicon.lookup(phraseId);
+ if (wordIsRare != null)
+ {
+ boolean first = true;
+ for (int i = 0; i < phrase.size(); ++i)
+ {
+ if (wordIsRare[phrase.get(i)])
+ {
+ if (first)
+ {
+ phrase = (TIntArrayList) phrase.clone();
+ first = false;
+ }
+ phrase.set(i, rareSentinel);
+ }
+ }
+ }
+ return phrase;
+ }
+
+ public TIntArrayList getRawPhrase(int phraseId)
+ {
return phraseLexicon.lookup(phraseId);
}
public String getPhraseString(int phraseId)
{
StringBuffer b = new StringBuffer();
- for (int tid: getPhrase(phraseId).toNativeArray())
+ for (int tid: getRawPhrase(phraseId).toNativeArray())
{
if (b.length() > 0)
b.append(" ");
@@ -149,13 +181,35 @@ public class Corpus
public TIntArrayList getContext(int contextId)
{
+ TIntArrayList context = contextLexicon.lookup(contextId);
+ if (wordIsRare != null)
+ {
+ boolean first = true;
+ for (int i = 0; i < context.size(); ++i)
+ {
+ if (wordIsRare[context.get(i)])
+ {
+ if (first)
+ {
+ context = (TIntArrayList) context.clone();
+ first = false;
+ }
+ context.set(i, rareSentinel);
+ }
+ }
+ }
+ return context;
+ }
+
+ public TIntArrayList getRawContext(int contextId)
+ {
return contextLexicon.lookup(contextId);
}
public String getContextString(int contextId, boolean insertPhraseSentinel)
{
StringBuffer b = new StringBuffer();
- TIntArrayList c = getContext(contextId);
+ TIntArrayList c = getRawContext(contextId);
for (int i = 0; i < c.size(); ++i)
{
if (i > 0) b.append(" ");
@@ -249,5 +303,24 @@ public class Corpus
{
out.println("Corpus has " + edges.size() + " edges " + phraseLexicon.size() + " phrases "
+ contextLexicon.size() + " contexts and " + wordLexicon.size() + " word types");
- }
+ }
+
+ public void applyWordThreshold(int wordThreshold)
+ {
+ int[] counts = new int[wordLexicon.size()];
+ for (Edge e: edges)
+ {
+ TIntArrayList phrase = e.getPhrase();
+ for (int i = 0; i < phrase.size(); ++i)
+ counts[phrase.get(i)] += e.getCount();
+
+ TIntArrayList context = e.getContext();
+ for (int i = 0; i < context.size(); ++i)
+ counts[context.get(i)] += e.getCount();
+ }
+
+ wordIsRare = new boolean[wordLexicon.size()];
+ for (int i = 0; i < wordLexicon.size(); ++i)
+ wordIsRare[i] = counts[i] < wordThreshold;
+ }
} \ No newline at end of file
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index 5efaf52e..feab5eda 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -2,14 +2,20 @@ package phrase;
import gnu.trove.TIntArrayList;
import org.apache.commons.math.special.Gamma;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.List;
+import java.util.StringTokenizer;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
+import java.util.regex.Pattern;
import phrase.Corpus.Edge;
@@ -17,7 +23,7 @@ import phrase.Corpus.Edge;
public class PhraseCluster {
public int K;
- private int n_phrases, n_words, n_contexts, n_positions;
+ private int n_phrases, n_words, n_contexts, n_positions, edge_threshold;
public Corpus c;
public ExecutorService pool;
@@ -38,6 +44,7 @@ public class PhraseCluster {
n_phrases=c.getNumPhrases();
n_contexts=c.getNumContexts();
n_positions=c.getNumContextPositions();
+ edge_threshold=0;
emit=new double [K][n_positions][n_words];
pi=new double[n_phrases][K];
@@ -74,12 +81,11 @@ public class PhraseCluster {
double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
- if (skipBigPhrases)
- {
- for(double [][]i:exp_emit)
- for(double []j:i)
- Arrays.fill(j, 1e-100);
- }
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-10);
+ for(double []j:pi)
+ Arrays.fill(j, 1e-10);
double loglikelihood=0;
@@ -97,6 +103,9 @@ public class PhraseCluster {
for (int ctx=0; ctx<contexts.size(); ctx++)
{
Edge edge = contexts.get(ctx);
+ if (edge.getCount() < edge_threshold)
+ continue;
+
double p[]=posterior(edge);
double z = arr.F.l1norm(p);
assert z > 0;
@@ -121,7 +130,7 @@ public class PhraseCluster {
arr.F.l1normalize(j);
for(double []j:exp_pi)
- arr.F.l1normalize(j);
+ arr.F.l1normalize(j);
emit=exp_emit;
pi=exp_pi;
@@ -250,12 +259,11 @@ public class PhraseCluster {
double [][][]exp_emit=new double[K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
- if (skipBigPhrases)
- {
- for(double [][]i:exp_emit)
- for(double []j:i)
- Arrays.fill(j, 1e-100);
- }
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-10);
+ for(double []j:pi)
+ Arrays.fill(j, 1e-10);
if (lambdaPT == null && cacheLambda)
lambdaPT = new double[n_phrases][];
@@ -271,6 +279,7 @@ public class PhraseCluster {
continue;
}
+ // FIXME: add rare edge check to phrase objective & posterior processing
PhraseObjective po = new PhraseObjective(this, phrase, scalePT, (cacheLambda) ? lambdaPT[phrase] : null);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) ++failures;
@@ -493,11 +502,25 @@ public class PhraseCluster {
{
double[] prob=Arrays.copyOf(pi[edge.getPhraseId()], K);
+ //if (edge.getCount() < edge_threshold)
+ //System.out.println("Edge: " + edge + " probs for phrase " + Arrays.toString(prob));
+
TIntArrayList ctx = edge.getContext();
for(int tag=0;tag<K;tag++)
+ {
for(int c=0;c<n_positions;c++)
- if (!this.c.isSentinel(ctx.get(c)))
- prob[tag]*=emit[tag][c][ctx.get(c)];
+ {
+ int word = ctx.get(c);
+ //if (edge.getCount() < edge_threshold)
+ //System.out.println("\ttag: " + tag + " context word: " + word + " prob " + emit[tag][c][word]);
+
+ if (!this.c.isSentinel(word))
+ prob[tag]*=emit[tag][c][word];
+ }
+ }
+
+ //if (edge.getCount() < edge_threshold)
+ //System.out.println("prob " + Arrays.toString(prob));
return prob;
}
@@ -514,39 +537,33 @@ public class PhraseCluster {
ps.print("\t");
ps.print(edge.getContextString(true));
int t=arr.F.argmax(probs);
- ps.println(" ||| C=" + t);
+ ps.println(" ||| C=" + t + " T=" + edge.getCount() + " P=" + probs[t]);
+ //ps.println("# probs " + Arrays.toString(probs));
}
}
public void displayModelParam(PrintStream ps)
{
final double EPS = 1e-6;
+ ps.println("phrases " + n_phrases + " tags " + K + " positions " + n_positions);
- ps.println("P(tag|phrase)");
for (int i = 0; i < n_phrases; ++i)
- {
- ps.print(c.getPhrase(i));
- for(int j=0;j<pi[i].length;j++){
+ for(int j=0;j<pi[i].length;j++)
if (pi[i][j] > EPS)
- ps.print("\t" + j + ": " + pi[i][j]);
- }
- ps.println();
- }
-
- ps.println("P(word|tag,position)");
+ ps.println(i + " " + j + " " + pi[i][j]);
+
+ ps.println();
for (int i = 0; i < K; ++i)
{
- for(int position=0;position<n_positions;position++){
- ps.println("tag " + i + " position " + position);
- for(int word=0;word<emit[i][position].length;word++){
+ for(int position=0;position<n_positions;position++)
+ {
+ for(int word=0;word<emit[i][position].length;word++)
+ {
if (emit[i][position][word] > EPS)
- ps.print(c.getWord(word)+"="+emit[i][position][word]+"\t");
+ ps.println(i + " " + position + " " + word + " " + emit[i][position][word]);
}
- ps.println();
}
- ps.println();
}
-
}
double phrase_l1lmax()
@@ -586,4 +603,66 @@ public class PhraseCluster {
}
return sum;
}
+
+ public void loadParameters(BufferedReader input) throws IOException
+ {
+ final double EPS = 1e-50;
+
+ // overwrite pi, emit with ~zeros
+ for(double [][]i:emit)
+ for(double []j:i)
+ Arrays.fill(j, EPS);
+
+ for(double []j:pi)
+ Arrays.fill(j, EPS);
+
+ String line = input.readLine();
+ assert line != null;
+
+ Pattern space = Pattern.compile(" +");
+ String[] parts = space.split(line);
+ assert parts.length == 6;
+
+ assert parts[0].equals("phrases");
+ int phrases = Integer.parseInt(parts[1]);
+ int tags = Integer.parseInt(parts[3]);
+ int positions = Integer.parseInt(parts[5]);
+
+ assert phrases == n_phrases;
+ assert tags == K;
+ assert positions == n_positions;
+
+ // read in pi
+ while ((line = input.readLine()) != null)
+ {
+ line = line.trim();
+ if (line.isEmpty()) break;
+
+ String[] tokens = space.split(line);
+ assert tokens.length == 3;
+ int p = Integer.parseInt(tokens[0]);
+ int t = Integer.parseInt(tokens[1]);
+ double v = Double.parseDouble(tokens[2]);
+
+ pi[p][t] = v;
+ }
+
+ // read in emissions
+ while ((line = input.readLine()) != null)
+ {
+ String[] tokens = space.split(line);
+ assert tokens.length == 4;
+ int t = Integer.parseInt(tokens[0]);
+ int p = Integer.parseInt(tokens[1]);
+ int w = Integer.parseInt(tokens[2]);
+ double v = Double.parseDouble(tokens[3]);
+
+ emit[t][p][w] = v;
+ }
+ }
+
+ public void setEdgeThreshold(int edgeThreshold)
+ {
+ this.edge_threshold = edgeThreshold;
+ }
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
index a67c17a2..ed7a6bbe 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
@@ -4,6 +4,7 @@ import io.FileUtil;
import joptsimple.OptionParser;
import joptsimple.OptionSet;
import java.io.File;
+import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintStream;
import java.util.Random;
@@ -18,12 +19,12 @@ public class Trainer
parser.accepts("help");
parser.accepts("in").withRequiredArg().ofType(File.class);
parser.accepts("out").withRequiredArg().ofType(File.class);
+ parser.accepts("start").withRequiredArg().ofType(File.class);
parser.accepts("parameters").withRequiredArg().ofType(File.class);
parser.accepts("topics").withRequiredArg().ofType(Integer.class).defaultsTo(5);
- parser.accepts("em-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(5);
- parser.accepts("pr-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(0);
+ parser.accepts("iterations").withRequiredArg().ofType(Integer.class).defaultsTo(10);
parser.accepts("threads").withRequiredArg().ofType(Integer.class).defaultsTo(0);
- parser.accepts("scale-phrase").withRequiredArg().ofType(Double.class).defaultsTo(5.0);
+ parser.accepts("scale-phrase").withRequiredArg().ofType(Double.class).defaultsTo(0.0);
parser.accepts("scale-context").withRequiredArg().ofType(Double.class).defaultsTo(0.0);
parser.accepts("seed").withRequiredArg().ofType(Long.class).defaultsTo(0l);
parser.accepts("convergence-threshold").withRequiredArg().ofType(Double.class).defaultsTo(1e-6);
@@ -33,6 +34,8 @@ public class Trainer
parser.accepts("agree");
parser.accepts("no-parameter-cache");
parser.accepts("skip-large-phrases").withRequiredArg().ofType(Integer.class).defaultsTo(5);
+ parser.accepts("rare-word").withRequiredArg().ofType(Integer.class).defaultsTo(0);
+ parser.accepts("rare-edge").withRequiredArg().ofType(Integer.class).defaultsTo(0);
OptionSet options = parser.parse(args);
if (options.has("help") || !options.has("in"))
@@ -47,8 +50,7 @@ public class Trainer
}
int tags = (Integer) options.valueOf("topics");
- int em_iterations = (Integer) options.valueOf("em-iterations");
- int pr_iterations = (Integer) options.valueOf("pr-iterations");
+ int iterations = (Integer) options.valueOf("iterations");
double scale_phrase = (Double) options.valueOf("scale-phrase");
double scale_context = (Double) options.valueOf("scale-context");
int threads = (Integer) options.valueOf("threads");
@@ -57,6 +59,8 @@ public class Trainer
double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0;
double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0;
int skip = (Integer) options.valueOf("skip-large-phrases");
+ int wordThreshold = (Integer) options.valueOf("rare-word");
+ int edgeThreshold = (Integer) options.valueOf("rare-edge");
if (options.has("seed"))
F.rng = new Random((Long) options.valueOf("seed"));
@@ -79,15 +83,18 @@ public class Trainer
System.exit(1);
}
+ if (wordThreshold > 0)
+ corpus.applyWordThreshold(wordThreshold);
+
if (!options.has("agree"))
System.out.println("Running with " + tags + " tags " +
- "for " + em_iterations + " EM and " + pr_iterations + " PR iterations " +
- "skipping large phrases for first " + skip + " iterations " +
+ "for " + iterations + " iterations " +
+ ((skip > 0) ? "skipping large phrases for first " + skip + " iterations " : "") +
"with scale " + scale_phrase + " phrase and " + scale_context + " context " +
"and " + threads + " threads");
else
System.out.println("Running agreement model with " + tags + " tags " +
- "for " + em_iterations);
+ "for " + iterations);
System.out.println();
@@ -102,17 +109,28 @@ public class Trainer
if (vb) cluster.initialiseVB(alphaEmit, alphaPi);
if (options.has("no-parameter-cache"))
cluster.cacheLambda = false;
+ if (options.has("start"))
+ {
+ try {
+ System.err.println("Reading starting parameters from " + options.valueOf("start"));
+ cluster.loadParameters(FileUtil.reader((File)options.valueOf("start")));
+ } catch (IOException e) {
+ System.err.println("Failed to open input file: " + options.valueOf("start"));
+ e.printStackTrace();
+ }
+ }
+ cluster.setEdgeThreshold(edgeThreshold);
}
double last = 0;
- for (int i=0; i<em_iterations+pr_iterations; i++)
+ for (int i=0; i < iterations; i++)
{
double o;
if (agree != null)
o = agree.EM();
else
{
- if (i < em_iterations)
+ if (scale_phrase <= 0 && scale_context <= 0)
{
if (!vb)
o = cluster.EM(i < skip);
@@ -128,13 +146,7 @@ public class Trainer
if (i != 0 && Math.abs((o - last) / o) < threshold)
{
last = o;
- if (i < Math.max(em_iterations, skip))
- {
- i = Math.max(em_iterations, skip) - 1;
- continue;
- }
- else
- break;
+ break;
}
last = o;
}
@@ -145,8 +157,6 @@ public class Trainer
double pl1lmax = cluster.phrase_l1lmax();
double cl1lmax = cluster.context_l1lmax();
System.out.println("\nFinal posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax);
- if (pr_iterations == 0)
- System.out.println("With PR objective " + (last - scale_phrase*pl1lmax - scale_context*cl1lmax));
if (options.has("out"))
{