summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation
diff options
context:
space:
mode:
authortrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-19 22:28:10 +0000
committertrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-19 22:28:10 +0000
commit5a8ea689c8a4e9cf3e72f88a253b08153bf32dde (patch)
treeee208f756ca67201b4a8c4cb34c965d11ccd001b /gi/posterior-regularisation
parent7b2b7938405fa198e075e6aa19e40c57ed8db2da (diff)
Reversed out broken thresholding
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@324 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Corpus.java199
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java80
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Trainer.java46
3 files changed, 92 insertions, 233 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
index 2afc18dc..ad092cc6 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
@@ -18,9 +18,6 @@ public class Corpus
public int splitSentinel;
public int phraseSentinel;
public int rareSentinel;
- private boolean[] rareWords;
- private boolean[] rarePhrases;
- private boolean[] rareContexts;
public Corpus()
{
@@ -45,10 +42,6 @@ public class Corpus
{
return Corpus.this.getPhrase(phraseId);
}
- public TIntArrayList getRawPhrase()
- {
- return Corpus.this.getRawPhrase(phraseId);
- }
public String getPhraseString()
{
return Corpus.this.getPhraseString(phraseId);
@@ -61,10 +54,6 @@ public class Corpus
{
return Corpus.this.getContext(contextId);
}
- public TIntArrayList getRawContext()
- {
- return Corpus.this.getRawContext(contextId);
- }
public String getContextString(boolean insertPhraseSentinel)
{
return Corpus.this.getContextString(contextId, insertPhraseSentinel);
@@ -145,35 +134,13 @@ public class Corpus
public TIntArrayList getPhrase(int phraseId)
{
- TIntArrayList phrase = phraseLexicon.lookup(phraseId);
- if (rareWords != null)
- {
- boolean first = true;
- for (int i = 0; i < phrase.size(); ++i)
- {
- if (rareWords[phrase.get(i)])
- {
- if (first)
- {
- phrase = (TIntArrayList) phrase.clone();
- first = false;
- }
- phrase.set(i, rareSentinel);
- }
- }
- }
- return phrase;
- }
-
- public TIntArrayList getRawPhrase(int phraseId)
- {
return phraseLexicon.lookup(phraseId);
}
public String getPhraseString(int phraseId)
{
StringBuffer b = new StringBuffer();
- for (int tid: getRawPhrase(phraseId).toNativeArray())
+ for (int tid: getPhrase(phraseId).toNativeArray())
{
if (b.length() > 0)
b.append(" ");
@@ -184,35 +151,13 @@ public class Corpus
public TIntArrayList getContext(int contextId)
{
- TIntArrayList context = contextLexicon.lookup(contextId);
- if (rareWords != null)
- {
- boolean first = true;
- for (int i = 0; i < context.size(); ++i)
- {
- if (rareWords[context.get(i)])
- {
- if (first)
- {
- context = (TIntArrayList) context.clone();
- first = false;
- }
- context.set(i, rareSentinel);
- }
- }
- }
- return context;
- }
-
- public TIntArrayList getRawContext(int contextId)
- {
return contextLexicon.lookup(contextId);
}
public String getContextString(int contextId, boolean insertPhraseSentinel)
{
StringBuffer b = new StringBuffer();
- TIntArrayList c = getRawContext(contextId);
+ TIntArrayList c = getContext(contextId);
for (int i = 0; i < c.size(); ++i)
{
if (i > 0) b.append(" ");
@@ -227,15 +172,14 @@ public class Corpus
return wordId == splitSentinel || wordId == phraseSentinel;
}
- static Corpus readFromFile(Reader in) throws IOException
- {
- Corpus c = new Corpus();
-
+ List<Edge> readEdges(Reader in) throws IOException
+ {
// read in line-by-line
BufferedReader bin = new BufferedReader(in);
String line;
Pattern separator = Pattern.compile(" \\|\\|\\| ");
-
+
+ List<Edge> edges = new ArrayList<Edge>();
while ((line = bin.readLine()) != null)
{
// split into phrase and contexts
@@ -250,10 +194,8 @@ public class Corpus
st = new StringTokenizer(phraseToks, " ");
TIntArrayList ptoks = new TIntArrayList();
while (st.hasMoreTokens())
- ptoks.add(c.wordLexicon.insert(st.nextToken()));
- int phraseId = c.phraseLexicon.insert(ptoks);
- if (phraseId == c.phraseToContext.size())
- c.phraseToContext.add(new ArrayList<Edge>());
+ ptoks.add(wordLexicon.insert(st.nextToken()));
+ int phraseId = phraseLexicon.insert(ptoks);
// process contexts
String[] parts = separator.split(rest);
@@ -261,34 +203,45 @@ public class Corpus
for (int i = 0; i < parts.length; i += 2)
{
// process pairs of strings - context and count
- TIntArrayList ctx = new TIntArrayList();
String ctxString = parts[i];
String countString = parts[i + 1];
+
+ assert (countString.startsWith("C="));
+ int count = Integer.parseInt(countString.substring(2).trim());
+
+ TIntArrayList ctx = new TIntArrayList();
StringTokenizer ctxStrtok = new StringTokenizer(ctxString, " ");
while (ctxStrtok.hasMoreTokens())
{
String token = ctxStrtok.nextToken();
- //if (!token.equals("<PHRASE>"))
- ctx.add(c.wordLexicon.insert(token));
+ ctx.add(wordLexicon.insert(token));
}
- int contextId = c.contextLexicon.insert(ctx);
- if (contextId == c.contextToPhrase.size())
- c.contextToPhrase.add(new ArrayList<Edge>());
+ int contextId = contextLexicon.insert(ctx);
- assert (countString.startsWith("C="));
- Edge e = c.new Edge(phraseId, contextId,
- Integer.parseInt(countString.substring(2).trim()));
- c.edges.add(e);
-
- // index the edge for fast phrase, context lookup
- c.phraseToContext.get(phraseId).add(e);
- c.contextToPhrase.get(contextId).add(e);
+ edges.add(new Edge(phraseId, contextId, count));
}
}
-
- return c;
+ return edges;
}
+ static Corpus readFromFile(Reader in) throws IOException
+ {
+ Corpus c = new Corpus();
+ c.edges = c.readEdges(in);
+ for (Edge edge: c.edges)
+ {
+ while (edge.getPhraseId() >= c.phraseToContext.size())
+ c.phraseToContext.add(new ArrayList<Edge>());
+ while (edge.getContextId() >= c.contextToPhrase.size())
+ c.contextToPhrase.add(new ArrayList<Edge>());
+
+ // index the edge for fast phrase, context lookup
+ c.phraseToContext.get(edge.getPhraseId()).add(edge);
+ c.contextToPhrase.get(edge.getContextId()).add(edge);
+ }
+ return c;
+ }
+
TIntArrayList phraseEdges(TIntArrayList phrase)
{
TIntArrayList r = new TIntArrayList(4);
@@ -307,86 +260,4 @@ public class Corpus
out.println("Corpus has " + edges.size() + " edges " + phraseLexicon.size() + " phrases "
+ contextLexicon.size() + " contexts and " + wordLexicon.size() + " word types");
}
-
- public void applyWordThreshold(int wordThreshold)
- {
- int[] counts = new int[wordLexicon.size()];
- for (Edge e: edges)
- {
- TIntArrayList phrase = e.getPhrase();
- for (int i = 0; i < phrase.size(); ++i)
- counts[phrase.get(i)] += e.getCount();
-
- TIntArrayList context = e.getContext();
- for (int i = 0; i < context.size(); ++i)
- counts[context.get(i)] += e.getCount();
- }
-
- int count = 0;
- rareWords = new boolean[wordLexicon.size()];
- for (int i = 0; i < wordLexicon.size(); ++i)
- {
- rareWords[i] = counts[i] < wordThreshold;
- if (rareWords[i])
- count++;
- }
- System.err.println("There are " + count + " rare words");
- }
-
- public void applyPhraseThreshold(int threshold)
- {
- rarePhrases = new boolean[phraseLexicon.size()];
-
- int n = 0;
- for (int i = 0; i < phraseLexicon.size(); ++i)
- {
- List<Edge> contexts = phraseToContext.get(i);
- int count = 0;
- for (Edge edge: contexts)
- {
- count += edge.getCount();
- if (count >= threshold)
- break;
- }
-
- if (count < threshold)
- {
- rarePhrases[i] = true;
- n++;
- }
- }
- System.err.println("There are " + n + " rare phrases");
- }
-
- public void applyContextThreshold(int threshold)
- {
- rareContexts = new boolean[contextLexicon.size()];
-
- int n = 0;
- for (int i = 0; i < contextLexicon.size(); ++i)
- {
- List<Edge> phrases = contextToPhrase.get(i);
- int count = 0;
- for (Edge edge: phrases)
- {
- count += edge.getCount();
- if (count >= threshold)
- break;
- }
-
- if (count < threshold)
- {
- rareContexts[i] = true;
- n++;
- }
- }
- System.err.println("There are " + n + " rare contexts");
- }
-
- boolean isRare(Edge edge)
- {
- if (rarePhrases != null && rarePhrases[edge.getPhraseId()] == true) return true;
- if (rareContexts != null && rareContexts[edge.getContextId()] == true) return true;
- return false;
- }
} \ No newline at end of file
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index 9ee766d4..e7e4af32 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -23,7 +23,7 @@ import phrase.Corpus.Edge;
public class PhraseCluster {
public int K;
- private int n_phrases, n_words, n_contexts, n_positions, edge_threshold;
+ private int n_phrases, n_words, n_contexts, n_positions;
public Corpus c;
public ExecutorService pool;
@@ -44,7 +44,6 @@ public class PhraseCluster {
n_phrases=c.getNumPhrases();
n_contexts=c.getNumContexts();
n_positions=c.getNumContextPositions();
- edge_threshold=0;
emit=new double [K][n_positions][n_words];
pi=new double[n_phrases][K];
@@ -76,7 +75,7 @@ public class PhraseCluster {
pool = Executors.newFixedThreadPool(threads);
}
- public double EM(boolean skipBigPhrases)
+ public double EM(int phraseSizeLimit)
{
double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
@@ -92,19 +91,17 @@ public class PhraseCluster {
//E
for(int phrase=0; phrase < n_phrases; phrase++)
{
- if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit)
{
System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
continue;
- }
+ }
List<Edge> contexts = c.getEdgesForPhrase(phrase);
for (int ctx=0; ctx<contexts.size(); ctx++)
{
Edge edge = contexts.get(ctx);
- if (edge.getCount() < edge_threshold || c.isRare(edge))
- continue;
double p[]=posterior(edge);
double z = arr.F.l1norm(p);
@@ -138,10 +135,9 @@ public class PhraseCluster {
return loglikelihood;
}
- public double VBEM(double alphaEmit, double alphaPi, boolean skipBigPhrases)
+ public double VBEM(double alphaEmit, double alphaPi)
{
// FIXME: broken - needs to be done entirely in log-space
- assert !skipBigPhrases : "FIXME: implement this!";
double [][][]exp_emit = new double [K][n_positions][n_words];
double [][]exp_pi = new double[n_phrases][K];
@@ -240,21 +236,21 @@ public class PhraseCluster {
return kl;
}
- public double PREM(double scalePT, double scaleCT, boolean skipBigPhrases)
+ public double PREM(double scalePT, double scaleCT, int phraseSizeLimit)
{
if (scaleCT == 0)
{
if (pool != null)
- return PREM_phrase_constraints_parallel(scalePT, skipBigPhrases);
+ return PREM_phrase_constraints_parallel(scalePT, phraseSizeLimit);
else
- return PREM_phrase_constraints(scalePT, skipBigPhrases);
+ return PREM_phrase_constraints(scalePT, phraseSizeLimit);
}
- else
- return this.PREM_phrase_context_constraints(scalePT, scaleCT, skipBigPhrases);
+ else // FIXME: ignores phraseSizeLimit
+ return this.PREM_phrase_context_constraints(scalePT, scaleCT);
}
- public double PREM_phrase_constraints(double scalePT, boolean skipBigPhrases)
+ public double PREM_phrase_constraints(double scalePT, int phraseSizeLimit)
{
double [][][]exp_emit=new double[K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
@@ -272,8 +268,9 @@ public class PhraseCluster {
int failures=0, iterations=0;
long start = System.currentTimeMillis();
//E
- for(int phrase=0; phrase<n_phrases; phrase++){
- if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ for(int phrase=0; phrase<n_phrases; phrase++)
+ {
+ if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit)
{
System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
continue;
@@ -328,7 +325,7 @@ public class PhraseCluster {
return primal;
}
- public double PREM_phrase_constraints_parallel(final double scalePT, boolean skipBigPhrases)
+ public double PREM_phrase_constraints_parallel(final double scalePT, int phraseSizeLimit)
{
assert(pool != null);
@@ -338,12 +335,11 @@ public class PhraseCluster {
double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
- if (skipBigPhrases)
- {
- for(double [][]i:exp_emit)
- for(double []j:i)
- Arrays.fill(j, 1e-100);
- }
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-10);
+ for(double []j:pi)
+ Arrays.fill(j, 1e-10);
double loglikelihood=0, kl=0, l1lmax=0, primal=0;
final AtomicInteger failures = new AtomicInteger(0);
@@ -356,12 +352,13 @@ public class PhraseCluster {
//E
for(int phrase=0;phrase<n_phrases;phrase++){
- if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit)
{
n -= 1;
System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
continue;
}
+
final int p=phrase;
pool.execute(new Runnable() {
public void run() {
@@ -445,10 +442,8 @@ public class PhraseCluster {
return primal;
}
- public double PREM_phrase_context_constraints(double scalePT, double scaleCT, boolean skipBigPhrases)
+ public double PREM_phrase_context_constraints(double scalePT, double scaleCT)
{
- assert !skipBigPhrases : "Not supported yet - FIXME!"; //FIXME
-
double[][][] exp_emit = new double [K][n_positions][n_words];
double[][] exp_pi = new double[n_phrases][K];
@@ -500,10 +495,14 @@ public class PhraseCluster {
*/
public double[] posterior(Corpus.Edge edge)
{
- double[] prob=Arrays.copyOf(pi[edge.getPhraseId()], K);
-
- //if (edge.getCount() < edge_threshold)
- //System.out.println("Edge: " + edge + " probs for phrase " + Arrays.toString(prob));
+ double[] prob;
+ if (edge.getPhraseId() < n_phrases)
+ prob = Arrays.copyOf(pi[edge.getPhraseId()], K);
+ else
+ {
+ prob = new double[K];
+ Arrays.fill(prob, 1.0);
+ }
TIntArrayList ctx = edge.getContext();
for(int tag=0;tag<K;tag++)
@@ -511,23 +510,17 @@ public class PhraseCluster {
for(int c=0;c<n_positions;c++)
{
int word = ctx.get(c);
- //if (edge.getCount() < edge_threshold)
- //System.out.println("\ttag: " + tag + " context word: " + word + " prob " + emit[tag][c][word]);
-
- if (!this.c.isSentinel(word))
+ if (!this.c.isSentinel(word) && word < n_words)
prob[tag]*=emit[tag][c][word];
}
}
-
- //if (edge.getCount() < edge_threshold)
- //System.out.println("prob " + Arrays.toString(prob));
return prob;
}
- public void displayPosterior(PrintStream ps)
+ public void displayPosterior(PrintStream ps, List<Edge> testing)
{
- for (Edge edge : c.getEdges())
+ for (Edge edge : testing)
{
double probs[] = posterior(edge);
arr.F.l1normalize(probs);
@@ -660,9 +653,4 @@ public class PhraseCluster {
emit[t][p][w] = v;
}
}
-
- public void setEdgeThreshold(int edgeThreshold)
- {
- this.edge_threshold = edgeThreshold;
- }
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
index 7f0b1970..ec1a5804 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
@@ -7,8 +7,11 @@ import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintStream;
+import java.util.List;
import java.util.Random;
+import phrase.Corpus.Edge;
+
import arr.F;
public class Trainer
@@ -34,10 +37,6 @@ public class Trainer
parser.accepts("agree");
parser.accepts("no-parameter-cache");
parser.accepts("skip-large-phrases").withRequiredArg().ofType(Integer.class).defaultsTo(5);
- parser.accepts("rare-word").withRequiredArg().ofType(Integer.class).defaultsTo(10);
- parser.accepts("rare-edge").withRequiredArg().ofType(Integer.class).defaultsTo(1);
- parser.accepts("rare-phrase").withRequiredArg().ofType(Integer.class).defaultsTo(2);
- parser.accepts("rare-context").withRequiredArg().ofType(Integer.class).defaultsTo(2);
OptionSet options = parser.parse(args);
if (options.has("help") || !options.has("in"))
@@ -61,10 +60,6 @@ public class Trainer
double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0;
double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0;
int skip = (Integer) options.valueOf("skip-large-phrases");
- int wordThreshold = (Integer) options.valueOf("rare-word");
- int edgeThreshold = (Integer) options.valueOf("rare-edge");
- int phraseThreshold = (Integer) options.valueOf("rare-phrase");
- int contextThreshold = (Integer) options.valueOf("rare-context");
if (options.has("seed"))
F.rng = new Random((Long) options.valueOf("seed"));
@@ -86,14 +81,7 @@ public class Trainer
e.printStackTrace();
System.exit(1);
}
-
- if (wordThreshold > 1)
- corpus.applyWordThreshold(wordThreshold);
- if (phraseThreshold > 1)
- corpus.applyPhraseThreshold(phraseThreshold);
- if (contextThreshold > 1)
- corpus.applyContextThreshold(contextThreshold);
-
+
if (!options.has("agree"))
System.out.println("Running with " + tags + " tags " +
"for " + iterations + " iterations " +
@@ -127,7 +115,6 @@ public class Trainer
e.printStackTrace();
}
}
- cluster.setEdgeThreshold(edgeThreshold);
}
double last = 0;
@@ -138,20 +125,24 @@ public class Trainer
o = agree.EM();
else
{
+ if (i < skip)
+ System.out.println("Skipping phrases of length > " + (i+1));
+
if (scale_phrase <= 0 && scale_context <= 0)
{
if (!vb)
- o = cluster.EM(i < skip);
+ o = cluster.EM((i < skip) ? i+1 : 0);
else
- o = cluster.VBEM(alphaEmit, alphaPi, i < skip);
+ o = cluster.VBEM(alphaEmit, alphaPi);
}
else
- o = cluster.PREM(scale_phrase, scale_context, i < skip);
+ o = cluster.PREM(scale_phrase, scale_context, (i < skip) ? i+1 : 0);
}
System.out.println("ITER: "+i+" objective: " + o);
- if (i != 0 && Math.abs((o - last) / o) < threshold)
+ // sometimes takes a few iterations to break the ties
+ if (i > 5 && Math.abs((o - last) / o) < threshold)
{
last = o;
break;
@@ -171,10 +162,19 @@ public class Trainer
File outfile = (File) options.valueOf("out");
try {
PrintStream ps = FileUtil.printstream(outfile);
- cluster.displayPosterior(ps);
+ List<Edge> test;
+ if (!options.has("test"))
+ test = corpus.getEdges();
+ else
+ {
+ infile = (File) options.valueOf("test");
+ System.out.println("Reading testing concordance from " + infile);
+ test = corpus.readEdges(FileUtil.reader(infile));
+ }
+ cluster.displayPosterior(ps, test);
ps.close();
} catch (IOException e) {
- System.err.println("Failed to open output file: " + outfile);
+ System.err.println("Failed to open either testing file or output file");
e.printStackTrace();
System.exit(1);
}