From 77c25d9f30f95ccb7843f9dce71a4f4e018cc727 Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Mon, 12 Jul 2010 19:48:54 +0000 Subject: Updated launcher to include agreement model. git-svn-id: https://ws10smt.googlecode.com/svn/trunk@226 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../prjava/src/phrase/Agree.java | 6 +- .../prjava/src/phrase/PhraseCluster.java | 91 ++++++++++++---------- .../prjava/src/phrase/PhraseContextObjective.java | 18 +++-- .../prjava/src/phrase/PhraseObjective.java | 8 +- .../prjava/src/phrase/Trainer.java | 56 ++++++++----- 5 files changed, 104 insertions(+), 75 deletions(-) (limited to 'gi/posterior-regularisation/prjava/src') diff --git a/gi/posterior-regularisation/prjava/src/phrase/Agree.java b/gi/posterior-regularisation/prjava/src/phrase/Agree.java index d5b949b0..d61e6eef 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Agree.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Agree.java @@ -12,8 +12,8 @@ import java.util.List; import phrase.Corpus.Edge; public class Agree { - private PhraseCluster model1; - private C2F model2; + PhraseCluster model1; + C2F model2; Corpus c; private int K,n_phrases, n_words, n_contexts, n_positions1,n_positions2; @@ -32,7 +32,7 @@ public class Agree { */ public Agree(int numCluster, Corpus corpus){ - model1=new PhraseCluster(numCluster, corpus, 0, 0, 0); + model1=new PhraseCluster(numCluster, corpus); model2=new C2F(numCluster,corpus); c=corpus; n_words=c.getNumWords(); diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 7bc63c33..abd868c4 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -18,21 +18,16 @@ import util.MathUtil; public class PhraseCluster { public int K; - public double scalePT, scaleCT; private int n_phrases, n_words, n_contexts, n_positions; public Corpus c; public ExecutorService pool; // emit[tag][position][word] = p(word | tag, position in context) - private double emit[][][]; + double emit[][][]; // pi[phrase][tag] = p(tag | phrase) - private double pi[][]; + double pi[][]; - double alphaEmit; - double alphaPi; - - public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads, - double alphaEmit, double alphaPi) + public PhraseCluster(int numCluster, Corpus corpus) { K=numCluster; c=corpus; @@ -40,33 +35,34 @@ public class PhraseCluster { n_phrases=c.getNumPhrases(); n_contexts=c.getNumContexts(); n_positions=c.getNumContextPositions(); - this.scalePT = scalep; - this.scaleCT = scalec; - if (threads > 0) - pool = Executors.newFixedThreadPool(threads); - + emit=new double [K][n_positions][n_words]; pi=new double[n_phrases][K]; for(double [][]i:emit) - { for(double []j:i) - { - arr.F.randomise(j, alphaEmit <= 0); - if (alphaEmit > 0) - digammaNormalize(j, alphaEmit); - } - } - + arr.F.randomise(j, true); for(double []j:pi) - { - arr.F.randomise(j, alphaPi <= 0); - if (alphaPi > 0) - digammaNormalize(j, alphaPi); - } + arr.F.randomise(j, true); + } + + public void initialiseVB(double alphaEmit, double alphaPi) + { + assert alphaEmit > 0; + assert alphaPi > 0; + + for(double [][]i:emit) + for(double []j:i) + digammaNormalize(j, alphaEmit); - this.alphaEmit = alphaEmit; - this.alphaPi = alphaPi; + for(double []j:pi) + digammaNormalize(j, alphaPi); + } + + void useThreadPool(int threads) + { + assert threads > 0; + pool = Executors.newFixedThreadPool(threads); } public double EM() @@ -116,7 +112,7 @@ public class PhraseCluster { return loglikelihood; } - public double VBEM() + public double VBEM(double alphaEmit, double alphaPi) { // FIXME: broken - needs to be done entirely in log-space @@ -216,9 +212,22 @@ public class PhraseCluster { return kl; } - public double PREM_phrase_constraints(){ - assert (scaleCT <= 0); - + public double PREM(double scalePT, double scaleCT) + { + if (scaleCT == 0) + { + if (pool != null) + return PREM_phrase_constraints_parallel(scalePT); + else + return PREM_phrase_constraints(scalePT); + } + else + return this.PREM_phrase_context_constraints(scalePT, scaleCT); + } + + + public double PREM_phrase_constraints(double scalePT) + { double [][][]exp_emit=new double[K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; @@ -226,7 +235,7 @@ public class PhraseCluster { int failures=0, iterations=0; //E for(int phrase=0; phrase edges = c.getEdgesForPhrase(phrase); for(int edge=0;edge expectations = new LinkedBlockingQueue(); @@ -294,7 +302,7 @@ public class PhraseCluster { public void run() { try { //System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p); - PhraseObjective po = new PhraseObjective(PhraseCluster.this, p); + PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT); boolean ok = po.optimizeWithProjectedGradientDescent(); if (!ok) failures.incrementAndGet(); //System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p); @@ -322,7 +330,7 @@ public class PhraseCluster { loglikelihood += po.loglikelihood(); kl += po.KL_divergence(); l1lmax += po.l1lmax(); - primal += po.primal(); + primal += po.primal(scalePT); iterations += po.getNumberUpdateCalls(); @@ -366,15 +374,14 @@ public class PhraseCluster { return primal; } - public double PREM_phrase_context_constraints(){ - assert (scaleCT > 0); - + public double PREM_phrase_context_constraints(double scalePT, double scaleCT) + { double[][][] exp_emit = new double [K][n_positions][n_words]; double[][] exp_pi = new double[n_phrases][K]; double[] lambda = null; //E step - PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool); + PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool, scalePT, scaleCT); lambda = pco.optimizeWithProjectedGradientDescent(); //now extract expectations diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java index 15bd29c2..ff135a3d 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java @@ -59,12 +59,18 @@ public class PhraseContextObjective extends ProjectedObjective private long actualProjectionTime; private ExecutorService pool; - public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters, ExecutorService pool) + double scalePT; + double scaleCT; + + public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters, ExecutorService pool, + double scalePT, double scaleCT) { c=cluster; data=c.c.getEdges(); n_param=data.size()*c.K*2; this.pool=pool; + this.scalePT = scalePT; + this.scaleCT = scaleCT; parameters = startingParameters; if (parameters == null) @@ -73,8 +79,8 @@ public class PhraseContextObjective extends ProjectedObjective newPoint = new double[n_param]; gradient = new double[n_param]; initP(); - projectionPhrase = new SimplexProjection(c.scalePT); - projectionContext = new SimplexProjection(c.scaleCT); + projectionPhrase = new SimplexProjection(scalePT); + projectionContext = new SimplexProjection(scaleCT); q=new double [data.size()][c.K]; edgeIndex = new HashMap(); @@ -151,7 +157,7 @@ public class PhraseContextObjective extends ProjectedObjective //System.out.println("projectPoint: " + Arrays.toString(point)); Arrays.fill(newPoint, 0, newPoint.length, 0); - if (c.scalePT > 0) + if (scalePT > 0) { // first project using the phrase-tag constraints, // for all p,t: sum_c lambda_ptc < scaleP @@ -201,7 +207,7 @@ public class PhraseContextObjective extends ProjectedObjective } //System.out.println("after PT " + Arrays.toString(newPoint)); - if (c.scaleCT > 1e-6) + if (scaleCT > 1e-6) { // now project using the context-tag constraints, // for all c,t: sum_p omega_pct < scaleC @@ -399,6 +405,6 @@ public class PhraseContextObjective extends ProjectedObjective // L - KL(q||p) - scalePT * l1lmax_phrase - scaleCT * l1lmax_context public double primal() { - return loglikelihood() - KL_divergence() - c.scalePT * phrase_l1lmax() - c.scalePT * context_l1lmax(); + return loglikelihood() - KL_divergence() - scalePT * phrase_l1lmax() - scalePT * context_l1lmax(); } } \ No newline at end of file diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index cc12546d..33167c20 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -63,7 +63,7 @@ public class PhraseObjective extends ProjectedObjective */ public double llh; - public PhraseObjective(PhraseCluster cluster, int phraseIdx){ + public PhraseObjective(PhraseCluster cluster, int phraseIdx, double scale){ phrase=phraseIdx; c=cluster; data=c.c.getEdgesForPhrase(phrase); @@ -81,7 +81,7 @@ public class PhraseObjective extends ProjectedObjective newPoint = new double[n_param]; gradient = new double[n_param]; initP(); - projection=new SimplexProjection(c.scalePT); + projection=new SimplexProjection(scale); q=new double [data.size()][c.K]; setParameters(parameters); @@ -220,8 +220,8 @@ public class PhraseObjective extends ProjectedObjective return sum; } - public double primal() + public double primal(double scale) { - return loglikelihood() - KL_divergence() - c.scalePT * l1lmax(); + return loglikelihood() - KL_divergence() - scale * l1lmax(); } } diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index 439fb337..240c4d64 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -30,6 +30,7 @@ public class Trainer parser.accepts("variational-bayes"); parser.accepts("alpha-emit").withRequiredArg().ofType(Double.class).defaultsTo(0.1); parser.accepts("alpha-pi").withRequiredArg().ofType(Double.class).defaultsTo(0.01); + parser.accepts("agree"); OptionSet options = parser.parse(args); if (options.has("help") || !options.has("in")) @@ -37,7 +38,7 @@ public class Trainer try { parser.printHelpOn(System.err); } catch (IOException e) { - System.err.println("This should never happen. Really."); + System.err.println("This should never happen."); e.printStackTrace(); } System.exit(1); @@ -75,34 +76,46 @@ public class Trainer System.exit(1); } - System.out.println("Running with " + tags + " tags " + - "for " + em_iterations + " EM and " + pr_iterations + " PR iterations " + - "with scale " + scale_phrase + " phrase and " + scale_context + " context " + - "and " + threads + " threads"); - System.out.println(); + if (!options.has("agree")) + System.out.println("Running with " + tags + " tags " + + "for " + em_iterations + " EM and " + pr_iterations + " PR iterations " + + "with scale " + scale_phrase + " phrase and " + scale_context + " context " + + "and " + threads + " threads"); + else + System.out.println("Running agreement model with " + tags + " tags " + + "for " + em_iterations); + + System.out.println(); - PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads, alphaEmit, alphaPi); + PhraseCluster cluster = null; + Agree agree = null; + if (options.has("agree")) + agree = new Agree(tags, corpus); + else + { + cluster = new PhraseCluster(tags, corpus); + if (threads > 0) cluster.useThreadPool(threads); + if (vb) cluster.initialiseVB(alphaEmit, alphaPi); + } double last = 0; for (int i=0; i= 1) - o = cluster.PREM_phrase_constraints_parallel(); + if (i < em_iterations) + { + if (!vb) + o = cluster.EM(); + else + o = cluster.VBEM(alphaEmit, alphaPi); + } else - o = cluster.PREM_phrase_constraints(); + o = cluster.PREM(scale_phrase, scale_context); } - else - o = cluster.PREM_phrase_context_constraints(); System.out.println("ITER: "+i+" objective: " + o); @@ -120,6 +133,9 @@ public class Trainer last = o; } + if (cluster == null) + cluster = agree.model1; + double pl1lmax = cluster.phrase_l1lmax(); double cl1lmax = cluster.context_l1lmax(); System.out.println("\nFinal posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax); -- cgit v1.2.3