summaryrefslogtreecommitdiff
path: root/gi
diff options
context:
space:
mode:
authortrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-15 22:48:44 +0000
committertrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-15 22:48:44 +0000
commit000093dc417088be9a99278dc59203a30f976289 (patch)
tree184ecb94f8ec09ccae59bdd64780a76fb59b5de6 /gi
parentcad8dd252814fdf76caf3d5576d1f8f877524253 (diff)
Option to run on single word phrases before moving to larger ones.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@272 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/C2F.java16
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Corpus.java26
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java67
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Trainer.java13
4 files changed, 99 insertions, 23 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/C2F.java b/gi/posterior-regularisation/prjava/src/phrase/C2F.java
index a8e557f2..31fd4fda 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/C2F.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/C2F.java
@@ -38,10 +38,10 @@ public class C2F {
n_contexts=c.getNumContexts();
//number of words in a phrase to be considered
- //currently the first and last word
- //if the phrase has length 1
- //use the same word for two positions
- n_positions=2;
+ //currently the first and last word in source and target
+ //if the phrase has length 1 in either dimension then
+ //we use the same word for two positions
+ n_positions=c.phraseEdges(c.getEdges().get(0).getPhrase()).size();
emit=new double [K][n_positions][n_words];
pi=new double[n_contexts][K];
@@ -156,9 +156,13 @@ public class C2F {
double[] prob=Arrays.copyOf(pi[edge.getContextId()], K);
TIntArrayList phrase = edge.getPhrase();
+ TIntArrayList offsets = c.phraseEdges(phrase);
for(int tag=0;tag<K;tag++)
- prob[tag]*=emit[tag][0][phrase.get(0)]
- *emit[tag][1][phrase.get(phrase.size()-1)];
+ {
+ for (int i=0; i < offsets.size(); ++i)
+ prob[tag]*=emit[tag][i][phrase.get(offsets.get(i))];
+ }
+
return prob;
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
index d57f3c04..2de2797b 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
@@ -15,6 +15,14 @@ public class Corpus
private List<Edge> edges = new ArrayList<Edge>();
private List<List<Edge>> phraseToContext = new ArrayList<List<Edge>>();
private List<List<Edge>> contextToPhrase = new ArrayList<List<Edge>>();
+ public int splitSentinel;
+ public int phraseSentinel;
+
+ public Corpus()
+ {
+ splitSentinel = wordLexicon.insert("<SPLIT>");
+ phraseSentinel = wordLexicon.insert("<PHRASE>");
+ }
public class Edge
{
@@ -157,6 +165,11 @@ public class Corpus
return b.toString();
}
+ public boolean isSentinel(int wordId)
+ {
+ return wordId == splitSentinel || wordId == phraseSentinel;
+ }
+
static Corpus readFromFile(Reader in) throws IOException
{
Corpus c = new Corpus();
@@ -218,6 +231,19 @@ public class Corpus
return c;
}
+
+ TIntArrayList phraseEdges(TIntArrayList phrase)
+ {
+ TIntArrayList r = new TIntArrayList(4);
+ for (int p = 0; p < phrase.size(); ++p)
+ {
+ if (p == 0 || phrase.get(p-1) == splitSentinel)
+ r.add(p);
+ if (p == phrase.size() - 1 || phrase.get(p+1) == splitSentinel)
+ r.add(p);
+ }
+ return r;
+ }
public void printStats(PrintStream out)
{
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index a369b319..5efaf52e 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -69,16 +69,29 @@ public class PhraseCluster {
pool = Executors.newFixedThreadPool(threads);
}
- public double EM()
+ public double EM(boolean skipBigPhrases)
{
double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
+ if (skipBigPhrases)
+ {
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-100);
+ }
+
double loglikelihood=0;
//E
for(int phrase=0; phrase < n_phrases; phrase++)
{
+ if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ {
+ System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
+ continue;
+ }
+
List<Edge> contexts = c.getEdgesForPhrase(phrase);
for (int ctx=0; ctx<contexts.size(); ctx++)
@@ -116,9 +129,10 @@ public class PhraseCluster {
return loglikelihood;
}
- public double VBEM(double alphaEmit, double alphaPi)
+ public double VBEM(double alphaEmit, double alphaPi, boolean skipBigPhrases)
{
// FIXME: broken - needs to be done entirely in log-space
+ assert !skipBigPhrases : "FIXME: implement this!";
double [][][]exp_emit = new double [K][n_positions][n_words];
double [][]exp_pi = new double[n_phrases][K];
@@ -217,24 +231,31 @@ public class PhraseCluster {
return kl;
}
- public double PREM(double scalePT, double scaleCT)
+ public double PREM(double scalePT, double scaleCT, boolean skipBigPhrases)
{
if (scaleCT == 0)
{
if (pool != null)
- return PREM_phrase_constraints_parallel(scalePT);
+ return PREM_phrase_constraints_parallel(scalePT, skipBigPhrases);
else
- return PREM_phrase_constraints(scalePT);
+ return PREM_phrase_constraints(scalePT, skipBigPhrases);
}
else
- return this.PREM_phrase_context_constraints(scalePT, scaleCT);
+ return this.PREM_phrase_context_constraints(scalePT, scaleCT, skipBigPhrases);
}
- public double PREM_phrase_constraints(double scalePT)
+ public double PREM_phrase_constraints(double scalePT, boolean skipBigPhrases)
{
double [][][]exp_emit=new double[K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
+
+ if (skipBigPhrases)
+ {
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-100);
+ }
if (lambdaPT == null && cacheLambda)
lambdaPT = new double[n_phrases][];
@@ -244,6 +265,12 @@ public class PhraseCluster {
long start = System.currentTimeMillis();
//E
for(int phrase=0; phrase<n_phrases; phrase++){
+ if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ {
+ System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
+ continue;
+ }
+
PhraseObjective po = new PhraseObjective(this, phrase, scalePT, (cacheLambda) ? lambdaPT[phrase] : null);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) ++failures;
@@ -292,7 +319,7 @@ public class PhraseCluster {
return primal;
}
- public double PREM_phrase_constraints_parallel(final double scalePT)
+ public double PREM_phrase_constraints_parallel(final double scalePT, boolean skipBigPhrases)
{
assert(pool != null);
@@ -302,10 +329,17 @@ public class PhraseCluster {
double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
+ if (skipBigPhrases)
+ {
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-100);
+ }
+
double loglikelihood=0, kl=0, l1lmax=0, primal=0;
final AtomicInteger failures = new AtomicInteger(0);
final AtomicLong elapsed = new AtomicLong(0l);
- int iterations=0;
+ int iterations=0, n=n_phrases;
long start = System.currentTimeMillis();
if (lambdaPT == null && cacheLambda)
@@ -313,6 +347,12 @@ public class PhraseCluster {
//E
for(int phrase=0;phrase<n_phrases;phrase++){
+ if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ {
+ n -= 1;
+ System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
+ continue;
+ }
final int p=phrase;
pool.execute(new Runnable() {
public void run() {
@@ -337,7 +377,7 @@ public class PhraseCluster {
}
// aggregate the expectations as they become available
- for(int count=0;count<n_phrases;count++) {
+ for(int count=0;count<n;count++) {
try {
//System.out.println("" + Thread.currentThread().getId() + " reading queue #" + count);
@@ -396,8 +436,10 @@ public class PhraseCluster {
return primal;
}
- public double PREM_phrase_context_constraints(double scalePT, double scaleCT)
+ public double PREM_phrase_context_constraints(double scalePT, double scaleCT, boolean skipBigPhrases)
{
+ assert !skipBigPhrases : "Not supported yet - FIXME!"; //FIXME
+
double[][][] exp_emit = new double [K][n_positions][n_words];
double[][] exp_pi = new double[n_phrases][K];
@@ -454,7 +496,8 @@ public class PhraseCluster {
TIntArrayList ctx = edge.getContext();
for(int tag=0;tag<K;tag++)
for(int c=0;c<n_positions;c++)
- prob[tag]*=emit[tag][c][ctx.get(c)];
+ if (!this.c.isSentinel(ctx.get(c)))
+ prob[tag]*=emit[tag][c][ctx.get(c)];
return prob;
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
index 20f6c905..a67c17a2 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
@@ -32,6 +32,7 @@ public class Trainer
parser.accepts("alpha-pi").withRequiredArg().ofType(Double.class).defaultsTo(0.01);
parser.accepts("agree");
parser.accepts("no-parameter-cache");
+ parser.accepts("skip-large-phrases").withRequiredArg().ofType(Integer.class).defaultsTo(5);
OptionSet options = parser.parse(args);
if (options.has("help") || !options.has("in"))
@@ -55,6 +56,7 @@ public class Trainer
boolean vb = options.has("variational-bayes");
double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0;
double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0;
+ int skip = (Integer) options.valueOf("skip-large-phrases");
if (options.has("seed"))
F.rng = new Random((Long) options.valueOf("seed"));
@@ -80,6 +82,7 @@ public class Trainer
if (!options.has("agree"))
System.out.println("Running with " + tags + " tags " +
"for " + em_iterations + " EM and " + pr_iterations + " PR iterations " +
+ "skipping large phrases for first " + skip + " iterations " +
"with scale " + scale_phrase + " phrase and " + scale_context + " context " +
"and " + threads + " threads");
else
@@ -112,12 +115,12 @@ public class Trainer
if (i < em_iterations)
{
if (!vb)
- o = cluster.EM();
+ o = cluster.EM(i < skip);
else
- o = cluster.VBEM(alphaEmit, alphaPi);
+ o = cluster.VBEM(alphaEmit, alphaPi, i < skip);
}
else
- o = cluster.PREM(scale_phrase, scale_context);
+ o = cluster.PREM(scale_phrase, scale_context, i < skip);
}
System.out.println("ITER: "+i+" objective: " + o);
@@ -125,9 +128,9 @@ public class Trainer
if (i != 0 && Math.abs((o - last) / o) < threshold)
{
last = o;
- if (i < em_iterations)
+ if (i < Math.max(em_iterations, skip))
{
- i = em_iterations - 1;
+ i = Math.max(em_iterations, skip) - 1;
continue;
}
else