diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava')
3 files changed, 92 insertions, 233 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java index 2afc18dc..ad092cc6 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java @@ -18,9 +18,6 @@ public class Corpus public int splitSentinel; public int phraseSentinel; public int rareSentinel; - private boolean[] rareWords; - private boolean[] rarePhrases; - private boolean[] rareContexts; public Corpus() { @@ -45,10 +42,6 @@ public class Corpus { return Corpus.this.getPhrase(phraseId); } - public TIntArrayList getRawPhrase() - { - return Corpus.this.getRawPhrase(phraseId); - } public String getPhraseString() { return Corpus.this.getPhraseString(phraseId); @@ -61,10 +54,6 @@ public class Corpus { return Corpus.this.getContext(contextId); } - public TIntArrayList getRawContext() - { - return Corpus.this.getRawContext(contextId); - } public String getContextString(boolean insertPhraseSentinel) { return Corpus.this.getContextString(contextId, insertPhraseSentinel); @@ -145,35 +134,13 @@ public class Corpus public TIntArrayList getPhrase(int phraseId) { - TIntArrayList phrase = phraseLexicon.lookup(phraseId); - if (rareWords != null) - { - boolean first = true; - for (int i = 0; i < phrase.size(); ++i) - { - if (rareWords[phrase.get(i)]) - { - if (first) - { - phrase = (TIntArrayList) phrase.clone(); - first = false; - } - phrase.set(i, rareSentinel); - } - } - } - return phrase; - } - - public TIntArrayList getRawPhrase(int phraseId) - { return phraseLexicon.lookup(phraseId); } public String getPhraseString(int phraseId) { StringBuffer b = new StringBuffer(); - for (int tid: getRawPhrase(phraseId).toNativeArray()) + for (int tid: getPhrase(phraseId).toNativeArray()) { if (b.length() > 0) b.append(" "); @@ -184,35 +151,13 @@ public class Corpus public TIntArrayList getContext(int contextId) { - TIntArrayList context = contextLexicon.lookup(contextId); - if (rareWords != null) - { - boolean first = true; - for (int i = 0; i < context.size(); ++i) - { - if (rareWords[context.get(i)]) - { - if (first) - { - context = (TIntArrayList) context.clone(); - first = false; - } - context.set(i, rareSentinel); - } - } - } - return context; - } - - public TIntArrayList getRawContext(int contextId) - { return contextLexicon.lookup(contextId); } public String getContextString(int contextId, boolean insertPhraseSentinel) { StringBuffer b = new StringBuffer(); - TIntArrayList c = getRawContext(contextId); + TIntArrayList c = getContext(contextId); for (int i = 0; i < c.size(); ++i) { if (i > 0) b.append(" "); @@ -227,15 +172,14 @@ public class Corpus return wordId == splitSentinel || wordId == phraseSentinel; } - static Corpus readFromFile(Reader in) throws IOException - { - Corpus c = new Corpus(); - + List<Edge> readEdges(Reader in) throws IOException + { // read in line-by-line BufferedReader bin = new BufferedReader(in); String line; Pattern separator = Pattern.compile(" \\|\\|\\| "); - + + List<Edge> edges = new ArrayList<Edge>(); while ((line = bin.readLine()) != null) { // split into phrase and contexts @@ -250,10 +194,8 @@ public class Corpus st = new StringTokenizer(phraseToks, " "); TIntArrayList ptoks = new TIntArrayList(); while (st.hasMoreTokens()) - ptoks.add(c.wordLexicon.insert(st.nextToken())); - int phraseId = c.phraseLexicon.insert(ptoks); - if (phraseId == c.phraseToContext.size()) - c.phraseToContext.add(new ArrayList<Edge>()); + ptoks.add(wordLexicon.insert(st.nextToken())); + int phraseId = phraseLexicon.insert(ptoks); // process contexts String[] parts = separator.split(rest); @@ -261,34 +203,45 @@ public class Corpus for (int i = 0; i < parts.length; i += 2) { // process pairs of strings - context and count - TIntArrayList ctx = new TIntArrayList(); String ctxString = parts[i]; String countString = parts[i + 1]; + + assert (countString.startsWith("C=")); + int count = Integer.parseInt(countString.substring(2).trim()); + + TIntArrayList ctx = new TIntArrayList(); StringTokenizer ctxStrtok = new StringTokenizer(ctxString, " "); while (ctxStrtok.hasMoreTokens()) { String token = ctxStrtok.nextToken(); - //if (!token.equals("<PHRASE>")) - ctx.add(c.wordLexicon.insert(token)); + ctx.add(wordLexicon.insert(token)); } - int contextId = c.contextLexicon.insert(ctx); - if (contextId == c.contextToPhrase.size()) - c.contextToPhrase.add(new ArrayList<Edge>()); + int contextId = contextLexicon.insert(ctx); - assert (countString.startsWith("C=")); - Edge e = c.new Edge(phraseId, contextId, - Integer.parseInt(countString.substring(2).trim())); - c.edges.add(e); - - // index the edge for fast phrase, context lookup - c.phraseToContext.get(phraseId).add(e); - c.contextToPhrase.get(contextId).add(e); + edges.add(new Edge(phraseId, contextId, count)); } } - - return c; + return edges; } + static Corpus readFromFile(Reader in) throws IOException + { + Corpus c = new Corpus(); + c.edges = c.readEdges(in); + for (Edge edge: c.edges) + { + while (edge.getPhraseId() >= c.phraseToContext.size()) + c.phraseToContext.add(new ArrayList<Edge>()); + while (edge.getContextId() >= c.contextToPhrase.size()) + c.contextToPhrase.add(new ArrayList<Edge>()); + + // index the edge for fast phrase, context lookup + c.phraseToContext.get(edge.getPhraseId()).add(edge); + c.contextToPhrase.get(edge.getContextId()).add(edge); + } + return c; + } + TIntArrayList phraseEdges(TIntArrayList phrase) { TIntArrayList r = new TIntArrayList(4); @@ -307,86 +260,4 @@ public class Corpus out.println("Corpus has " + edges.size() + " edges " + phraseLexicon.size() + " phrases " + contextLexicon.size() + " contexts and " + wordLexicon.size() + " word types"); } - - public void applyWordThreshold(int wordThreshold) - { - int[] counts = new int[wordLexicon.size()]; - for (Edge e: edges) - { - TIntArrayList phrase = e.getPhrase(); - for (int i = 0; i < phrase.size(); ++i) - counts[phrase.get(i)] += e.getCount(); - - TIntArrayList context = e.getContext(); - for (int i = 0; i < context.size(); ++i) - counts[context.get(i)] += e.getCount(); - } - - int count = 0; - rareWords = new boolean[wordLexicon.size()]; - for (int i = 0; i < wordLexicon.size(); ++i) - { - rareWords[i] = counts[i] < wordThreshold; - if (rareWords[i]) - count++; - } - System.err.println("There are " + count + " rare words"); - } - - public void applyPhraseThreshold(int threshold) - { - rarePhrases = new boolean[phraseLexicon.size()]; - - int n = 0; - for (int i = 0; i < phraseLexicon.size(); ++i) - { - List<Edge> contexts = phraseToContext.get(i); - int count = 0; - for (Edge edge: contexts) - { - count += edge.getCount(); - if (count >= threshold) - break; - } - - if (count < threshold) - { - rarePhrases[i] = true; - n++; - } - } - System.err.println("There are " + n + " rare phrases"); - } - - public void applyContextThreshold(int threshold) - { - rareContexts = new boolean[contextLexicon.size()]; - - int n = 0; - for (int i = 0; i < contextLexicon.size(); ++i) - { - List<Edge> phrases = contextToPhrase.get(i); - int count = 0; - for (Edge edge: phrases) - { - count += edge.getCount(); - if (count >= threshold) - break; - } - - if (count < threshold) - { - rareContexts[i] = true; - n++; - } - } - System.err.println("There are " + n + " rare contexts"); - } - - boolean isRare(Edge edge) - { - if (rarePhrases != null && rarePhrases[edge.getPhraseId()] == true) return true; - if (rareContexts != null && rareContexts[edge.getContextId()] == true) return true; - return false; - } }
\ No newline at end of file diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 9ee766d4..e7e4af32 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -23,7 +23,7 @@ import phrase.Corpus.Edge; public class PhraseCluster {
public int K;
- private int n_phrases, n_words, n_contexts, n_positions, edge_threshold;
+ private int n_phrases, n_words, n_contexts, n_positions;
public Corpus c;
public ExecutorService pool;
@@ -44,7 +44,6 @@ public class PhraseCluster { n_phrases=c.getNumPhrases();
n_contexts=c.getNumContexts();
n_positions=c.getNumContextPositions();
- edge_threshold=0;
emit=new double [K][n_positions][n_words];
pi=new double[n_phrases][K];
@@ -76,7 +75,7 @@ public class PhraseCluster { pool = Executors.newFixedThreadPool(threads);
}
- public double EM(boolean skipBigPhrases)
+ public double EM(int phraseSizeLimit)
{
double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
@@ -92,19 +91,17 @@ public class PhraseCluster { //E
for(int phrase=0; phrase < n_phrases; phrase++)
{
- if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit)
{
System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
continue;
- }
+ }
List<Edge> contexts = c.getEdgesForPhrase(phrase);
for (int ctx=0; ctx<contexts.size(); ctx++)
{
Edge edge = contexts.get(ctx);
- if (edge.getCount() < edge_threshold || c.isRare(edge))
- continue;
double p[]=posterior(edge);
double z = arr.F.l1norm(p);
@@ -138,10 +135,9 @@ public class PhraseCluster { return loglikelihood;
}
- public double VBEM(double alphaEmit, double alphaPi, boolean skipBigPhrases)
+ public double VBEM(double alphaEmit, double alphaPi)
{
// FIXME: broken - needs to be done entirely in log-space
- assert !skipBigPhrases : "FIXME: implement this!";
double [][][]exp_emit = new double [K][n_positions][n_words];
double [][]exp_pi = new double[n_phrases][K];
@@ -240,21 +236,21 @@ public class PhraseCluster { return kl;
}
- public double PREM(double scalePT, double scaleCT, boolean skipBigPhrases)
+ public double PREM(double scalePT, double scaleCT, int phraseSizeLimit)
{
if (scaleCT == 0)
{
if (pool != null)
- return PREM_phrase_constraints_parallel(scalePT, skipBigPhrases);
+ return PREM_phrase_constraints_parallel(scalePT, phraseSizeLimit);
else
- return PREM_phrase_constraints(scalePT, skipBigPhrases);
+ return PREM_phrase_constraints(scalePT, phraseSizeLimit);
}
- else
- return this.PREM_phrase_context_constraints(scalePT, scaleCT, skipBigPhrases);
+ else // FIXME: ignores phraseSizeLimit
+ return this.PREM_phrase_context_constraints(scalePT, scaleCT);
}
- public double PREM_phrase_constraints(double scalePT, boolean skipBigPhrases)
+ public double PREM_phrase_constraints(double scalePT, int phraseSizeLimit)
{
double [][][]exp_emit=new double[K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
@@ -272,8 +268,9 @@ public class PhraseCluster { int failures=0, iterations=0;
long start = System.currentTimeMillis();
//E
- for(int phrase=0; phrase<n_phrases; phrase++){
- if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ for(int phrase=0; phrase<n_phrases; phrase++)
+ {
+ if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit)
{
System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
continue;
@@ -328,7 +325,7 @@ public class PhraseCluster { return primal;
}
- public double PREM_phrase_constraints_parallel(final double scalePT, boolean skipBigPhrases)
+ public double PREM_phrase_constraints_parallel(final double scalePT, int phraseSizeLimit)
{
assert(pool != null);
@@ -338,12 +335,11 @@ public class PhraseCluster { double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
- if (skipBigPhrases)
- {
- for(double [][]i:exp_emit)
- for(double []j:i)
- Arrays.fill(j, 1e-100);
- }
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-10);
+ for(double []j:pi)
+ Arrays.fill(j, 1e-10);
double loglikelihood=0, kl=0, l1lmax=0, primal=0;
final AtomicInteger failures = new AtomicInteger(0);
@@ -356,12 +352,13 @@ public class PhraseCluster { //E
for(int phrase=0;phrase<n_phrases;phrase++){
- if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit)
{
n -= 1;
System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
continue;
}
+
final int p=phrase;
pool.execute(new Runnable() {
public void run() {
@@ -445,10 +442,8 @@ public class PhraseCluster { return primal;
}
- public double PREM_phrase_context_constraints(double scalePT, double scaleCT, boolean skipBigPhrases)
+ public double PREM_phrase_context_constraints(double scalePT, double scaleCT)
{
- assert !skipBigPhrases : "Not supported yet - FIXME!"; //FIXME
-
double[][][] exp_emit = new double [K][n_positions][n_words];
double[][] exp_pi = new double[n_phrases][K];
@@ -500,10 +495,14 @@ public class PhraseCluster { */
public double[] posterior(Corpus.Edge edge)
{
- double[] prob=Arrays.copyOf(pi[edge.getPhraseId()], K);
-
- //if (edge.getCount() < edge_threshold)
- //System.out.println("Edge: " + edge + " probs for phrase " + Arrays.toString(prob));
+ double[] prob;
+ if (edge.getPhraseId() < n_phrases)
+ prob = Arrays.copyOf(pi[edge.getPhraseId()], K);
+ else
+ {
+ prob = new double[K];
+ Arrays.fill(prob, 1.0);
+ }
TIntArrayList ctx = edge.getContext();
for(int tag=0;tag<K;tag++)
@@ -511,23 +510,17 @@ public class PhraseCluster { for(int c=0;c<n_positions;c++)
{
int word = ctx.get(c);
- //if (edge.getCount() < edge_threshold)
- //System.out.println("\ttag: " + tag + " context word: " + word + " prob " + emit[tag][c][word]);
-
- if (!this.c.isSentinel(word))
+ if (!this.c.isSentinel(word) && word < n_words)
prob[tag]*=emit[tag][c][word];
}
}
-
- //if (edge.getCount() < edge_threshold)
- //System.out.println("prob " + Arrays.toString(prob));
return prob;
}
- public void displayPosterior(PrintStream ps)
+ public void displayPosterior(PrintStream ps, List<Edge> testing)
{
- for (Edge edge : c.getEdges())
+ for (Edge edge : testing)
{
double probs[] = posterior(edge);
arr.F.l1normalize(probs);
@@ -660,9 +653,4 @@ public class PhraseCluster { emit[t][p][w] = v;
}
}
-
- public void setEdgeThreshold(int edgeThreshold)
- {
- this.edge_threshold = edgeThreshold;
- }
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index 7f0b1970..ec1a5804 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -7,8 +7,11 @@ import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; import java.io.PrintStream; +import java.util.List; import java.util.Random; +import phrase.Corpus.Edge; + import arr.F; public class Trainer @@ -34,10 +37,6 @@ public class Trainer parser.accepts("agree"); parser.accepts("no-parameter-cache"); parser.accepts("skip-large-phrases").withRequiredArg().ofType(Integer.class).defaultsTo(5); - parser.accepts("rare-word").withRequiredArg().ofType(Integer.class).defaultsTo(10); - parser.accepts("rare-edge").withRequiredArg().ofType(Integer.class).defaultsTo(1); - parser.accepts("rare-phrase").withRequiredArg().ofType(Integer.class).defaultsTo(2); - parser.accepts("rare-context").withRequiredArg().ofType(Integer.class).defaultsTo(2); OptionSet options = parser.parse(args); if (options.has("help") || !options.has("in")) @@ -61,10 +60,6 @@ public class Trainer double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0; double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0; int skip = (Integer) options.valueOf("skip-large-phrases"); - int wordThreshold = (Integer) options.valueOf("rare-word"); - int edgeThreshold = (Integer) options.valueOf("rare-edge"); - int phraseThreshold = (Integer) options.valueOf("rare-phrase"); - int contextThreshold = (Integer) options.valueOf("rare-context"); if (options.has("seed")) F.rng = new Random((Long) options.valueOf("seed")); @@ -86,14 +81,7 @@ public class Trainer e.printStackTrace(); System.exit(1); } - - if (wordThreshold > 1) - corpus.applyWordThreshold(wordThreshold); - if (phraseThreshold > 1) - corpus.applyPhraseThreshold(phraseThreshold); - if (contextThreshold > 1) - corpus.applyContextThreshold(contextThreshold); - + if (!options.has("agree")) System.out.println("Running with " + tags + " tags " + "for " + iterations + " iterations " + @@ -127,7 +115,6 @@ public class Trainer e.printStackTrace(); } } - cluster.setEdgeThreshold(edgeThreshold); } double last = 0; @@ -138,20 +125,24 @@ public class Trainer o = agree.EM(); else { + if (i < skip) + System.out.println("Skipping phrases of length > " + (i+1)); + if (scale_phrase <= 0 && scale_context <= 0) { if (!vb) - o = cluster.EM(i < skip); + o = cluster.EM((i < skip) ? i+1 : 0); else - o = cluster.VBEM(alphaEmit, alphaPi, i < skip); + o = cluster.VBEM(alphaEmit, alphaPi); } else - o = cluster.PREM(scale_phrase, scale_context, i < skip); + o = cluster.PREM(scale_phrase, scale_context, (i < skip) ? i+1 : 0); } System.out.println("ITER: "+i+" objective: " + o); - if (i != 0 && Math.abs((o - last) / o) < threshold) + // sometimes takes a few iterations to break the ties + if (i > 5 && Math.abs((o - last) / o) < threshold) { last = o; break; @@ -171,10 +162,19 @@ public class Trainer File outfile = (File) options.valueOf("out"); try { PrintStream ps = FileUtil.printstream(outfile); - cluster.displayPosterior(ps); + List<Edge> test; + if (!options.has("test")) + test = corpus.getEdges(); + else + { + infile = (File) options.valueOf("test"); + System.out.println("Reading testing concordance from " + infile); + test = corpus.readEdges(FileUtil.reader(infile)); + } + cluster.displayPosterior(ps, test); ps.close(); } catch (IOException e) { - System.err.println("Failed to open output file: " + outfile); + System.err.println("Failed to open either testing file or output file"); e.printStackTrace(); System.exit(1); } |