summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
diff options
context:
space:
mode:
authortrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-07 14:11:42 +0000
committertrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-07 14:11:42 +0000
commit5517e0b82f9503c59c10fc0167fa9d7fbdca1e64 (patch)
treea5ff0f26fc78627a52288fcabde843b4f0dfefc9 /gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
parent946f7569a487209a35567e804d92edd1a84f2619 (diff)
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@173 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java110
1 files changed, 79 insertions, 31 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index cd28c12e..731d03ac 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -1,11 +1,16 @@
package phrase;
import io.FileUtil;
+
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
+import java.util.zip.GZIPOutputStream;
public class PhraseCluster {
@@ -26,28 +31,46 @@ public class PhraseCluster {
public static void main(String[] args) {
String input_fname = args[0];
int tags = Integer.parseInt(args[1]);
- String outputDir = args[2];
+ String output_fname = args[2];
int iterations = Integer.parseInt(args[3]);
double scale = Double.parseDouble(args[4]);
int threads = Integer.parseInt(args[5]);
+ boolean runEM = Boolean.parseBoolean(args[6]);
PhraseCorpus corpus = new PhraseCorpus(input_fname);
PhraseCluster cluster = new PhraseCluster(tags, corpus, scale, threads);
- PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out");
+ //PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out");
+ double last = 0;
for(int i=0;i<iterations;i++){
- double o = cluster.PREM();
- //double o = cluster.EM();
- PhraseObjective.ps.println("ITER: "+i+" objective: " + o);
+
+ double o;
+ if (runEM || i < 3)
+ o = cluster.EM();
+ else
+ o = cluster.PREM();
+ //PhraseObjective.ps.
+ System.out.println("ITER: "+i+" objective: " + o);
+ last = o;
+ }
+
+ if (runEM)
+ {
+ double l1lmax = cluster.posterior_l1lmax();
+ System.out.println("Final l1lmax term " + l1lmax + ", total PR objective " + (last - scale*l1lmax));
+ // nb. KL is 0 by definition
}
- PrintStream ps=io.FileUtil.openOutFile(outputDir + "/posterior.out");
+ PrintStream ps=io.FileUtil.openOutFile(output_fname);
cluster.displayPosterior(ps);
- ps.println();
- cluster.displayModelParam(ps);
ps.close();
- PhraseObjective.ps.close();
+
+ //PhraseObjective.ps.close();
+
+ //ps = io.FileUtil.openOutFile(outputDir + "/parameters.out");
+ //cluster.displayModelParam(ps);
+ //ps.close();
cluster.finish();
}
@@ -61,7 +84,7 @@ public class PhraseCluster {
if (threads > 0)
pool = Executors.newFixedThreadPool(threads);
- emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
+ emit=new double [K][c.numContexts][n_words];
pi=new double[n_phrase][K];
for(double [][]i:emit){
@@ -82,7 +105,7 @@ public class PhraseCluster {
}
public double EM(){
- double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
+ double [][][]exp_emit=new double [K][c.numContexts][n_words];
double [][]exp_pi=new double[n_phrase][K];
double loglikelihood=0;
@@ -93,7 +116,9 @@ public class PhraseCluster {
for(int ctx=0;ctx<data.length;ctx++){
int context[]=data[ctx];
double p[]=posterior(phrase,context);
- loglikelihood+=Math.log(arr.F.l1norm(p));
+ double z = arr.F.l1norm(p);
+ assert z > 0;
+ loglikelihood+=Math.log(z);
arr.F.l1normalize(p);
int contextCnt=context[context.length-1];
@@ -132,7 +157,7 @@ public class PhraseCluster {
if (pool != null)
return PREMParallel();
- double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
+ double [][][]exp_emit=new double [K][c.numContexts][n_words];
double [][]exp_pi=new double[n_phrase][K];
double loglikelihood=0;
@@ -142,7 +167,7 @@ public class PhraseCluster {
PhraseObjective po=new PhraseObjective(this,phrase);
po.optimizeWithProjectedGradientDescent();
double [][] q=po.posterior();
- loglikelihood+=po.getValue();
+ loglikelihood+=po.llh;
primal+=po.primal();
for(int edge=0;edge<q.length;edge++){
int []context=c.data[phrase][edge];
@@ -184,7 +209,7 @@ public class PhraseCluster {
final LinkedBlockingQueue<PhraseObjective> expectations
= new LinkedBlockingQueue<PhraseObjective>();
- double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
+ double [][][]exp_emit=new double [K][c.numContexts][n_words];
double [][]exp_pi=new double[n_phrase][K];
double loglikelihood=0;
@@ -220,7 +245,7 @@ public class PhraseCluster {
int phrase = po.phrase;
//System.out.println("" + Thread.currentThread().getId() + " taken phrase " + phrase);
double [][] q=po.posterior();
- loglikelihood+=po.getValue();
+ loglikelihood+=po.llh;
primal+=po.primal();
for(int edge=0;edge<q.length;edge++){
int []context=c.data[phrase][edge];
@@ -295,18 +320,19 @@ public class PhraseCluster {
// emit phrase
ps.print(c.phraseList[i]);
ps.print("\t");
- ps.print(c.getContextString(e));
- ps.print("||| C=" + e[e.length-1] + " |||");
-
+ ps.print(c.getContextString(e, true));
int t=arr.F.argmax(probs);
+ ps.println(" ||| C=" + t);
+
+ //ps.print("||| C=" + e[e.length-1] + " |||");
- ps.print(t+"||| [");
- for(t=0;t<K;t++){
- ps.print(probs[t]+", ");
- }
+ //ps.print(t+"||| [");
+ //for(t=0;t<K;t++){
+ // ps.print(probs[t]+", ");
+ //}
// for (int t = 0; t < numTags; ++t)
// System.out.print(" " + probs[t]);
- ps.println("]");
+ //ps.println("]");
}
}
}
@@ -329,14 +355,14 @@ public class PhraseCluster {
ps.println("P(word|tag,position)");
for (int i = 0; i < K; ++i)
{
- ps.println(i);
- for(int position=0;position<PhraseCorpus.NUM_CONTEXT;position++){
- ps.println(position);
+ for(int position=0;position<c.numContexts;position++){
+ ps.println("tag " + i + " position " + position);
for(int word=0;word<emit[i][position].length;word++){
- if((word+1)%100==0){
- ps.println();
- }
- ps.print(c.wordList[word]+"="+emit[i][position][word]+"\t");
+ //if((word+1)%100==0){
+ // ps.println();
+ //}
+ if (emit[i][position][word] > 1e-10)
+ ps.print(c.wordList[word]+"="+emit[i][position][word]+"\t");
}
ps.println();
}
@@ -344,4 +370,26 @@ public class PhraseCluster {
}
}
+
+ double posterior_l1lmax()
+ {
+ double sum=0;
+ for(int phrase=0;phrase<c.data.length;phrase++)
+ {
+ int [][] data = c.data[phrase];
+ double [] maxes = new double[K];
+ for(int ctx=0;ctx<data.length;ctx++)
+ {
+ int context[]=data[ctx];
+ double p[]=posterior(phrase,context);
+ arr.F.l1normalize(p);
+
+ for(int tag=0;tag<K;tag++)
+ maxes[tag] = Math.max(maxes[tag], p[tag]);
+ }
+ for(int tag=0;tag<K;tag++)
+ sum += maxes[tag];
+ }
+ return sum;
+ }
}