summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java80
1 files changed, 34 insertions, 46 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index 9ee766d4..e7e4af32 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -23,7 +23,7 @@ import phrase.Corpus.Edge;
public class PhraseCluster {
public int K;
- private int n_phrases, n_words, n_contexts, n_positions, edge_threshold;
+ private int n_phrases, n_words, n_contexts, n_positions;
public Corpus c;
public ExecutorService pool;
@@ -44,7 +44,6 @@ public class PhraseCluster {
n_phrases=c.getNumPhrases();
n_contexts=c.getNumContexts();
n_positions=c.getNumContextPositions();
- edge_threshold=0;
emit=new double [K][n_positions][n_words];
pi=new double[n_phrases][K];
@@ -76,7 +75,7 @@ public class PhraseCluster {
pool = Executors.newFixedThreadPool(threads);
}
- public double EM(boolean skipBigPhrases)
+ public double EM(int phraseSizeLimit)
{
double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
@@ -92,19 +91,17 @@ public class PhraseCluster {
//E
for(int phrase=0; phrase < n_phrases; phrase++)
{
- if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit)
{
System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
continue;
- }
+ }
List<Edge> contexts = c.getEdgesForPhrase(phrase);
for (int ctx=0; ctx<contexts.size(); ctx++)
{
Edge edge = contexts.get(ctx);
- if (edge.getCount() < edge_threshold || c.isRare(edge))
- continue;
double p[]=posterior(edge);
double z = arr.F.l1norm(p);
@@ -138,10 +135,9 @@ public class PhraseCluster {
return loglikelihood;
}
- public double VBEM(double alphaEmit, double alphaPi, boolean skipBigPhrases)
+ public double VBEM(double alphaEmit, double alphaPi)
{
// FIXME: broken - needs to be done entirely in log-space
- assert !skipBigPhrases : "FIXME: implement this!";
double [][][]exp_emit = new double [K][n_positions][n_words];
double [][]exp_pi = new double[n_phrases][K];
@@ -240,21 +236,21 @@ public class PhraseCluster {
return kl;
}
- public double PREM(double scalePT, double scaleCT, boolean skipBigPhrases)
+ public double PREM(double scalePT, double scaleCT, int phraseSizeLimit)
{
if (scaleCT == 0)
{
if (pool != null)
- return PREM_phrase_constraints_parallel(scalePT, skipBigPhrases);
+ return PREM_phrase_constraints_parallel(scalePT, phraseSizeLimit);
else
- return PREM_phrase_constraints(scalePT, skipBigPhrases);
+ return PREM_phrase_constraints(scalePT, phraseSizeLimit);
}
- else
- return this.PREM_phrase_context_constraints(scalePT, scaleCT, skipBigPhrases);
+ else // FIXME: ignores phraseSizeLimit
+ return this.PREM_phrase_context_constraints(scalePT, scaleCT);
}
- public double PREM_phrase_constraints(double scalePT, boolean skipBigPhrases)
+ public double PREM_phrase_constraints(double scalePT, int phraseSizeLimit)
{
double [][][]exp_emit=new double[K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
@@ -272,8 +268,9 @@ public class PhraseCluster {
int failures=0, iterations=0;
long start = System.currentTimeMillis();
//E
- for(int phrase=0; phrase<n_phrases; phrase++){
- if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ for(int phrase=0; phrase<n_phrases; phrase++)
+ {
+ if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit)
{
System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
continue;
@@ -328,7 +325,7 @@ public class PhraseCluster {
return primal;
}
- public double PREM_phrase_constraints_parallel(final double scalePT, boolean skipBigPhrases)
+ public double PREM_phrase_constraints_parallel(final double scalePT, int phraseSizeLimit)
{
assert(pool != null);
@@ -338,12 +335,11 @@ public class PhraseCluster {
double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
- if (skipBigPhrases)
- {
- for(double [][]i:exp_emit)
- for(double []j:i)
- Arrays.fill(j, 1e-100);
- }
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-10);
+ for(double []j:pi)
+ Arrays.fill(j, 1e-10);
double loglikelihood=0, kl=0, l1lmax=0, primal=0;
final AtomicInteger failures = new AtomicInteger(0);
@@ -356,12 +352,13 @@ public class PhraseCluster {
//E
for(int phrase=0;phrase<n_phrases;phrase++){
- if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
+ if (phraseSizeLimit >= 1 && c.getPhrase(phrase).size() > phraseSizeLimit)
{
n -= 1;
System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
continue;
}
+
final int p=phrase;
pool.execute(new Runnable() {
public void run() {
@@ -445,10 +442,8 @@ public class PhraseCluster {
return primal;
}
- public double PREM_phrase_context_constraints(double scalePT, double scaleCT, boolean skipBigPhrases)
+ public double PREM_phrase_context_constraints(double scalePT, double scaleCT)
{
- assert !skipBigPhrases : "Not supported yet - FIXME!"; //FIXME
-
double[][][] exp_emit = new double [K][n_positions][n_words];
double[][] exp_pi = new double[n_phrases][K];
@@ -500,10 +495,14 @@ public class PhraseCluster {
*/
public double[] posterior(Corpus.Edge edge)
{
- double[] prob=Arrays.copyOf(pi[edge.getPhraseId()], K);
-
- //if (edge.getCount() < edge_threshold)
- //System.out.println("Edge: " + edge + " probs for phrase " + Arrays.toString(prob));
+ double[] prob;
+ if (edge.getPhraseId() < n_phrases)
+ prob = Arrays.copyOf(pi[edge.getPhraseId()], K);
+ else
+ {
+ prob = new double[K];
+ Arrays.fill(prob, 1.0);
+ }
TIntArrayList ctx = edge.getContext();
for(int tag=0;tag<K;tag++)
@@ -511,23 +510,17 @@ public class PhraseCluster {
for(int c=0;c<n_positions;c++)
{
int word = ctx.get(c);
- //if (edge.getCount() < edge_threshold)
- //System.out.println("\ttag: " + tag + " context word: " + word + " prob " + emit[tag][c][word]);
-
- if (!this.c.isSentinel(word))
+ if (!this.c.isSentinel(word) && word < n_words)
prob[tag]*=emit[tag][c][word];
}
}
-
- //if (edge.getCount() < edge_threshold)
- //System.out.println("prob " + Arrays.toString(prob));
return prob;
}
- public void displayPosterior(PrintStream ps)
+ public void displayPosterior(PrintStream ps, List<Edge> testing)
{
- for (Edge edge : c.getEdges())
+ for (Edge edge : testing)
{
double probs[] = posterior(edge);
arr.F.l1normalize(probs);
@@ -660,9 +653,4 @@ public class PhraseCluster {
emit[t][p][w] = v;
}
}
-
- public void setEdgeThreshold(int edgeThreshold)
- {
- this.edge_threshold = edgeThreshold;
- }
}