summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java91
1 files changed, 49 insertions, 42 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index 7bc63c33..abd868c4 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -18,21 +18,16 @@ import util.MathUtil;
public class PhraseCluster {
public int K;
- public double scalePT, scaleCT;
private int n_phrases, n_words, n_contexts, n_positions;
public Corpus c;
public ExecutorService pool;
// emit[tag][position][word] = p(word | tag, position in context)
- private double emit[][][];
+ double emit[][][];
// pi[phrase][tag] = p(tag | phrase)
- private double pi[][];
+ double pi[][];
- double alphaEmit;
- double alphaPi;
-
- public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads,
- double alphaEmit, double alphaPi)
+ public PhraseCluster(int numCluster, Corpus corpus)
{
K=numCluster;
c=corpus;
@@ -40,33 +35,34 @@ public class PhraseCluster {
n_phrases=c.getNumPhrases();
n_contexts=c.getNumContexts();
n_positions=c.getNumContextPositions();
- this.scalePT = scalep;
- this.scaleCT = scalec;
- if (threads > 0)
- pool = Executors.newFixedThreadPool(threads);
-
+
emit=new double [K][n_positions][n_words];
pi=new double[n_phrases][K];
for(double [][]i:emit)
- {
for(double []j:i)
- {
- arr.F.randomise(j, alphaEmit <= 0);
- if (alphaEmit > 0)
- digammaNormalize(j, alphaEmit);
- }
- }
-
+ arr.F.randomise(j, true);
for(double []j:pi)
- {
- arr.F.randomise(j, alphaPi <= 0);
- if (alphaPi > 0)
- digammaNormalize(j, alphaPi);
- }
+ arr.F.randomise(j, true);
+ }
+
+ public void initialiseVB(double alphaEmit, double alphaPi)
+ {
+ assert alphaEmit > 0;
+ assert alphaPi > 0;
+
+ for(double [][]i:emit)
+ for(double []j:i)
+ digammaNormalize(j, alphaEmit);
- this.alphaEmit = alphaEmit;
- this.alphaPi = alphaPi;
+ for(double []j:pi)
+ digammaNormalize(j, alphaPi);
+ }
+
+ void useThreadPool(int threads)
+ {
+ assert threads > 0;
+ pool = Executors.newFixedThreadPool(threads);
}
public double EM()
@@ -116,7 +112,7 @@ public class PhraseCluster {
return loglikelihood;
}
- public double VBEM()
+ public double VBEM(double alphaEmit, double alphaPi)
{
// FIXME: broken - needs to be done entirely in log-space
@@ -216,9 +212,22 @@ public class PhraseCluster {
return kl;
}
- public double PREM_phrase_constraints(){
- assert (scaleCT <= 0);
-
+ public double PREM(double scalePT, double scaleCT)
+ {
+ if (scaleCT == 0)
+ {
+ if (pool != null)
+ return PREM_phrase_constraints_parallel(scalePT);
+ else
+ return PREM_phrase_constraints(scalePT);
+ }
+ else
+ return this.PREM_phrase_context_constraints(scalePT, scaleCT);
+ }
+
+
+ public double PREM_phrase_constraints(double scalePT)
+ {
double [][][]exp_emit=new double[K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
@@ -226,7 +235,7 @@ public class PhraseCluster {
int failures=0, iterations=0;
//E
for(int phrase=0; phrase<n_phrases; phrase++){
- PhraseObjective po=new PhraseObjective(this,phrase);
+ PhraseObjective po=new PhraseObjective(this, phrase, scalePT);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) ++failures;
iterations += po.getNumberUpdateCalls();
@@ -234,7 +243,7 @@ public class PhraseCluster {
loglikelihood += po.loglikelihood();
kl += po.KL_divergence();
l1lmax += po.l1lmax();
- primal += po.primal();
+ primal += po.primal(scalePT);
List<Edge> edges = c.getEdgesForPhrase(phrase);
for(int edge=0;edge<q.length;edge++){
@@ -272,10 +281,9 @@ public class PhraseCluster {
return primal;
}
- public double PREM_phrase_constraints_parallel()
+ public double PREM_phrase_constraints_parallel(final double scalePT)
{
assert(pool != null);
- assert(scaleCT <= 0);
final LinkedBlockingQueue<PhraseObjective> expectations
= new LinkedBlockingQueue<PhraseObjective>();
@@ -294,7 +302,7 @@ public class PhraseCluster {
public void run() {
try {
//System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
- PhraseObjective po = new PhraseObjective(PhraseCluster.this, p);
+ PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) failures.incrementAndGet();
//System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
@@ -322,7 +330,7 @@ public class PhraseCluster {
loglikelihood += po.loglikelihood();
kl += po.KL_divergence();
l1lmax += po.l1lmax();
- primal += po.primal();
+ primal += po.primal(scalePT);
iterations += po.getNumberUpdateCalls();
@@ -366,15 +374,14 @@ public class PhraseCluster {
return primal;
}
- public double PREM_phrase_context_constraints(){
- assert (scaleCT > 0);
-
+ public double PREM_phrase_context_constraints(double scalePT, double scaleCT)
+ {
double[][][] exp_emit = new double [K][n_positions][n_words];
double[][] exp_pi = new double[n_phrases][K];
double[] lambda = null;
//E step
- PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool);
+ PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool, scalePT, scaleCT);
lambda = pco.optimizeWithProjectedGradientDescent();
//now extract expectations