diff options
author | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-15 22:48:44 +0000 |
---|---|---|
committer | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-15 22:48:44 +0000 |
commit | 4037e35c511aec96f780276aa4e3c1493e19eba1 (patch) | |
tree | 0e43592cc58682fe44d7d11abc6a9a835a0547a3 /gi/posterior-regularisation | |
parent | c14b17b45c1215b1d4a1495c161531d3e8936a34 (diff) |
Option to run on single word phrases before moving to larger ones.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@272 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation')
4 files changed, 99 insertions, 23 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/C2F.java b/gi/posterior-regularisation/prjava/src/phrase/C2F.java index a8e557f2..31fd4fda 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/C2F.java +++ b/gi/posterior-regularisation/prjava/src/phrase/C2F.java @@ -38,10 +38,10 @@ public class C2F { n_contexts=c.getNumContexts();
//number of words in a phrase to be considered
- //currently the first and last word
- //if the phrase has length 1
- //use the same word for two positions
- n_positions=2;
+ //currently the first and last word in source and target
+ //if the phrase has length 1 in either dimension then
+ //we use the same word for two positions
+ n_positions=c.phraseEdges(c.getEdges().get(0).getPhrase()).size();
emit=new double [K][n_positions][n_words];
pi=new double[n_contexts][K];
@@ -156,9 +156,13 @@ public class C2F { double[] prob=Arrays.copyOf(pi[edge.getContextId()], K);
TIntArrayList phrase = edge.getPhrase();
+ TIntArrayList offsets = c.phraseEdges(phrase);
for(int tag=0;tag<K;tag++)
- prob[tag]*=emit[tag][0][phrase.get(0)]
- *emit[tag][1][phrase.get(phrase.size()-1)];
+ {
+ for (int i=0; i < offsets.size(); ++i)
+ prob[tag]*=emit[tag][i][phrase.get(offsets.get(i))];
+ }
+
return prob;
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java index d57f3c04..2de2797b 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java @@ -15,6 +15,14 @@ public class Corpus private List<Edge> edges = new ArrayList<Edge>(); private List<List<Edge>> phraseToContext = new ArrayList<List<Edge>>(); private List<List<Edge>> contextToPhrase = new ArrayList<List<Edge>>(); + public int splitSentinel; + public int phraseSentinel; + + public Corpus() + { + splitSentinel = wordLexicon.insert("<SPLIT>"); + phraseSentinel = wordLexicon.insert("<PHRASE>"); + } public class Edge { @@ -157,6 +165,11 @@ public class Corpus return b.toString(); } + public boolean isSentinel(int wordId) + { + return wordId == splitSentinel || wordId == phraseSentinel; + } + static Corpus readFromFile(Reader in) throws IOException { Corpus c = new Corpus(); @@ -218,6 +231,19 @@ public class Corpus return c; } + + TIntArrayList phraseEdges(TIntArrayList phrase) + { + TIntArrayList r = new TIntArrayList(4); + for (int p = 0; p < phrase.size(); ++p) + { + if (p == 0 || phrase.get(p-1) == splitSentinel) + r.add(p); + if (p == phrase.size() - 1 || phrase.get(p+1) == splitSentinel) + r.add(p); + } + return r; + } public void printStats(PrintStream out) { diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index a369b319..5efaf52e 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -69,16 +69,29 @@ public class PhraseCluster { pool = Executors.newFixedThreadPool(threads);
}
- public double EM()
+ public double EM(boolean skipBigPhrases)
{
double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
+ if (skipBigPhrases)
+ {
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-100);
+ }
+
double loglikelihood=0;
//E
for(int phrase=0; phrase < n_phrases; phrase++)
{
+ if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ {
+ System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
+ continue;
+ }
+
List<Edge> contexts = c.getEdgesForPhrase(phrase);
for (int ctx=0; ctx<contexts.size(); ctx++)
@@ -116,9 +129,10 @@ public class PhraseCluster { return loglikelihood;
}
- public double VBEM(double alphaEmit, double alphaPi)
+ public double VBEM(double alphaEmit, double alphaPi, boolean skipBigPhrases)
{
// FIXME: broken - needs to be done entirely in log-space
+ assert !skipBigPhrases : "FIXME: implement this!";
double [][][]exp_emit = new double [K][n_positions][n_words];
double [][]exp_pi = new double[n_phrases][K];
@@ -217,24 +231,31 @@ public class PhraseCluster { return kl;
}
- public double PREM(double scalePT, double scaleCT)
+ public double PREM(double scalePT, double scaleCT, boolean skipBigPhrases)
{
if (scaleCT == 0)
{
if (pool != null)
- return PREM_phrase_constraints_parallel(scalePT);
+ return PREM_phrase_constraints_parallel(scalePT, skipBigPhrases);
else
- return PREM_phrase_constraints(scalePT);
+ return PREM_phrase_constraints(scalePT, skipBigPhrases);
}
else
- return this.PREM_phrase_context_constraints(scalePT, scaleCT);
+ return this.PREM_phrase_context_constraints(scalePT, scaleCT, skipBigPhrases);
}
- public double PREM_phrase_constraints(double scalePT)
+ public double PREM_phrase_constraints(double scalePT, boolean skipBigPhrases)
{
double [][][]exp_emit=new double[K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
+
+ if (skipBigPhrases)
+ {
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-100);
+ }
if (lambdaPT == null && cacheLambda)
lambdaPT = new double[n_phrases][];
@@ -244,6 +265,12 @@ public class PhraseCluster { long start = System.currentTimeMillis();
//E
for(int phrase=0; phrase<n_phrases; phrase++){
+ if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ {
+ System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
+ continue;
+ }
+
PhraseObjective po = new PhraseObjective(this, phrase, scalePT, (cacheLambda) ? lambdaPT[phrase] : null);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) ++failures;
@@ -292,7 +319,7 @@ public class PhraseCluster { return primal;
}
- public double PREM_phrase_constraints_parallel(final double scalePT)
+ public double PREM_phrase_constraints_parallel(final double scalePT, boolean skipBigPhrases)
{
assert(pool != null);
@@ -302,10 +329,17 @@ public class PhraseCluster { double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
+ if (skipBigPhrases)
+ {
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-100);
+ }
+
double loglikelihood=0, kl=0, l1lmax=0, primal=0;
final AtomicInteger failures = new AtomicInteger(0);
final AtomicLong elapsed = new AtomicLong(0l);
- int iterations=0;
+ int iterations=0, n=n_phrases;
long start = System.currentTimeMillis();
if (lambdaPT == null && cacheLambda)
@@ -313,6 +347,12 @@ public class PhraseCluster { //E
for(int phrase=0;phrase<n_phrases;phrase++){
+ if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ {
+ n -= 1;
+ System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
+ continue;
+ }
final int p=phrase;
pool.execute(new Runnable() {
public void run() {
@@ -337,7 +377,7 @@ public class PhraseCluster { }
// aggregate the expectations as they become available
- for(int count=0;count<n_phrases;count++) {
+ for(int count=0;count<n;count++) {
try {
//System.out.println("" + Thread.currentThread().getId() + " reading queue #" + count);
@@ -396,8 +436,10 @@ public class PhraseCluster { return primal;
}
- public double PREM_phrase_context_constraints(double scalePT, double scaleCT)
+ public double PREM_phrase_context_constraints(double scalePT, double scaleCT, boolean skipBigPhrases)
{
+ assert !skipBigPhrases : "Not supported yet - FIXME!"; //FIXME
+
double[][][] exp_emit = new double [K][n_positions][n_words];
double[][] exp_pi = new double[n_phrases][K];
@@ -454,7 +496,8 @@ public class PhraseCluster { TIntArrayList ctx = edge.getContext();
for(int tag=0;tag<K;tag++)
for(int c=0;c<n_positions;c++)
- prob[tag]*=emit[tag][c][ctx.get(c)];
+ if (!this.c.isSentinel(ctx.get(c)))
+ prob[tag]*=emit[tag][c][ctx.get(c)];
return prob;
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index 20f6c905..a67c17a2 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -32,6 +32,7 @@ public class Trainer parser.accepts("alpha-pi").withRequiredArg().ofType(Double.class).defaultsTo(0.01); parser.accepts("agree"); parser.accepts("no-parameter-cache"); + parser.accepts("skip-large-phrases").withRequiredArg().ofType(Integer.class).defaultsTo(5); OptionSet options = parser.parse(args); if (options.has("help") || !options.has("in")) @@ -55,6 +56,7 @@ public class Trainer boolean vb = options.has("variational-bayes"); double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0; double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0; + int skip = (Integer) options.valueOf("skip-large-phrases"); if (options.has("seed")) F.rng = new Random((Long) options.valueOf("seed")); @@ -80,6 +82,7 @@ public class Trainer if (!options.has("agree")) System.out.println("Running with " + tags + " tags " + "for " + em_iterations + " EM and " + pr_iterations + " PR iterations " + + "skipping large phrases for first " + skip + " iterations " + "with scale " + scale_phrase + " phrase and " + scale_context + " context " + "and " + threads + " threads"); else @@ -112,12 +115,12 @@ public class Trainer if (i < em_iterations) { if (!vb) - o = cluster.EM(); + o = cluster.EM(i < skip); else - o = cluster.VBEM(alphaEmit, alphaPi); + o = cluster.VBEM(alphaEmit, alphaPi, i < skip); } else - o = cluster.PREM(scale_phrase, scale_context); + o = cluster.PREM(scale_phrase, scale_context, i < skip); } System.out.println("ITER: "+i+" objective: " + o); @@ -125,9 +128,9 @@ public class Trainer if (i != 0 && Math.abs((o - last) / o) < threshold) { last = o; - if (i < em_iterations) + if (i < Math.max(em_iterations, skip)) { - i = em_iterations - 1; + i = Math.max(em_iterations, skip) - 1; continue; } else |