summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java165
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java1
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java8
3 files changed, 129 insertions, 45 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index 8b1e0a8c..cd28c12e 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -1,53 +1,65 @@
package phrase;
import io.FileUtil;
-
import java.io.PrintStream;
import java.util.Arrays;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingQueue;
public class PhraseCluster {
-
- /**@brief number of clusters*/
+
public int K;
+ public double scale;
private int n_phrase;
private int n_words;
public PhraseCorpus c;
+ private ExecutorService pool;
/**@brief
* emit[tag][position][word]
*/
private double emit[][][];
private double pi[][];
+
- public static int ITER=20;
- public static String postFilename="../pdata/posterior.out";
- public static String phraseStatFilename="../pdata/phrase_stat.out";
- private static int NUM_TAG=3;
public static void main(String[] args) {
-
- PhraseCorpus c=new PhraseCorpus(PhraseCorpus.DATA_FILENAME);
-
- PhraseCluster cluster=new PhraseCluster(NUM_TAG,c);
- PhraseObjective.ps=FileUtil.openOutFile(phraseStatFilename);
- for(int i=0;i<ITER;i++){
- PhraseObjective.ps.println("ITER: "+i);
- cluster.PREM();
- // cluster.EM();
+ String input_fname = args[0];
+ int tags = Integer.parseInt(args[1]);
+ String outputDir = args[2];
+ int iterations = Integer.parseInt(args[3]);
+ double scale = Double.parseDouble(args[4]);
+ int threads = Integer.parseInt(args[5]);
+
+ PhraseCorpus corpus = new PhraseCorpus(input_fname);
+ PhraseCluster cluster = new PhraseCluster(tags, corpus, scale, threads);
+
+ PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out");
+
+ for(int i=0;i<iterations;i++){
+ double o = cluster.PREM();
+ //double o = cluster.EM();
+ PhraseObjective.ps.println("ITER: "+i+" objective: " + o);
}
- PrintStream ps=io.FileUtil.openOutFile(postFilename);
+ PrintStream ps=io.FileUtil.openOutFile(outputDir + "/posterior.out");
cluster.displayPosterior(ps);
ps.println();
cluster.displayModelParam(ps);
ps.close();
PhraseObjective.ps.close();
+
+ cluster.finish();
}
- public PhraseCluster(int numCluster,PhraseCorpus corpus){
+ public PhraseCluster(int numCluster, PhraseCorpus corpus, double scale, int threads){
K=numCluster;
c=corpus;
n_words=c.wordLex.size();
n_phrase=c.data.length;
+ this.scale = scale;
+ if (threads > 0)
+ pool = Executors.newFixedThreadPool(threads);
emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
pi=new double[n_phrase][K];
@@ -61,28 +73,15 @@ public class PhraseCluster {
for(double []j:pi){
arr.F.randomise(j);
}
-
- pi[0]=new double[]{
- 0.3,0.5,0.2
- };
-
- double temp[][]=new double[][]{
- {0.11,0.16,0.19,0.11,0.1},
- {0.10,0.15,0.18,0.1,0.11},
- {0.09,0.07,0.12,0.14,0.13}
- };
-
- for(int tag=0;tag<3;tag++){
- for(int word=0;word<4;word++){
- for(int pos=0;pos<4;pos++){
- emit[tag][pos][word]=temp[tag][word];
- }
- }
- }
-
+ }
+
+ public void finish()
+ {
+ if (pool != null)
+ pool.shutdown();
}
- public void EM(){
+ public double EM(){
double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
double [][]exp_pi=new double[n_phrase][K];
@@ -125,9 +124,14 @@ public class PhraseCluster {
}
pi=exp_pi;
+
+ return loglikelihood;
}
- public void PREM(){
+ public double PREM(){
+ if (pool != null)
+ return PREMParallel();
+
double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
double [][]exp_pi=new double[n_phrase][K];
@@ -171,6 +175,89 @@ public class PhraseCluster {
}
pi=exp_pi;
+
+ return primal;
+ }
+
+ public double PREMParallel(){
+ assert(pool != null);
+ final LinkedBlockingQueue<PhraseObjective> expectations
+ = new LinkedBlockingQueue<PhraseObjective>();
+
+ double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
+ double [][]exp_pi=new double[n_phrase][K];
+
+ double loglikelihood=0;
+ double primal=0;
+ //E
+ for(int phrase=0;phrase<c.data.length;phrase++){
+ final int p=phrase;
+ pool.execute(new Runnable() {
+ public void run() {
+ try {
+ //System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
+ PhraseObjective po = new PhraseObjective(PhraseCluster.this, p);
+ po.optimizeWithProjectedGradientDescent();
+ //System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
+ expectations.put(po);
+ //System.out.println("" + Thread.currentThread().getId() + " added to queue " + p);
+ } catch (InterruptedException e) {
+ System.err.println(Thread.currentThread().getId() + " Local e-step thread interrupted; will cause deadlock.");
+ e.printStackTrace();
+ }
+ }
+ });
+ }
+
+ // aggregate the expectations as they become available
+ for(int count=0;count<c.data.length;count++) {
+ try {
+ //System.out.println("" + Thread.currentThread().getId() + " reading queue #" + count);
+
+ // wait (blocking) until something is ready
+ PhraseObjective po = expectations.take();
+ // process
+ int phrase = po.phrase;
+ //System.out.println("" + Thread.currentThread().getId() + " taken phrase " + phrase);
+ double [][] q=po.posterior();
+ loglikelihood+=po.getValue();
+ primal+=po.primal();
+ for(int edge=0;edge<q.length;edge++){
+ int []context=c.data[phrase][edge];
+ int contextCnt=context[context.length-1];
+ //increment expected count
+ for(int tag=0;tag<K;tag++){
+ for(int pos=0;pos<context.length-1;pos++){
+ exp_emit[tag][pos][context[pos]]+=q[edge][tag]*contextCnt;
+ }
+ exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
+ }
+ }
+ } catch (InterruptedException e){
+ System.err.println("M-step thread interrupted. Probably fatal!");
+ e.printStackTrace();
+ }
+ }
+
+ System.out.println("Log likelihood: "+loglikelihood);
+ System.out.println("Primal Objective: "+primal);
+
+ //M
+ for(double [][]i:exp_emit){
+ for(double []j:i){
+ arr.F.l1normalize(j);
+ }
+ }
+
+ emit=exp_emit;
+
+ for(double []j:exp_pi){
+ arr.F.l1normalize(j);
+ }
+
+ pi=exp_pi;
+
+ return primal;
}
/**
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java
index 3902f665..99545371 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java
@@ -12,7 +12,6 @@ public class PhraseCorpus {
public static String LEX_FILENAME="../pdata/lex.out";
- //public static String DATA_FILENAME="../pdata/canned.con";
public static String DATA_FILENAME="../pdata/btec.con";
public static int NUM_CONTEXT=4;
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
index e9e063d6..71c91b96 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
@@ -22,7 +22,6 @@ public class PhraseObjective extends ProjectedObjective{
private static final double GRAD_DIFF = 0.002;
public static double INIT_STEP_SIZE=1;
public static double VAL_DIFF=0.001;
- private double scale=5;
private double c1=0.0001;
private double c2=0.9;
@@ -73,7 +72,7 @@ public class PhraseObjective extends ProjectedObjective{
newPoint = new double[n_param];
gradient = new double[n_param];
initP();
- projection=new SimplexProjection (scale);
+ projection=new SimplexProjection(c.scale);
q=new double [data.length][c.K];
setParameters(parameters);
@@ -111,8 +110,7 @@ public class PhraseObjective extends ProjectedObjective{
}
for(int edge=0;edge<data.length;edge++){
- loglikelihood+=Math.log
- (data[edge][countIdx]*arr.F.l1norm(q[edge]));
+ loglikelihood+=data[edge][countIdx] * Math.log(arr.F.l1norm(q[edge]));
arr.F.l1normalize(q[edge]);
}
@@ -222,7 +220,7 @@ public class PhraseObjective extends ProjectedObjective{
sum+=max;
}
// ps.println(", "+sum);
- l=l-scale*sum;
+ l=l-c.scale*sum;
return l;
}