summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java67
1 files changed, 55 insertions, 12 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index a369b319..5efaf52e 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -69,16 +69,29 @@ public class PhraseCluster {
pool = Executors.newFixedThreadPool(threads);
}
- public double EM()
+ public double EM(boolean skipBigPhrases)
{
double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
+ if (skipBigPhrases)
+ {
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-100);
+ }
+
double loglikelihood=0;
//E
for(int phrase=0; phrase < n_phrases; phrase++)
{
+ if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ {
+ System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
+ continue;
+ }
+
List<Edge> contexts = c.getEdgesForPhrase(phrase);
for (int ctx=0; ctx<contexts.size(); ctx++)
@@ -116,9 +129,10 @@ public class PhraseCluster {
return loglikelihood;
}
- public double VBEM(double alphaEmit, double alphaPi)
+ public double VBEM(double alphaEmit, double alphaPi, boolean skipBigPhrases)
{
// FIXME: broken - needs to be done entirely in log-space
+ assert !skipBigPhrases : "FIXME: implement this!";
double [][][]exp_emit = new double [K][n_positions][n_words];
double [][]exp_pi = new double[n_phrases][K];
@@ -217,24 +231,31 @@ public class PhraseCluster {
return kl;
}
- public double PREM(double scalePT, double scaleCT)
+ public double PREM(double scalePT, double scaleCT, boolean skipBigPhrases)
{
if (scaleCT == 0)
{
if (pool != null)
- return PREM_phrase_constraints_parallel(scalePT);
+ return PREM_phrase_constraints_parallel(scalePT, skipBigPhrases);
else
- return PREM_phrase_constraints(scalePT);
+ return PREM_phrase_constraints(scalePT, skipBigPhrases);
}
else
- return this.PREM_phrase_context_constraints(scalePT, scaleCT);
+ return this.PREM_phrase_context_constraints(scalePT, scaleCT, skipBigPhrases);
}
- public double PREM_phrase_constraints(double scalePT)
+ public double PREM_phrase_constraints(double scalePT, boolean skipBigPhrases)
{
double [][][]exp_emit=new double[K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
+
+ if (skipBigPhrases)
+ {
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-100);
+ }
if (lambdaPT == null && cacheLambda)
lambdaPT = new double[n_phrases][];
@@ -244,6 +265,12 @@ public class PhraseCluster {
long start = System.currentTimeMillis();
//E
for(int phrase=0; phrase<n_phrases; phrase++){
+ if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ {
+ System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
+ continue;
+ }
+
PhraseObjective po = new PhraseObjective(this, phrase, scalePT, (cacheLambda) ? lambdaPT[phrase] : null);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) ++failures;
@@ -292,7 +319,7 @@ public class PhraseCluster {
return primal;
}
- public double PREM_phrase_constraints_parallel(final double scalePT)
+ public double PREM_phrase_constraints_parallel(final double scalePT, boolean skipBigPhrases)
{
assert(pool != null);
@@ -302,10 +329,17 @@ public class PhraseCluster {
double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
+ if (skipBigPhrases)
+ {
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-100);
+ }
+
double loglikelihood=0, kl=0, l1lmax=0, primal=0;
final AtomicInteger failures = new AtomicInteger(0);
final AtomicLong elapsed = new AtomicLong(0l);
- int iterations=0;
+ int iterations=0, n=n_phrases;
long start = System.currentTimeMillis();
if (lambdaPT == null && cacheLambda)
@@ -313,6 +347,12 @@ public class PhraseCluster {
//E
for(int phrase=0;phrase<n_phrases;phrase++){
+ if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ {
+ n -= 1;
+ System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
+ continue;
+ }
final int p=phrase;
pool.execute(new Runnable() {
public void run() {
@@ -337,7 +377,7 @@ public class PhraseCluster {
}
// aggregate the expectations as they become available
- for(int count=0;count<n_phrases;count++) {
+ for(int count=0;count<n;count++) {
try {
//System.out.println("" + Thread.currentThread().getId() + " reading queue #" + count);
@@ -396,8 +436,10 @@ public class PhraseCluster {
return primal;
}
- public double PREM_phrase_context_constraints(double scalePT, double scaleCT)
+ public double PREM_phrase_context_constraints(double scalePT, double scaleCT, boolean skipBigPhrases)
{
+ assert !skipBigPhrases : "Not supported yet - FIXME!"; //FIXME
+
double[][][] exp_emit = new double [K][n_positions][n_words];
double[][] exp_pi = new double[n_phrases][K];
@@ -454,7 +496,8 @@ public class PhraseCluster {
TIntArrayList ctx = edge.getContext();
for(int tag=0;tag<K;tag++)
for(int c=0;c<n_positions;c++)
- prob[tag]*=emit[tag][c][ctx.get(c)];
+ if (!this.c.isSentinel(ctx.get(c)))
+ prob[tag]*=emit[tag][c][ctx.get(c)];
return prob;
}