From fc5860d6df8c30149cee280f38d7a11889102663 Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Mon, 12 Jul 2010 19:48:54 +0000 Subject: Updated launcher to include agreement model. git-svn-id: https://ws10smt.googlecode.com/svn/trunk@226 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../prjava/src/phrase/PhraseCluster.java | 91 ++++++++++++---------- 1 file changed, 49 insertions(+), 42 deletions(-) (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java') diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 7bc63c33..abd868c4 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -18,21 +18,16 @@ import util.MathUtil; public class PhraseCluster { public int K; - public double scalePT, scaleCT; private int n_phrases, n_words, n_contexts, n_positions; public Corpus c; public ExecutorService pool; // emit[tag][position][word] = p(word | tag, position in context) - private double emit[][][]; + double emit[][][]; // pi[phrase][tag] = p(tag | phrase) - private double pi[][]; + double pi[][]; - double alphaEmit; - double alphaPi; - - public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads, - double alphaEmit, double alphaPi) + public PhraseCluster(int numCluster, Corpus corpus) { K=numCluster; c=corpus; @@ -40,33 +35,34 @@ public class PhraseCluster { n_phrases=c.getNumPhrases(); n_contexts=c.getNumContexts(); n_positions=c.getNumContextPositions(); - this.scalePT = scalep; - this.scaleCT = scalec; - if (threads > 0) - pool = Executors.newFixedThreadPool(threads); - + emit=new double [K][n_positions][n_words]; pi=new double[n_phrases][K]; for(double [][]i:emit) - { for(double []j:i) - { - arr.F.randomise(j, alphaEmit <= 0); - if (alphaEmit > 0) - digammaNormalize(j, alphaEmit); - } - } - + arr.F.randomise(j, true); for(double []j:pi) - { - arr.F.randomise(j, alphaPi <= 0); - if (alphaPi > 0) - digammaNormalize(j, alphaPi); - } + arr.F.randomise(j, true); + } + + public void initialiseVB(double alphaEmit, double alphaPi) + { + assert alphaEmit > 0; + assert alphaPi > 0; + + for(double [][]i:emit) + for(double []j:i) + digammaNormalize(j, alphaEmit); - this.alphaEmit = alphaEmit; - this.alphaPi = alphaPi; + for(double []j:pi) + digammaNormalize(j, alphaPi); + } + + void useThreadPool(int threads) + { + assert threads > 0; + pool = Executors.newFixedThreadPool(threads); } public double EM() @@ -116,7 +112,7 @@ public class PhraseCluster { return loglikelihood; } - public double VBEM() + public double VBEM(double alphaEmit, double alphaPi) { // FIXME: broken - needs to be done entirely in log-space @@ -216,9 +212,22 @@ public class PhraseCluster { return kl; } - public double PREM_phrase_constraints(){ - assert (scaleCT <= 0); - + public double PREM(double scalePT, double scaleCT) + { + if (scaleCT == 0) + { + if (pool != null) + return PREM_phrase_constraints_parallel(scalePT); + else + return PREM_phrase_constraints(scalePT); + } + else + return this.PREM_phrase_context_constraints(scalePT, scaleCT); + } + + + public double PREM_phrase_constraints(double scalePT) + { double [][][]exp_emit=new double[K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; @@ -226,7 +235,7 @@ public class PhraseCluster { int failures=0, iterations=0; //E for(int phrase=0; phrase edges = c.getEdgesForPhrase(phrase); for(int edge=0;edge expectations = new LinkedBlockingQueue(); @@ -294,7 +302,7 @@ public class PhraseCluster { public void run() { try { //System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p); - PhraseObjective po = new PhraseObjective(PhraseCluster.this, p); + PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT); boolean ok = po.optimizeWithProjectedGradientDescent(); if (!ok) failures.incrementAndGet(); //System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p); @@ -322,7 +330,7 @@ public class PhraseCluster { loglikelihood += po.loglikelihood(); kl += po.KL_divergence(); l1lmax += po.l1lmax(); - primal += po.primal(); + primal += po.primal(scalePT); iterations += po.getNumberUpdateCalls(); @@ -366,15 +374,14 @@ public class PhraseCluster { return primal; } - public double PREM_phrase_context_constraints(){ - assert (scaleCT > 0); - + public double PREM_phrase_context_constraints(double scalePT, double scaleCT) + { double[][][] exp_emit = new double [K][n_positions][n_words]; double[][] exp_pi = new double[n_phrases][K]; double[] lambda = null; //E step - PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool); + PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool, scalePT, scaleCT); lambda = pco.optimizeWithProjectedGradientDescent(); //now extract expectations -- cgit v1.2.3