diff options
Diffstat (limited to 'gi/posterior-regularisation')
4 files changed, 228 insertions, 63 deletions
diff --git a/gi/posterior-regularisation/prjava/src/arr/F.java b/gi/posterior-regularisation/prjava/src/arr/F.java index 2201179e..0f74cbab 100644 --- a/gi/posterior-regularisation/prjava/src/arr/F.java +++ b/gi/posterior-regularisation/prjava/src/arr/F.java @@ -1,5 +1,6 @@ package arr;
+import java.util.Arrays;
import java.util.Random;
public class F {
@@ -36,11 +37,13 @@ public class F { for(int i=0;i<a.length;i++){
sum+=a[i];
}
- if(sum==0){
- return ;
- }
- for(int i=0;i<a.length;i++){
- a[i]/=sum;
+ if(sum==0)
+ Arrays.fill(a, 1.0/a.length);
+ else
+ {
+ for(int i=0;i<a.length;i++){
+ a[i]/=sum;
+ }
}
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java index 2de2797b..f2c6b132 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java @@ -17,11 +17,14 @@ public class Corpus private List<List<Edge>> contextToPhrase = new ArrayList<List<Edge>>(); public int splitSentinel; public int phraseSentinel; - + public int rareSentinel; + private boolean[] wordIsRare; + public Corpus() { splitSentinel = wordLexicon.insert("<SPLIT>"); phraseSentinel = wordLexicon.insert("<PHRASE>"); + rareSentinel = wordLexicon.insert("<RARE>"); } public class Edge @@ -40,6 +43,10 @@ public class Corpus { return Corpus.this.getPhrase(phraseId); } + public TIntArrayList getRawPhrase() + { + return Corpus.this.getRawPhrase(phraseId); + } public String getPhraseString() { return Corpus.this.getPhraseString(phraseId); @@ -52,7 +59,10 @@ public class Corpus { return Corpus.this.getContext(contextId); } - public String getContextString(boolean insertPhraseSentinel) + public TIntArrayList getRawContext() + { + return Corpus.this.getRawContext(contextId); + } public String getContextString(boolean insertPhraseSentinel) { return Corpus.this.getContextString(contextId, insertPhraseSentinel); } @@ -132,13 +142,35 @@ public class Corpus public TIntArrayList getPhrase(int phraseId) { + TIntArrayList phrase = phraseLexicon.lookup(phraseId); + if (wordIsRare != null) + { + boolean first = true; + for (int i = 0; i < phrase.size(); ++i) + { + if (wordIsRare[phrase.get(i)]) + { + if (first) + { + phrase = (TIntArrayList) phrase.clone(); + first = false; + } + phrase.set(i, rareSentinel); + } + } + } + return phrase; + } + + public TIntArrayList getRawPhrase(int phraseId) + { return phraseLexicon.lookup(phraseId); } public String getPhraseString(int phraseId) { StringBuffer b = new StringBuffer(); - for (int tid: getPhrase(phraseId).toNativeArray()) + for (int tid: getRawPhrase(phraseId).toNativeArray()) { if (b.length() > 0) b.append(" "); @@ -149,13 +181,35 @@ public class Corpus public TIntArrayList getContext(int contextId) { + TIntArrayList context = contextLexicon.lookup(contextId); + if (wordIsRare != null) + { + boolean first = true; + for (int i = 0; i < context.size(); ++i) + { + if (wordIsRare[context.get(i)]) + { + if (first) + { + context = (TIntArrayList) context.clone(); + first = false; + } + context.set(i, rareSentinel); + } + } + } + return context; + } + + public TIntArrayList getRawContext(int contextId) + { return contextLexicon.lookup(contextId); } public String getContextString(int contextId, boolean insertPhraseSentinel) { StringBuffer b = new StringBuffer(); - TIntArrayList c = getContext(contextId); + TIntArrayList c = getRawContext(contextId); for (int i = 0; i < c.size(); ++i) { if (i > 0) b.append(" "); @@ -249,5 +303,24 @@ public class Corpus { out.println("Corpus has " + edges.size() + " edges " + phraseLexicon.size() + " phrases " + contextLexicon.size() + " contexts and " + wordLexicon.size() + " word types"); - } + } + + public void applyWordThreshold(int wordThreshold) + { + int[] counts = new int[wordLexicon.size()]; + for (Edge e: edges) + { + TIntArrayList phrase = e.getPhrase(); + for (int i = 0; i < phrase.size(); ++i) + counts[phrase.get(i)] += e.getCount(); + + TIntArrayList context = e.getContext(); + for (int i = 0; i < context.size(); ++i) + counts[context.get(i)] += e.getCount(); + } + + wordIsRare = new boolean[wordLexicon.size()]; + for (int i = 0; i < wordLexicon.size(); ++i) + wordIsRare[i] = counts[i] < wordThreshold; + } }
\ No newline at end of file diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 5efaf52e..feab5eda 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -2,14 +2,20 @@ package phrase; import gnu.trove.TIntArrayList;
import org.apache.commons.math.special.Gamma;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.List;
+import java.util.StringTokenizer;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
+import java.util.regex.Pattern;
import phrase.Corpus.Edge;
@@ -17,7 +23,7 @@ import phrase.Corpus.Edge; public class PhraseCluster {
public int K;
- private int n_phrases, n_words, n_contexts, n_positions;
+ private int n_phrases, n_words, n_contexts, n_positions, edge_threshold;
public Corpus c;
public ExecutorService pool;
@@ -38,6 +44,7 @@ public class PhraseCluster { n_phrases=c.getNumPhrases();
n_contexts=c.getNumContexts();
n_positions=c.getNumContextPositions();
+ edge_threshold=0;
emit=new double [K][n_positions][n_words];
pi=new double[n_phrases][K];
@@ -74,12 +81,11 @@ public class PhraseCluster { double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
- if (skipBigPhrases)
- {
- for(double [][]i:exp_emit)
- for(double []j:i)
- Arrays.fill(j, 1e-100);
- }
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-10);
+ for(double []j:pi)
+ Arrays.fill(j, 1e-10);
double loglikelihood=0;
@@ -97,6 +103,9 @@ public class PhraseCluster { for (int ctx=0; ctx<contexts.size(); ctx++)
{
Edge edge = contexts.get(ctx);
+ if (edge.getCount() < edge_threshold)
+ continue;
+
double p[]=posterior(edge);
double z = arr.F.l1norm(p);
assert z > 0;
@@ -121,7 +130,7 @@ public class PhraseCluster { arr.F.l1normalize(j);
for(double []j:exp_pi)
- arr.F.l1normalize(j);
+ arr.F.l1normalize(j);
emit=exp_emit;
pi=exp_pi;
@@ -250,12 +259,11 @@ public class PhraseCluster { double [][][]exp_emit=new double[K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
- if (skipBigPhrases)
- {
- for(double [][]i:exp_emit)
- for(double []j:i)
- Arrays.fill(j, 1e-100);
- }
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-10);
+ for(double []j:pi)
+ Arrays.fill(j, 1e-10);
if (lambdaPT == null && cacheLambda)
lambdaPT = new double[n_phrases][];
@@ -271,6 +279,7 @@ public class PhraseCluster { continue;
}
+ // FIXME: add rare edge check to phrase objective & posterior processing
PhraseObjective po = new PhraseObjective(this, phrase, scalePT, (cacheLambda) ? lambdaPT[phrase] : null);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) ++failures;
@@ -493,11 +502,25 @@ public class PhraseCluster { {
double[] prob=Arrays.copyOf(pi[edge.getPhraseId()], K);
+ //if (edge.getCount() < edge_threshold)
+ //System.out.println("Edge: " + edge + " probs for phrase " + Arrays.toString(prob));
+
TIntArrayList ctx = edge.getContext();
for(int tag=0;tag<K;tag++)
+ {
for(int c=0;c<n_positions;c++)
- if (!this.c.isSentinel(ctx.get(c)))
- prob[tag]*=emit[tag][c][ctx.get(c)];
+ {
+ int word = ctx.get(c);
+ //if (edge.getCount() < edge_threshold)
+ //System.out.println("\ttag: " + tag + " context word: " + word + " prob " + emit[tag][c][word]);
+
+ if (!this.c.isSentinel(word))
+ prob[tag]*=emit[tag][c][word];
+ }
+ }
+
+ //if (edge.getCount() < edge_threshold)
+ //System.out.println("prob " + Arrays.toString(prob));
return prob;
}
@@ -514,39 +537,33 @@ public class PhraseCluster { ps.print("\t");
ps.print(edge.getContextString(true));
int t=arr.F.argmax(probs);
- ps.println(" ||| C=" + t);
+ ps.println(" ||| C=" + t + " T=" + edge.getCount() + " P=" + probs[t]);
+ //ps.println("# probs " + Arrays.toString(probs));
}
}
public void displayModelParam(PrintStream ps)
{
final double EPS = 1e-6;
+ ps.println("phrases " + n_phrases + " tags " + K + " positions " + n_positions);
- ps.println("P(tag|phrase)");
for (int i = 0; i < n_phrases; ++i)
- {
- ps.print(c.getPhrase(i));
- for(int j=0;j<pi[i].length;j++){
+ for(int j=0;j<pi[i].length;j++)
if (pi[i][j] > EPS)
- ps.print("\t" + j + ": " + pi[i][j]);
- }
- ps.println();
- }
-
- ps.println("P(word|tag,position)");
+ ps.println(i + " " + j + " " + pi[i][j]);
+
+ ps.println();
for (int i = 0; i < K; ++i)
{
- for(int position=0;position<n_positions;position++){
- ps.println("tag " + i + " position " + position);
- for(int word=0;word<emit[i][position].length;word++){
+ for(int position=0;position<n_positions;position++)
+ {
+ for(int word=0;word<emit[i][position].length;word++)
+ {
if (emit[i][position][word] > EPS)
- ps.print(c.getWord(word)+"="+emit[i][position][word]+"\t");
+ ps.println(i + " " + position + " " + word + " " + emit[i][position][word]);
}
- ps.println();
}
- ps.println();
}
-
}
double phrase_l1lmax()
@@ -586,4 +603,66 @@ public class PhraseCluster { }
return sum;
}
+
+ public void loadParameters(BufferedReader input) throws IOException
+ {
+ final double EPS = 1e-50;
+
+ // overwrite pi, emit with ~zeros
+ for(double [][]i:emit)
+ for(double []j:i)
+ Arrays.fill(j, EPS);
+
+ for(double []j:pi)
+ Arrays.fill(j, EPS);
+
+ String line = input.readLine();
+ assert line != null;
+
+ Pattern space = Pattern.compile(" +");
+ String[] parts = space.split(line);
+ assert parts.length == 6;
+
+ assert parts[0].equals("phrases");
+ int phrases = Integer.parseInt(parts[1]);
+ int tags = Integer.parseInt(parts[3]);
+ int positions = Integer.parseInt(parts[5]);
+
+ assert phrases == n_phrases;
+ assert tags == K;
+ assert positions == n_positions;
+
+ // read in pi
+ while ((line = input.readLine()) != null)
+ {
+ line = line.trim();
+ if (line.isEmpty()) break;
+
+ String[] tokens = space.split(line);
+ assert tokens.length == 3;
+ int p = Integer.parseInt(tokens[0]);
+ int t = Integer.parseInt(tokens[1]);
+ double v = Double.parseDouble(tokens[2]);
+
+ pi[p][t] = v;
+ }
+
+ // read in emissions
+ while ((line = input.readLine()) != null)
+ {
+ String[] tokens = space.split(line);
+ assert tokens.length == 4;
+ int t = Integer.parseInt(tokens[0]);
+ int p = Integer.parseInt(tokens[1]);
+ int w = Integer.parseInt(tokens[2]);
+ double v = Double.parseDouble(tokens[3]);
+
+ emit[t][p][w] = v;
+ }
+ }
+
+ public void setEdgeThreshold(int edgeThreshold)
+ {
+ this.edge_threshold = edgeThreshold;
+ }
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index a67c17a2..ed7a6bbe 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -4,6 +4,7 @@ import io.FileUtil; import joptsimple.OptionParser; import joptsimple.OptionSet; import java.io.File; +import java.io.FileNotFoundException; import java.io.IOException; import java.io.PrintStream; import java.util.Random; @@ -18,12 +19,12 @@ public class Trainer parser.accepts("help"); parser.accepts("in").withRequiredArg().ofType(File.class); parser.accepts("out").withRequiredArg().ofType(File.class); + parser.accepts("start").withRequiredArg().ofType(File.class); parser.accepts("parameters").withRequiredArg().ofType(File.class); parser.accepts("topics").withRequiredArg().ofType(Integer.class).defaultsTo(5); - parser.accepts("em-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(5); - parser.accepts("pr-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(0); + parser.accepts("iterations").withRequiredArg().ofType(Integer.class).defaultsTo(10); parser.accepts("threads").withRequiredArg().ofType(Integer.class).defaultsTo(0); - parser.accepts("scale-phrase").withRequiredArg().ofType(Double.class).defaultsTo(5.0); + parser.accepts("scale-phrase").withRequiredArg().ofType(Double.class).defaultsTo(0.0); parser.accepts("scale-context").withRequiredArg().ofType(Double.class).defaultsTo(0.0); parser.accepts("seed").withRequiredArg().ofType(Long.class).defaultsTo(0l); parser.accepts("convergence-threshold").withRequiredArg().ofType(Double.class).defaultsTo(1e-6); @@ -33,6 +34,8 @@ public class Trainer parser.accepts("agree"); parser.accepts("no-parameter-cache"); parser.accepts("skip-large-phrases").withRequiredArg().ofType(Integer.class).defaultsTo(5); + parser.accepts("rare-word").withRequiredArg().ofType(Integer.class).defaultsTo(0); + parser.accepts("rare-edge").withRequiredArg().ofType(Integer.class).defaultsTo(0); OptionSet options = parser.parse(args); if (options.has("help") || !options.has("in")) @@ -47,8 +50,7 @@ public class Trainer } int tags = (Integer) options.valueOf("topics"); - int em_iterations = (Integer) options.valueOf("em-iterations"); - int pr_iterations = (Integer) options.valueOf("pr-iterations"); + int iterations = (Integer) options.valueOf("iterations"); double scale_phrase = (Double) options.valueOf("scale-phrase"); double scale_context = (Double) options.valueOf("scale-context"); int threads = (Integer) options.valueOf("threads"); @@ -57,6 +59,8 @@ public class Trainer double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0; double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0; int skip = (Integer) options.valueOf("skip-large-phrases"); + int wordThreshold = (Integer) options.valueOf("rare-word"); + int edgeThreshold = (Integer) options.valueOf("rare-edge"); if (options.has("seed")) F.rng = new Random((Long) options.valueOf("seed")); @@ -79,15 +83,18 @@ public class Trainer System.exit(1); } + if (wordThreshold > 0) + corpus.applyWordThreshold(wordThreshold); + if (!options.has("agree")) System.out.println("Running with " + tags + " tags " + - "for " + em_iterations + " EM and " + pr_iterations + " PR iterations " + - "skipping large phrases for first " + skip + " iterations " + + "for " + iterations + " iterations " + + ((skip > 0) ? "skipping large phrases for first " + skip + " iterations " : "") + "with scale " + scale_phrase + " phrase and " + scale_context + " context " + "and " + threads + " threads"); else System.out.println("Running agreement model with " + tags + " tags " + - "for " + em_iterations); + "for " + iterations); System.out.println(); @@ -102,17 +109,28 @@ public class Trainer if (vb) cluster.initialiseVB(alphaEmit, alphaPi); if (options.has("no-parameter-cache")) cluster.cacheLambda = false; + if (options.has("start")) + { + try { + System.err.println("Reading starting parameters from " + options.valueOf("start")); + cluster.loadParameters(FileUtil.reader((File)options.valueOf("start"))); + } catch (IOException e) { + System.err.println("Failed to open input file: " + options.valueOf("start")); + e.printStackTrace(); + } + } + cluster.setEdgeThreshold(edgeThreshold); } double last = 0; - for (int i=0; i<em_iterations+pr_iterations; i++) + for (int i=0; i < iterations; i++) { double o; if (agree != null) o = agree.EM(); else { - if (i < em_iterations) + if (scale_phrase <= 0 && scale_context <= 0) { if (!vb) o = cluster.EM(i < skip); @@ -128,13 +146,7 @@ public class Trainer if (i != 0 && Math.abs((o - last) / o) < threshold) { last = o; - if (i < Math.max(em_iterations, skip)) - { - i = Math.max(em_iterations, skip) - 1; - continue; - } - else - break; + break; } last = o; } @@ -145,8 +157,6 @@ public class Trainer double pl1lmax = cluster.phrase_l1lmax(); double cl1lmax = cluster.context_l1lmax(); System.out.println("\nFinal posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax); - if (pr_iterations == 0) - System.out.println("With PR objective " + (last - scale_phrase*pl1lmax - scale_context*cl1lmax)); if (options.has("out")) { |