summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java165
1 files changed, 126 insertions, 39 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index 8b1e0a8c..cd28c12e 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -1,53 +1,65 @@
package phrase;
import io.FileUtil;
-
import java.io.PrintStream;
import java.util.Arrays;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingQueue;
public class PhraseCluster {
-
- /**@brief number of clusters*/
+
public int K;
+ public double scale;
private int n_phrase;
private int n_words;
public PhraseCorpus c;
+ private ExecutorService pool;
/**@brief
* emit[tag][position][word]
*/
private double emit[][][];
private double pi[][];
+
- public static int ITER=20;
- public static String postFilename="../pdata/posterior.out";
- public static String phraseStatFilename="../pdata/phrase_stat.out";
- private static int NUM_TAG=3;
public static void main(String[] args) {
-
- PhraseCorpus c=new PhraseCorpus(PhraseCorpus.DATA_FILENAME);
-
- PhraseCluster cluster=new PhraseCluster(NUM_TAG,c);
- PhraseObjective.ps=FileUtil.openOutFile(phraseStatFilename);
- for(int i=0;i<ITER;i++){
- PhraseObjective.ps.println("ITER: "+i);
- cluster.PREM();
- // cluster.EM();
+ String input_fname = args[0];
+ int tags = Integer.parseInt(args[1]);
+ String outputDir = args[2];
+ int iterations = Integer.parseInt(args[3]);
+ double scale = Double.parseDouble(args[4]);
+ int threads = Integer.parseInt(args[5]);
+
+ PhraseCorpus corpus = new PhraseCorpus(input_fname);
+ PhraseCluster cluster = new PhraseCluster(tags, corpus, scale, threads);
+
+ PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out");
+
+ for(int i=0;i<iterations;i++){
+ double o = cluster.PREM();
+ //double o = cluster.EM();
+ PhraseObjective.ps.println("ITER: "+i+" objective: " + o);
}
- PrintStream ps=io.FileUtil.openOutFile(postFilename);
+ PrintStream ps=io.FileUtil.openOutFile(outputDir + "/posterior.out");
cluster.displayPosterior(ps);
ps.println();
cluster.displayModelParam(ps);
ps.close();
PhraseObjective.ps.close();
+
+ cluster.finish();
}
- public PhraseCluster(int numCluster,PhraseCorpus corpus){
+ public PhraseCluster(int numCluster, PhraseCorpus corpus, double scale, int threads){
K=numCluster;
c=corpus;
n_words=c.wordLex.size();
n_phrase=c.data.length;
+ this.scale = scale;
+ if (threads > 0)
+ pool = Executors.newFixedThreadPool(threads);
emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
pi=new double[n_phrase][K];
@@ -61,28 +73,15 @@ public class PhraseCluster {
for(double []j:pi){
arr.F.randomise(j);
}
-
- pi[0]=new double[]{
- 0.3,0.5,0.2
- };
-
- double temp[][]=new double[][]{
- {0.11,0.16,0.19,0.11,0.1},
- {0.10,0.15,0.18,0.1,0.11},
- {0.09,0.07,0.12,0.14,0.13}
- };
-
- for(int tag=0;tag<3;tag++){
- for(int word=0;word<4;word++){
- for(int pos=0;pos<4;pos++){
- emit[tag][pos][word]=temp[tag][word];
- }
- }
- }
-
+ }
+
+ public void finish()
+ {
+ if (pool != null)
+ pool.shutdown();
}
- public void EM(){
+ public double EM(){
double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
double [][]exp_pi=new double[n_phrase][K];
@@ -125,9 +124,14 @@ public class PhraseCluster {
}
pi=exp_pi;
+
+ return loglikelihood;
}
- public void PREM(){
+ public double PREM(){
+ if (pool != null)
+ return PREMParallel();
+
double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
double [][]exp_pi=new double[n_phrase][K];
@@ -171,6 +175,89 @@ public class PhraseCluster {
}
pi=exp_pi;
+
+ return primal;
+ }
+
+ public double PREMParallel(){
+ assert(pool != null);
+ final LinkedBlockingQueue<PhraseObjective> expectations
+ = new LinkedBlockingQueue<PhraseObjective>();
+
+ double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
+ double [][]exp_pi=new double[n_phrase][K];
+
+ double loglikelihood=0;
+ double primal=0;
+ //E
+ for(int phrase=0;phrase<c.data.length;phrase++){
+ final int p=phrase;
+ pool.execute(new Runnable() {
+ public void run() {
+ try {
+ //System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
+ PhraseObjective po = new PhraseObjective(PhraseCluster.this, p);
+ po.optimizeWithProjectedGradientDescent();
+ //System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
+ expectations.put(po);
+ //System.out.println("" + Thread.currentThread().getId() + " added to queue " + p);
+ } catch (InterruptedException e) {
+ System.err.println(Thread.currentThread().getId() + " Local e-step thread interrupted; will cause deadlock.");
+ e.printStackTrace();
+ }
+ }
+ });
+ }
+
+ // aggregate the expectations as they become available
+ for(int count=0;count<c.data.length;count++) {
+ try {
+ //System.out.println("" + Thread.currentThread().getId() + " reading queue #" + count);
+
+ // wait (blocking) until something is ready
+ PhraseObjective po = expectations.take();
+ // process
+ int phrase = po.phrase;
+ //System.out.println("" + Thread.currentThread().getId() + " taken phrase " + phrase);
+ double [][] q=po.posterior();
+ loglikelihood+=po.getValue();
+ primal+=po.primal();
+ for(int edge=0;edge<q.length;edge++){
+ int []context=c.data[phrase][edge];
+ int contextCnt=context[context.length-1];
+ //increment expected count
+ for(int tag=0;tag<K;tag++){
+ for(int pos=0;pos<context.length-1;pos++){
+ exp_emit[tag][pos][context[pos]]+=q[edge][tag]*contextCnt;
+ }
+ exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
+ }
+ }
+ } catch (InterruptedException e){
+ System.err.println("M-step thread interrupted. Probably fatal!");
+ e.printStackTrace();
+ }
+ }
+
+ System.out.println("Log likelihood: "+loglikelihood);
+ System.out.println("Primal Objective: "+primal);
+
+ //M
+ for(double [][]i:exp_emit){
+ for(double []j:i){
+ arr.F.l1normalize(j);
+ }
+ }
+
+ emit=exp_emit;
+
+ for(double []j:exp_pi){
+ arr.F.l1normalize(j);
+ }
+
+ pi=exp_pi;
+
+ return primal;
}
/**