diff options
author | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-12 19:48:54 +0000 |
---|---|---|
committer | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-12 19:48:54 +0000 |
commit | 77c25d9f30f95ccb7843f9dce71a4f4e018cc727 (patch) | |
tree | 41ed18b1fe81fd4288a4e81e5f4e6efae392643b /gi/posterior-regularisation/prjava/src | |
parent | 2f90f9e203da01583ee0b82d4769f25b198835dd (diff) |
Updated launcher to include agreement model.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@226 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src')
5 files changed, 104 insertions, 75 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Agree.java b/gi/posterior-regularisation/prjava/src/phrase/Agree.java index d5b949b0..d61e6eef 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Agree.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Agree.java @@ -12,8 +12,8 @@ import java.util.List; import phrase.Corpus.Edge;
public class Agree {
- private PhraseCluster model1;
- private C2F model2;
+ PhraseCluster model1;
+ C2F model2;
Corpus c;
private int K,n_phrases, n_words, n_contexts, n_positions1,n_positions2;
@@ -32,7 +32,7 @@ public class Agree { */
public Agree(int numCluster, Corpus corpus){
- model1=new PhraseCluster(numCluster, corpus, 0, 0, 0);
+ model1=new PhraseCluster(numCluster, corpus);
model2=new C2F(numCluster,corpus);
c=corpus;
n_words=c.getNumWords();
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 7bc63c33..abd868c4 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -18,21 +18,16 @@ import util.MathUtil; public class PhraseCluster {
public int K;
- public double scalePT, scaleCT;
private int n_phrases, n_words, n_contexts, n_positions;
public Corpus c;
public ExecutorService pool;
// emit[tag][position][word] = p(word | tag, position in context)
- private double emit[][][];
+ double emit[][][];
// pi[phrase][tag] = p(tag | phrase)
- private double pi[][];
+ double pi[][];
- double alphaEmit;
- double alphaPi;
-
- public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads,
- double alphaEmit, double alphaPi)
+ public PhraseCluster(int numCluster, Corpus corpus)
{
K=numCluster;
c=corpus;
@@ -40,33 +35,34 @@ public class PhraseCluster { n_phrases=c.getNumPhrases();
n_contexts=c.getNumContexts();
n_positions=c.getNumContextPositions();
- this.scalePT = scalep;
- this.scaleCT = scalec;
- if (threads > 0)
- pool = Executors.newFixedThreadPool(threads);
-
+
emit=new double [K][n_positions][n_words];
pi=new double[n_phrases][K];
for(double [][]i:emit)
- {
for(double []j:i)
- {
- arr.F.randomise(j, alphaEmit <= 0);
- if (alphaEmit > 0)
- digammaNormalize(j, alphaEmit);
- }
- }
-
+ arr.F.randomise(j, true);
for(double []j:pi)
- {
- arr.F.randomise(j, alphaPi <= 0);
- if (alphaPi > 0)
- digammaNormalize(j, alphaPi);
- }
+ arr.F.randomise(j, true);
+ }
+
+ public void initialiseVB(double alphaEmit, double alphaPi)
+ {
+ assert alphaEmit > 0;
+ assert alphaPi > 0;
+
+ for(double [][]i:emit)
+ for(double []j:i)
+ digammaNormalize(j, alphaEmit);
- this.alphaEmit = alphaEmit;
- this.alphaPi = alphaPi;
+ for(double []j:pi)
+ digammaNormalize(j, alphaPi);
+ }
+
+ void useThreadPool(int threads)
+ {
+ assert threads > 0;
+ pool = Executors.newFixedThreadPool(threads);
}
public double EM()
@@ -116,7 +112,7 @@ public class PhraseCluster { return loglikelihood;
}
- public double VBEM()
+ public double VBEM(double alphaEmit, double alphaPi)
{
// FIXME: broken - needs to be done entirely in log-space
@@ -216,9 +212,22 @@ public class PhraseCluster { return kl;
}
- public double PREM_phrase_constraints(){
- assert (scaleCT <= 0);
-
+ public double PREM(double scalePT, double scaleCT)
+ {
+ if (scaleCT == 0)
+ {
+ if (pool != null)
+ return PREM_phrase_constraints_parallel(scalePT);
+ else
+ return PREM_phrase_constraints(scalePT);
+ }
+ else
+ return this.PREM_phrase_context_constraints(scalePT, scaleCT);
+ }
+
+
+ public double PREM_phrase_constraints(double scalePT)
+ {
double [][][]exp_emit=new double[K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
@@ -226,7 +235,7 @@ public class PhraseCluster { int failures=0, iterations=0;
//E
for(int phrase=0; phrase<n_phrases; phrase++){
- PhraseObjective po=new PhraseObjective(this,phrase);
+ PhraseObjective po=new PhraseObjective(this, phrase, scalePT);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) ++failures;
iterations += po.getNumberUpdateCalls();
@@ -234,7 +243,7 @@ public class PhraseCluster { loglikelihood += po.loglikelihood();
kl += po.KL_divergence();
l1lmax += po.l1lmax();
- primal += po.primal();
+ primal += po.primal(scalePT);
List<Edge> edges = c.getEdgesForPhrase(phrase);
for(int edge=0;edge<q.length;edge++){
@@ -272,10 +281,9 @@ public class PhraseCluster { return primal;
}
- public double PREM_phrase_constraints_parallel()
+ public double PREM_phrase_constraints_parallel(final double scalePT)
{
assert(pool != null);
- assert(scaleCT <= 0);
final LinkedBlockingQueue<PhraseObjective> expectations
= new LinkedBlockingQueue<PhraseObjective>();
@@ -294,7 +302,7 @@ public class PhraseCluster { public void run() {
try {
//System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
- PhraseObjective po = new PhraseObjective(PhraseCluster.this, p);
+ PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) failures.incrementAndGet();
//System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
@@ -322,7 +330,7 @@ public class PhraseCluster { loglikelihood += po.loglikelihood();
kl += po.KL_divergence();
l1lmax += po.l1lmax();
- primal += po.primal();
+ primal += po.primal(scalePT);
iterations += po.getNumberUpdateCalls();
@@ -366,15 +374,14 @@ public class PhraseCluster { return primal;
}
- public double PREM_phrase_context_constraints(){
- assert (scaleCT > 0);
-
+ public double PREM_phrase_context_constraints(double scalePT, double scaleCT)
+ {
double[][][] exp_emit = new double [K][n_positions][n_words];
double[][] exp_pi = new double[n_phrases][K];
double[] lambda = null;
//E step
- PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool);
+ PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool, scalePT, scaleCT);
lambda = pco.optimizeWithProjectedGradientDescent();
//now extract expectations
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java index 15bd29c2..ff135a3d 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java @@ -59,12 +59,18 @@ public class PhraseContextObjective extends ProjectedObjective private long actualProjectionTime;
private ExecutorService pool;
- public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters, ExecutorService pool)
+ double scalePT;
+ double scaleCT;
+
+ public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters, ExecutorService pool,
+ double scalePT, double scaleCT)
{
c=cluster;
data=c.c.getEdges();
n_param=data.size()*c.K*2;
this.pool=pool;
+ this.scalePT = scalePT;
+ this.scaleCT = scaleCT;
parameters = startingParameters;
if (parameters == null)
@@ -73,8 +79,8 @@ public class PhraseContextObjective extends ProjectedObjective newPoint = new double[n_param];
gradient = new double[n_param];
initP();
- projectionPhrase = new SimplexProjection(c.scalePT);
- projectionContext = new SimplexProjection(c.scaleCT);
+ projectionPhrase = new SimplexProjection(scalePT);
+ projectionContext = new SimplexProjection(scaleCT);
q=new double [data.size()][c.K];
edgeIndex = new HashMap<Edge, Integer>();
@@ -151,7 +157,7 @@ public class PhraseContextObjective extends ProjectedObjective //System.out.println("projectPoint: " + Arrays.toString(point));
Arrays.fill(newPoint, 0, newPoint.length, 0);
- if (c.scalePT > 0)
+ if (scalePT > 0)
{
// first project using the phrase-tag constraints,
// for all p,t: sum_c lambda_ptc < scaleP
@@ -201,7 +207,7 @@ public class PhraseContextObjective extends ProjectedObjective }
//System.out.println("after PT " + Arrays.toString(newPoint));
- if (c.scaleCT > 1e-6)
+ if (scaleCT > 1e-6)
{
// now project using the context-tag constraints,
// for all c,t: sum_p omega_pct < scaleC
@@ -399,6 +405,6 @@ public class PhraseContextObjective extends ProjectedObjective // L - KL(q||p) - scalePT * l1lmax_phrase - scaleCT * l1lmax_context
public double primal()
{
- return loglikelihood() - KL_divergence() - c.scalePT * phrase_l1lmax() - c.scalePT * context_l1lmax();
+ return loglikelihood() - KL_divergence() - scalePT * phrase_l1lmax() - scalePT * context_l1lmax();
}
}
\ No newline at end of file diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index cc12546d..33167c20 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -63,7 +63,7 @@ public class PhraseObjective extends ProjectedObjective */
public double llh;
- public PhraseObjective(PhraseCluster cluster, int phraseIdx){
+ public PhraseObjective(PhraseCluster cluster, int phraseIdx, double scale){
phrase=phraseIdx;
c=cluster;
data=c.c.getEdgesForPhrase(phrase);
@@ -81,7 +81,7 @@ public class PhraseObjective extends ProjectedObjective newPoint = new double[n_param];
gradient = new double[n_param];
initP();
- projection=new SimplexProjection(c.scalePT);
+ projection=new SimplexProjection(scale);
q=new double [data.size()][c.K];
setParameters(parameters);
@@ -220,8 +220,8 @@ public class PhraseObjective extends ProjectedObjective return sum;
}
- public double primal()
+ public double primal(double scale)
{
- return loglikelihood() - KL_divergence() - c.scalePT * l1lmax();
+ return loglikelihood() - KL_divergence() - scale * l1lmax();
}
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index 439fb337..240c4d64 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -30,6 +30,7 @@ public class Trainer parser.accepts("variational-bayes"); parser.accepts("alpha-emit").withRequiredArg().ofType(Double.class).defaultsTo(0.1); parser.accepts("alpha-pi").withRequiredArg().ofType(Double.class).defaultsTo(0.01); + parser.accepts("agree"); OptionSet options = parser.parse(args); if (options.has("help") || !options.has("in")) @@ -37,7 +38,7 @@ public class Trainer try { parser.printHelpOn(System.err); } catch (IOException e) { - System.err.println("This should never happen. Really."); + System.err.println("This should never happen."); e.printStackTrace(); } System.exit(1); @@ -75,34 +76,46 @@ public class Trainer System.exit(1); } - System.out.println("Running with " + tags + " tags " + - "for " + em_iterations + " EM and " + pr_iterations + " PR iterations " + - "with scale " + scale_phrase + " phrase and " + scale_context + " context " + - "and " + threads + " threads"); - System.out.println(); + if (!options.has("agree")) + System.out.println("Running with " + tags + " tags " + + "for " + em_iterations + " EM and " + pr_iterations + " PR iterations " + + "with scale " + scale_phrase + " phrase and " + scale_context + " context " + + "and " + threads + " threads"); + else + System.out.println("Running agreement model with " + tags + " tags " + + "for " + em_iterations); + + System.out.println(); - PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads, alphaEmit, alphaPi); + PhraseCluster cluster = null; + Agree agree = null; + if (options.has("agree")) + agree = new Agree(tags, corpus); + else + { + cluster = new PhraseCluster(tags, corpus); + if (threads > 0) cluster.useThreadPool(threads); + if (vb) cluster.initialiseVB(alphaEmit, alphaPi); + } double last = 0; for (int i=0; i<em_iterations+pr_iterations; i++) { double o; - if (i < em_iterations) - { - if (!vb) - o = cluster.EM(); - else - o = cluster.VBEM(); - } - else if (scale_context == 0) + if (agree != null) + o = agree.EM(); + else { - if (threads >= 1) - o = cluster.PREM_phrase_constraints_parallel(); + if (i < em_iterations) + { + if (!vb) + o = cluster.EM(); + else + o = cluster.VBEM(alphaEmit, alphaPi); + } else - o = cluster.PREM_phrase_constraints(); + o = cluster.PREM(scale_phrase, scale_context); } - else - o = cluster.PREM_phrase_context_constraints(); System.out.println("ITER: "+i+" objective: " + o); @@ -120,6 +133,9 @@ public class Trainer last = o; } + if (cluster == null) + cluster = agree.model1; + double pl1lmax = cluster.phrase_l1lmax(); double cl1lmax = cluster.context_l1lmax(); System.out.println("\nFinal posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax); |