diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java')
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java | 147 |
1 files changed, 113 insertions, 34 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 5efaf52e..feab5eda 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -2,14 +2,20 @@ package phrase; import gnu.trove.TIntArrayList;
import org.apache.commons.math.special.Gamma;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.List;
+import java.util.StringTokenizer;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
+import java.util.regex.Pattern;
import phrase.Corpus.Edge;
@@ -17,7 +23,7 @@ import phrase.Corpus.Edge; public class PhraseCluster {
public int K;
- private int n_phrases, n_words, n_contexts, n_positions;
+ private int n_phrases, n_words, n_contexts, n_positions, edge_threshold;
public Corpus c;
public ExecutorService pool;
@@ -38,6 +44,7 @@ public class PhraseCluster { n_phrases=c.getNumPhrases();
n_contexts=c.getNumContexts();
n_positions=c.getNumContextPositions();
+ edge_threshold=0;
emit=new double [K][n_positions][n_words];
pi=new double[n_phrases][K];
@@ -74,12 +81,11 @@ public class PhraseCluster { double [][][]exp_emit=new double [K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
- if (skipBigPhrases)
- {
- for(double [][]i:exp_emit)
- for(double []j:i)
- Arrays.fill(j, 1e-100);
- }
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-10);
+ for(double []j:pi)
+ Arrays.fill(j, 1e-10);
double loglikelihood=0;
@@ -97,6 +103,9 @@ public class PhraseCluster { for (int ctx=0; ctx<contexts.size(); ctx++)
{
Edge edge = contexts.get(ctx);
+ if (edge.getCount() < edge_threshold)
+ continue;
+
double p[]=posterior(edge);
double z = arr.F.l1norm(p);
assert z > 0;
@@ -121,7 +130,7 @@ public class PhraseCluster { arr.F.l1normalize(j);
for(double []j:exp_pi)
- arr.F.l1normalize(j);
+ arr.F.l1normalize(j);
emit=exp_emit;
pi=exp_pi;
@@ -250,12 +259,11 @@ public class PhraseCluster { double [][][]exp_emit=new double[K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
- if (skipBigPhrases)
- {
- for(double [][]i:exp_emit)
- for(double []j:i)
- Arrays.fill(j, 1e-100);
- }
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ Arrays.fill(j, 1e-10);
+ for(double []j:pi)
+ Arrays.fill(j, 1e-10);
if (lambdaPT == null && cacheLambda)
lambdaPT = new double[n_phrases][];
@@ -271,6 +279,7 @@ public class PhraseCluster { continue;
}
+ // FIXME: add rare edge check to phrase objective & posterior processing
PhraseObjective po = new PhraseObjective(this, phrase, scalePT, (cacheLambda) ? lambdaPT[phrase] : null);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) ++failures;
@@ -493,11 +502,25 @@ public class PhraseCluster { {
double[] prob=Arrays.copyOf(pi[edge.getPhraseId()], K);
+ //if (edge.getCount() < edge_threshold)
+ //System.out.println("Edge: " + edge + " probs for phrase " + Arrays.toString(prob));
+
TIntArrayList ctx = edge.getContext();
for(int tag=0;tag<K;tag++)
+ {
for(int c=0;c<n_positions;c++)
- if (!this.c.isSentinel(ctx.get(c)))
- prob[tag]*=emit[tag][c][ctx.get(c)];
+ {
+ int word = ctx.get(c);
+ //if (edge.getCount() < edge_threshold)
+ //System.out.println("\ttag: " + tag + " context word: " + word + " prob " + emit[tag][c][word]);
+
+ if (!this.c.isSentinel(word))
+ prob[tag]*=emit[tag][c][word];
+ }
+ }
+
+ //if (edge.getCount() < edge_threshold)
+ //System.out.println("prob " + Arrays.toString(prob));
return prob;
}
@@ -514,39 +537,33 @@ public class PhraseCluster { ps.print("\t");
ps.print(edge.getContextString(true));
int t=arr.F.argmax(probs);
- ps.println(" ||| C=" + t);
+ ps.println(" ||| C=" + t + " T=" + edge.getCount() + " P=" + probs[t]);
+ //ps.println("# probs " + Arrays.toString(probs));
}
}
public void displayModelParam(PrintStream ps)
{
final double EPS = 1e-6;
+ ps.println("phrases " + n_phrases + " tags " + K + " positions " + n_positions);
- ps.println("P(tag|phrase)");
for (int i = 0; i < n_phrases; ++i)
- {
- ps.print(c.getPhrase(i));
- for(int j=0;j<pi[i].length;j++){
+ for(int j=0;j<pi[i].length;j++)
if (pi[i][j] > EPS)
- ps.print("\t" + j + ": " + pi[i][j]);
- }
- ps.println();
- }
-
- ps.println("P(word|tag,position)");
+ ps.println(i + " " + j + " " + pi[i][j]);
+
+ ps.println();
for (int i = 0; i < K; ++i)
{
- for(int position=0;position<n_positions;position++){
- ps.println("tag " + i + " position " + position);
- for(int word=0;word<emit[i][position].length;word++){
+ for(int position=0;position<n_positions;position++)
+ {
+ for(int word=0;word<emit[i][position].length;word++)
+ {
if (emit[i][position][word] > EPS)
- ps.print(c.getWord(word)+"="+emit[i][position][word]+"\t");
+ ps.println(i + " " + position + " " + word + " " + emit[i][position][word]);
}
- ps.println();
}
- ps.println();
}
-
}
double phrase_l1lmax()
@@ -586,4 +603,66 @@ public class PhraseCluster { }
return sum;
}
+
+ public void loadParameters(BufferedReader input) throws IOException
+ {
+ final double EPS = 1e-50;
+
+ // overwrite pi, emit with ~zeros
+ for(double [][]i:emit)
+ for(double []j:i)
+ Arrays.fill(j, EPS);
+
+ for(double []j:pi)
+ Arrays.fill(j, EPS);
+
+ String line = input.readLine();
+ assert line != null;
+
+ Pattern space = Pattern.compile(" +");
+ String[] parts = space.split(line);
+ assert parts.length == 6;
+
+ assert parts[0].equals("phrases");
+ int phrases = Integer.parseInt(parts[1]);
+ int tags = Integer.parseInt(parts[3]);
+ int positions = Integer.parseInt(parts[5]);
+
+ assert phrases == n_phrases;
+ assert tags == K;
+ assert positions == n_positions;
+
+ // read in pi
+ while ((line = input.readLine()) != null)
+ {
+ line = line.trim();
+ if (line.isEmpty()) break;
+
+ String[] tokens = space.split(line);
+ assert tokens.length == 3;
+ int p = Integer.parseInt(tokens[0]);
+ int t = Integer.parseInt(tokens[1]);
+ double v = Double.parseDouble(tokens[2]);
+
+ pi[p][t] = v;
+ }
+
+ // read in emissions
+ while ((line = input.readLine()) != null)
+ {
+ String[] tokens = space.split(line);
+ assert tokens.length == 4;
+ int t = Integer.parseInt(tokens[0]);
+ int p = Integer.parseInt(tokens[1]);
+ int w = Integer.parseInt(tokens[2]);
+ double v = Double.parseDouble(tokens[3]);
+
+ emit[t][p][w] = v;
+ }
+ }
+
+ public void setEdgeThreshold(int edgeThreshold)
+ {
+ this.edge_threshold = edgeThreshold;
+ }
}
|