summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java24
1 files changed, 10 insertions, 14 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index e4db2a1a..63a60682 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -18,7 +18,7 @@ public class PhraseCluster {
public double scalePT, scaleCT;
private int n_phrases, n_words, n_contexts, n_positions;
public Corpus c;
- private ExecutorService pool;
+ public ExecutorService pool;
// emit[tag][position][word] = p(word | tag, position in context)
private double emit[][][];
@@ -88,7 +88,8 @@ public class PhraseCluster {
//cluster.displayModelParam(ps);
//ps.close();
- cluster.finish();
+ if (cluster.pool != null)
+ cluster.pool.shutdown();
}
public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){
@@ -100,7 +101,7 @@ public class PhraseCluster {
n_positions=c.getNumContextPositions();
this.scalePT = scalep;
this.scaleCT = scalec;
- if (threads > 0 && scalec <= 0)
+ if (threads > 0)
pool = Executors.newFixedThreadPool(threads);
emit=new double [K][n_positions][n_words];
@@ -116,12 +117,7 @@ public class PhraseCluster {
arr.F.randomise(j);
}
}
-
- public void finish()
- {
- if (pool != null)
- pool.shutdown();
- }
+
public double EM(){
double [][][]exp_emit=new double [K][n_positions][n_words];
@@ -318,13 +314,13 @@ public class PhraseCluster {
public double PREM_phrase_context_constraints(){
assert (scaleCT > 0);
- double [][][]exp_emit=new double [K][n_positions][n_words];
- double [][]exp_pi=new double[n_phrases][K];
+ double[][][] exp_emit = new double [K][n_positions][n_words];
+ double[][] exp_pi = new double[n_phrases][K];
+ double[] lambda = null;
//E step
- // TODO: cache the lambda values (the null below)
- PhraseContextObjective pco = new PhraseContextObjective(this, null);
- pco.optimizeWithProjectedGradientDescent();
+ PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool);
+ lambda = pco.optimizeWithProjectedGradientDescent();
//now extract expectations
List<Corpus.Edge> edges = c.getEdges();