diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java')
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java | 347 |
1 files changed, 209 insertions, 138 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 731d03ac..e4db2a1a 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -1,44 +1,54 @@ package phrase;
+import gnu.trove.TIntArrayList;
import io.FileUtil;
-
-import java.io.FileOutputStream;
import java.io.IOException;
-import java.io.OutputStream;
import java.io.PrintStream;
import java.util.Arrays;
+import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
-import java.util.zip.GZIPOutputStream;
+
+import phrase.Corpus.Edge;
public class PhraseCluster {
public int K;
- public double scale;
- private int n_phrase;
- private int n_words;
- public PhraseCorpus c;
+ public double scalePT, scaleCT;
+ private int n_phrases, n_words, n_contexts, n_positions;
+ public Corpus c;
private ExecutorService pool;
- /**@brief
- * emit[tag][position][word]
- */
+ // emit[tag][position][word] = p(word | tag, position in context)
private double emit[][][];
+ // pi[phrase][tag] = p(tag | phrase)
private double pi[][];
-
- public static void main(String[] args) {
+ public static void main(String[] args)
+ {
String input_fname = args[0];
int tags = Integer.parseInt(args[1]);
String output_fname = args[2];
int iterations = Integer.parseInt(args[3]);
- double scale = Double.parseDouble(args[4]);
- int threads = Integer.parseInt(args[5]);
- boolean runEM = Boolean.parseBoolean(args[6]);
-
- PhraseCorpus corpus = new PhraseCorpus(input_fname);
- PhraseCluster cluster = new PhraseCluster(tags, corpus, scale, threads);
+ double scalePT = Double.parseDouble(args[4]);
+ double scaleCT = Double.parseDouble(args[5]);
+ int threads = Integer.parseInt(args[6]);
+ boolean runEM = Boolean.parseBoolean(args[7]);
+
+ assert(tags >= 2);
+ assert(scalePT >= 0);
+ assert(scaleCT >= 0);
+
+ Corpus corpus = null;
+ try {
+ corpus = Corpus.readFromFile(FileUtil.openBufferedReader(input_fname));
+ } catch (IOException e) {
+ System.err.println("Failed to open input file: " + input_fname);
+ e.printStackTrace();
+ System.exit(1);
+ }
+ PhraseCluster cluster = new PhraseCluster(tags, corpus, scalePT, scaleCT, threads);
//PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out");
@@ -48,19 +58,25 @@ public class PhraseCluster { double o;
if (runEM || i < 3)
o = cluster.EM();
- else
- o = cluster.PREM();
+ else if (scaleCT == 0)
+ {
+ if (threads >= 1)
+ o = cluster.PREM_phrase_constraints_parallel();
+ else
+ o = cluster.PREM_phrase_constraints();
+ }
+ else
+ o = cluster.PREM_phrase_context_constraints();
+
//PhraseObjective.ps.
System.out.println("ITER: "+i+" objective: " + o);
last = o;
}
- if (runEM)
- {
- double l1lmax = cluster.posterior_l1lmax();
- System.out.println("Final l1lmax term " + l1lmax + ", total PR objective " + (last - scale*l1lmax));
- // nb. KL is 0 by definition
- }
+ double pl1lmax = cluster.phrase_l1lmax();
+ double cl1lmax = cluster.context_l1lmax();
+ System.out.println("Final posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax);
+ if (runEM) System.out.println("With PR objective " + (last - scalePT*pl1lmax - scaleCT*cl1lmax));
PrintStream ps=io.FileUtil.openOutFile(output_fname);
cluster.displayPosterior(ps);
@@ -75,17 +91,20 @@ public class PhraseCluster { cluster.finish();
}
- public PhraseCluster(int numCluster, PhraseCorpus corpus, double scale, int threads){
+ public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){
K=numCluster;
c=corpus;
- n_words=c.wordLex.size();
- n_phrase=c.data.length;
- this.scale = scale;
- if (threads > 0)
+ n_words=c.getNumWords();
+ n_phrases=c.getNumPhrases();
+ n_contexts=c.getNumContexts();
+ n_positions=c.getNumContextPositions();
+ this.scalePT = scalep;
+ this.scaleCT = scalec;
+ if (threads > 0 && scalec <= 0)
pool = Executors.newFixedThreadPool(threads);
- emit=new double [K][c.numContexts][n_words];
- pi=new double[n_phrase][K];
+ emit=new double [K][n_positions][n_words];
+ pi=new double[n_phrases][K];
for(double [][]i:emit){
for(double []j:i){
@@ -105,30 +124,32 @@ public class PhraseCluster { }
public double EM(){
- double [][][]exp_emit=new double [K][c.numContexts][n_words];
- double [][]exp_pi=new double[n_phrase][K];
+ double [][][]exp_emit=new double [K][n_positions][n_words];
+ double [][]exp_pi=new double[n_phrases][K];
double loglikelihood=0;
//E
- for(int phrase=0;phrase<c.data.length;phrase++){
- int [][] data=c.data[phrase];
- for(int ctx=0;ctx<data.length;ctx++){
- int context[]=data[ctx];
- double p[]=posterior(phrase,context);
+ for(int phrase=0; phrase < n_phrases; phrase++){
+ List<Edge> contexts = c.getEdgesForPhrase(phrase);
+
+ for (int ctx=0; ctx<contexts.size(); ctx++){
+ Edge edge = contexts.get(ctx);
+ double p[]=posterior(edge);
double z = arr.F.l1norm(p);
assert z > 0;
loglikelihood+=Math.log(z);
arr.F.l1normalize(p);
- int contextCnt=context[context.length-1];
+ int count = edge.getCount();
//increment expected count
+ TIntArrayList context = edge.getContext();
for(int tag=0;tag<K;tag++){
- for(int pos=0;pos<context.length-1;pos++){
- exp_emit[tag][pos][context[pos]]+=p[tag]*contextCnt;
+ for(int pos=0;pos<n_positions;pos++){
+ exp_emit[tag][pos][context.get(pos)]+=p[tag]*count;
}
- exp_pi[phrase][tag]+=p[tag]*contextCnt;
+ exp_pi[phrase][tag]+=p[tag]*count;
}
}
}
@@ -153,29 +174,32 @@ public class PhraseCluster { return loglikelihood;
}
- public double PREM(){
- if (pool != null)
- return PREMParallel();
+ public double PREM_phrase_constraints(){
+ assert (scaleCT <= 0);
- double [][][]exp_emit=new double [K][c.numContexts][n_words];
- double [][]exp_pi=new double[n_phrase][K];
+ double [][][]exp_emit=new double[K][n_positions][n_words];
+ double [][]exp_pi=new double[n_phrases][K];
- double loglikelihood=0;
- double primal=0;
+ double loglikelihood=0, kl=0, l1lmax=0, primal=0;
//E
- for(int phrase=0;phrase<c.data.length;phrase++){
+ for(int phrase=0; phrase<n_phrases; phrase++){
PhraseObjective po=new PhraseObjective(this,phrase);
po.optimizeWithProjectedGradientDescent();
double [][] q=po.posterior();
- loglikelihood+=po.llh;
- primal+=po.primal();
+ loglikelihood += po.loglikelihood();
+ kl += po.KL_divergence();
+ l1lmax += po.l1lmax();
+ primal += po.primal();
+ List<Edge> edges = c.getEdgesForPhrase(phrase);
+
for(int edge=0;edge<q.length;edge++){
- int []context=c.data[phrase][edge];
- int contextCnt=context[context.length-1];
+ Edge e = edges.get(edge);
+ TIntArrayList context = e.getContext();
+ int contextCnt = e.getCount();
//increment expected count
for(int tag=0;tag<K;tag++){
- for(int pos=0;pos<context.length-1;pos++){
- exp_emit[tag][pos][context[pos]]+=q[edge][tag]*contextCnt;
+ for(int pos=0;pos<n_positions;pos++){
+ exp_emit[tag][pos][context.get(pos)]+=q[edge][tag]*contextCnt;
}
exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
@@ -183,8 +207,9 @@ public class PhraseCluster { }
}
- System.out.println("Log likelihood: "+loglikelihood);
- System.out.println("Primal Objective: "+primal);
+ System.out.println("\tllh: " + loglikelihood);
+ System.out.println("\tKL: " + kl);
+ System.out.println("\tphrase l1lmax: " + l1lmax);
//M
for(double [][]i:exp_emit){
@@ -204,18 +229,21 @@ public class PhraseCluster { return primal;
}
- public double PREMParallel(){
+ public double PREM_phrase_constraints_parallel()
+ {
assert(pool != null);
+ assert(scaleCT <= 0);
+
final LinkedBlockingQueue<PhraseObjective> expectations
= new LinkedBlockingQueue<PhraseObjective>();
- double [][][]exp_emit=new double [K][c.numContexts][n_words];
- double [][]exp_pi=new double[n_phrase][K];
+ double [][][]exp_emit=new double [K][n_positions][n_words];
+ double [][]exp_pi=new double[n_phrases][K];
- double loglikelihood=0;
- double primal=0;
+ double loglikelihood=0, kl=0, l1lmax=0, primal=0;
+
//E
- for(int phrase=0;phrase<c.data.length;phrase++){
+ for(int phrase=0;phrase<n_phrases;phrase++){
final int p=phrase;
pool.execute(new Runnable() {
public void run() {
@@ -235,7 +263,7 @@ public class PhraseCluster { }
// aggregate the expectations as they become available
- for(int count=0;count<c.data.length;count++) {
+ for(int count=0;count<n_phrases;count++) {
try {
//System.out.println("" + Thread.currentThread().getId() + " reading queue #" + count);
@@ -245,109 +273,139 @@ public class PhraseCluster { int phrase = po.phrase;
//System.out.println("" + Thread.currentThread().getId() + " taken phrase " + phrase);
double [][] q=po.posterior();
- loglikelihood+=po.llh;
- primal+=po.primal();
+ loglikelihood += po.loglikelihood();
+ kl += po.KL_divergence();
+ l1lmax += po.l1lmax();
+ primal += po.primal();
+
+ List<Edge> edges = c.getEdgesForPhrase(phrase);
for(int edge=0;edge<q.length;edge++){
- int []context=c.data[phrase][edge];
- int contextCnt=context[context.length-1];
+ Edge e = edges.get(edge);
+ TIntArrayList context = e.getContext();
+ int contextCnt = e.getCount();
//increment expected count
for(int tag=0;tag<K;tag++){
- for(int pos=0;pos<context.length-1;pos++){
- exp_emit[tag][pos][context[pos]]+=q[edge][tag]*contextCnt;
+ for(int pos=0;pos<n_positions;pos++){
+ exp_emit[tag][pos][context.get(pos)]+=q[edge][tag]*contextCnt;
}
exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
}
}
- } catch (InterruptedException e){
+ } catch (InterruptedException e)
+ {
System.err.println("M-step thread interrupted. Probably fatal!");
e.printStackTrace();
}
}
- System.out.println("Log likelihood: "+loglikelihood);
- System.out.println("Primal Objective: "+primal);
+ System.out.println("\tllh: " + loglikelihood);
+ System.out.println("\tKL: " + kl);
+ System.out.println("\tphrase l1lmax: " + l1lmax);
//M
- for(double [][]i:exp_emit){
- for(double []j:i){
+ for(double [][]i:exp_emit)
+ for(double []j:i)
arr.F.l1normalize(j);
+ emit=exp_emit;
+
+ for(double []j:exp_pi)
+ arr.F.l1normalize(j);
+ pi=exp_pi;
+
+ return primal;
+ }
+
+ public double PREM_phrase_context_constraints(){
+ assert (scaleCT > 0);
+
+ double [][][]exp_emit=new double [K][n_positions][n_words];
+ double [][]exp_pi=new double[n_phrases][K];
+
+ //E step
+ // TODO: cache the lambda values (the null below)
+ PhraseContextObjective pco = new PhraseContextObjective(this, null);
+ pco.optimizeWithProjectedGradientDescent();
+
+ //now extract expectations
+ List<Corpus.Edge> edges = c.getEdges();
+ for(int e = 0; e < edges.size(); ++e)
+ {
+ double [] q = pco.posterior(e);
+ Corpus.Edge edge = edges.get(e);
+
+ TIntArrayList context = edge.getContext();
+ int contextCnt = edge.getCount();
+ //increment expected count
+ for(int tag=0;tag<K;tag++)
+ {
+ for(int pos=0;pos<n_positions;pos++)
+ exp_emit[tag][pos][context.get(pos)]+=q[tag]*contextCnt;
+ exp_pi[edge.getPhraseId()][tag]+=q[tag]*contextCnt;
}
}
+ System.out.println("\tllh: " + pco.loglikelihood());
+ System.out.println("\tKL: " + pco.KL_divergence());
+ System.out.println("\tphrase l1lmax: " + pco.phrase_l1lmax());
+ System.out.println("\tcontext l1lmax: " + pco.context_l1lmax());
+
+ //M step
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ arr.F.l1normalize(j);
emit=exp_emit;
- for(double []j:exp_pi){
+ for(double []j:exp_pi)
arr.F.l1normalize(j);
- }
-
pi=exp_pi;
- return primal;
- }
-
+ return pco.primal();
+ }
+
/**
- *
* @param phrase index of phrase
* @param ctx array of context
* @return unnormalized posterior
*/
- public double[]posterior(int phrase, int[]ctx){
- double[] prob=Arrays.copyOf(pi[phrase], K);
+ public double[] posterior(Corpus.Edge edge)
+ {
+ double[] prob=Arrays.copyOf(pi[edge.getPhraseId()], K);
- for(int tag=0;tag<K;tag++){
- for(int c=0;c<ctx.length-1;c++){
- int word=ctx[c];
- prob[tag]*=emit[tag][c][word];
- }
- }
+ TIntArrayList ctx = edge.getContext();
+ for(int tag=0;tag<K;tag++)
+ for(int c=0;c<n_positions;c++)
+ prob[tag]*=emit[tag][c][ctx.get(c)];
return prob;
}
public void displayPosterior(PrintStream ps)
- {
-
- c.buildList();
-
- for (int i = 0; i < n_phrase; ++i)
+ {
+ for (Edge edge : c.getEdges())
{
- int [][]data=c.data[i];
- for (int[] e: data)
- {
- double probs[] = posterior(i, e);
- arr.F.l1normalize(probs);
+ double probs[] = posterior(edge);
+ arr.F.l1normalize(probs);
- // emit phrase
- ps.print(c.phraseList[i]);
- ps.print("\t");
- ps.print(c.getContextString(e, true));
- int t=arr.F.argmax(probs);
- ps.println(" ||| C=" + t);
-
- //ps.print("||| C=" + e[e.length-1] + " |||");
-
- //ps.print(t+"||| [");
- //for(t=0;t<K;t++){
- // ps.print(probs[t]+", ");
- //}
- // for (int t = 0; t < numTags; ++t)
- // System.out.print(" " + probs[t]);
- //ps.println("]");
- }
+ // emit phrase
+ ps.print(edge.getPhraseString());
+ ps.print("\t");
+ ps.print(edge.getContextString(true));
+ int t=arr.F.argmax(probs);
+ ps.println(" ||| C=" + t);
}
}
public void displayModelParam(PrintStream ps)
{
-
- c.buildList();
+ final double EPS = 1e-6;
ps.println("P(tag|phrase)");
- for (int i = 0; i < n_phrase; ++i)
+ for (int i = 0; i < n_phrases; ++i)
{
- ps.print(c.phraseList[i]);
+ ps.print(c.getPhrase(i));
for(int j=0;j<pi[i].length;j++){
- ps.print("\t"+pi[i][j]);
+ if (pi[i][j] > EPS)
+ ps.print("\t" + j + ": " + pi[i][j]);
}
ps.println();
}
@@ -355,14 +413,11 @@ public class PhraseCluster { ps.println("P(word|tag,position)");
for (int i = 0; i < K; ++i)
{
- for(int position=0;position<c.numContexts;position++){
+ for(int position=0;position<n_positions;position++){
ps.println("tag " + i + " position " + position);
for(int word=0;word<emit[i][position].length;word++){
- //if((word+1)%100==0){
- // ps.println();
- //}
- if (emit[i][position][word] > 1e-10)
- ps.print(c.wordList[word]+"="+emit[i][position][word]+"\t");
+ if (emit[i][position][word] > EPS)
+ ps.print(c.getWord(word)+"="+emit[i][position][word]+"\t");
}
ps.println();
}
@@ -371,19 +426,35 @@ public class PhraseCluster { }
- double posterior_l1lmax()
+ double phrase_l1lmax()
{
double sum=0;
- for(int phrase=0;phrase<c.data.length;phrase++)
+ for(int phrase=0; phrase<n_phrases; phrase++)
{
- int [][] data = c.data[phrase];
double [] maxes = new double[K];
- for(int ctx=0;ctx<data.length;ctx++)
+ for (Edge edge : c.getEdgesForPhrase(phrase))
{
- int context[]=data[ctx];
- double p[]=posterior(phrase,context);
+ double p[] = posterior(edge);
arr.F.l1normalize(p);
+ for(int tag=0;tag<K;tag++)
+ maxes[tag] = Math.max(maxes[tag], p[tag]);
+ }
+ for(int tag=0;tag<K;tag++)
+ sum += maxes[tag];
+ }
+ return sum;
+ }
+ double context_l1lmax()
+ {
+ double sum=0;
+ for(int context=0; context<n_contexts; context++)
+ {
+ double [] maxes = new double[K];
+ for (Edge edge : c.getEdgesForContext(context))
+ {
+ double p[] = posterior(edge);
+ arr.F.l1normalize(p);
for(int tag=0;tag<K;tag++)
maxes[tag] = Math.max(maxes[tag], p[tag]);
}
|