summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java45
1 files changed, 32 insertions, 13 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index 1f73764e..a369b319 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -2,8 +2,6 @@ package phrase;
import gnu.trove.TIntArrayList;
import org.apache.commons.math.special.Gamma;
-import io.FileUtil;
-import java.io.IOException;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.List;
@@ -11,9 +9,10 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
import phrase.Corpus.Edge;
-import util.MathUtil;
+
public class PhraseCluster {
@@ -21,7 +20,11 @@ public class PhraseCluster {
private int n_phrases, n_words, n_contexts, n_positions;
public Corpus c;
public ExecutorService pool;
-
+
+ double[] lambdaPTCT;
+ double[][] lambdaPT;
+ boolean cacheLambda = true;
+
// emit[tag][position][word] = p(word | tag, position in context)
double emit[][][];
// pi[phrase][tag] = p(tag | phrase)
@@ -232,14 +235,19 @@ public class PhraseCluster {
{
double [][][]exp_emit=new double[K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
+
+ if (lambdaPT == null && cacheLambda)
+ lambdaPT = new double[n_phrases][];
double loglikelihood=0, kl=0, l1lmax=0, primal=0;
int failures=0, iterations=0;
+ long start = System.currentTimeMillis();
//E
for(int phrase=0; phrase<n_phrases; phrase++){
- PhraseObjective po=new PhraseObjective(this, phrase, scalePT);
+ PhraseObjective po = new PhraseObjective(this, phrase, scalePT, (cacheLambda) ? lambdaPT[phrase] : null);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) ++failures;
+ if (cacheLambda) lambdaPT[phrase] = po.getParameters();
iterations += po.getNumberUpdateCalls();
double [][] q=po.posterior();
loglikelihood += po.loglikelihood();
@@ -263,9 +271,10 @@ public class PhraseCluster {
}
}
+ long end = System.currentTimeMillis();
if (failures > 0)
System.out.println("WARNING: failed to converge in " + failures + "/" + n_phrases + " cases");
- System.out.println("\tmean iters: " + iterations/(double)n_phrases);
+ System.out.println("\tmean iters: " + iterations/(double)n_phrases + " elapsed time " + (end - start) / 1000.0);
System.out.println("\tllh: " + loglikelihood);
System.out.println("\tKL: " + kl);
System.out.println("\tphrase l1lmax: " + l1lmax);
@@ -295,7 +304,12 @@ public class PhraseCluster {
double loglikelihood=0, kl=0, l1lmax=0, primal=0;
final AtomicInteger failures = new AtomicInteger(0);
+ final AtomicLong elapsed = new AtomicLong(0l);
int iterations=0;
+ long start = System.currentTimeMillis();
+
+ if (lambdaPT == null && cacheLambda)
+ lambdaPT = new double[n_phrases][];
//E
for(int phrase=0;phrase<n_phrases;phrase++){
@@ -304,9 +318,13 @@ public class PhraseCluster {
public void run() {
try {
//System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
- PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT);
+ long start = System.currentTimeMillis();
+ PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT, (cacheLambda) ? lambdaPT[p] : null);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) failures.incrementAndGet();
+ long end = System.currentTimeMillis();
+ elapsed.addAndGet(end - start);
+
//System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
expectations.put(po);
//System.out.println("" + Thread.currentThread().getId() + " added to queue " + p);
@@ -327,6 +345,7 @@ public class PhraseCluster {
PhraseObjective po = expectations.take();
// process
int phrase = po.phrase;
+ if (cacheLambda) lambdaPT[phrase] = po.getParameters();
//System.out.println("" + Thread.currentThread().getId() + " taken phrase " + phrase);
double [][] q=po.posterior();
loglikelihood += po.loglikelihood();
@@ -335,7 +354,6 @@ public class PhraseCluster {
primal += po.primal(scalePT);
iterations += po.getNumberUpdateCalls();
-
List<Edge> edges = c.getEdgesForPhrase(phrase);
for(int edge=0;edge<q.length;edge++){
Edge e = edges.get(edge);
@@ -356,9 +374,11 @@ public class PhraseCluster {
}
}
+ long end = System.currentTimeMillis();
+
if (failures.get() > 0)
System.out.println("WARNING: failed to converge in " + failures.get() + "/" + n_phrases + " cases");
- System.out.println("\tmean iters: " + iterations/(double)n_phrases);
+ System.out.println("\tmean iters: " + iterations/(double)n_phrases + " walltime " + (end-start)/1000.0 + " threads " + elapsed.get() / 1000.0);
System.out.println("\tllh: " + loglikelihood);
System.out.println("\tKL: " + kl);
System.out.println("\tphrase l1lmax: " + l1lmax);
@@ -376,16 +396,15 @@ public class PhraseCluster {
return primal;
}
- double[] lambda;
-
public double PREM_phrase_context_constraints(double scalePT, double scaleCT)
{
double[][][] exp_emit = new double [K][n_positions][n_words];
double[][] exp_pi = new double[n_phrases][K];
//E step
- PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool, scalePT, scaleCT);
- lambda = pco.optimizeWithProjectedGradientDescent();
+ PhraseContextObjective pco = new PhraseContextObjective(this, lambdaPTCT, pool, scalePT, scaleCT);
+ boolean ok = pco.optimizeWithProjectedGradientDescent();
+ if (cacheLambda) lambdaPTCT = pco.getParameters();
//now extract expectations
List<Corpus.Edge> edges = c.getEdges();