From 5517e0b82f9503c59c10fc0167fa9d7fbdca1e64 Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Wed, 7 Jul 2010 14:11:42 +0000 Subject: git-svn-id: https://ws10smt.googlecode.com/svn/trunk@173 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../prjava/src/io/FileUtil.java | 12 ++- .../prjava/src/phrase/PhraseCluster.java | 110 +++++++++++++++------ .../prjava/src/phrase/PhraseCorpus.java | 19 ++-- .../prjava/src/phrase/PhraseObjective.java | 19 ++-- 4 files changed, 114 insertions(+), 46 deletions(-) (limited to 'gi/posterior-regularisation/prjava/src') diff --git a/gi/posterior-regularisation/prjava/src/io/FileUtil.java b/gi/posterior-regularisation/prjava/src/io/FileUtil.java index 7d9f2bc5..67ce571e 100644 --- a/gi/posterior-regularisation/prjava/src/io/FileUtil.java +++ b/gi/posterior-regularisation/prjava/src/io/FileUtil.java @@ -1,5 +1,7 @@ package io; import java.util.*; +import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; import java.io.*; public class FileUtil { public static Scanner openInFile(String filename){ @@ -18,7 +20,10 @@ public class FileUtil { BufferedReader r=null; try { - r=(new BufferedReader(new FileReader(new File(filename)))); + if (filename.endsWith(".gz")) + r=(new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(new File(filename)))))); + else + r=(new BufferedReader(new FileReader(new File(filename)))); }catch(IOException ioe){ System.out.println(ioe.getMessage()); } @@ -29,7 +34,10 @@ public class FileUtil { PrintStream localps=null; try { - localps=new PrintStream (new FileOutputStream(filename)); + if (filename.endsWith(".gz")) + localps=new PrintStream (new GZIPOutputStream(new FileOutputStream(filename))); + else + localps=new PrintStream (new FileOutputStream(filename)); }catch(IOException ioe){ System.out.println(ioe.getMessage()); diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index cd28c12e..731d03ac 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -1,11 +1,16 @@ package phrase; import io.FileUtil; + +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; import java.io.PrintStream; import java.util.Arrays; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; +import java.util.zip.GZIPOutputStream; public class PhraseCluster { @@ -26,28 +31,46 @@ public class PhraseCluster { public static void main(String[] args) { String input_fname = args[0]; int tags = Integer.parseInt(args[1]); - String outputDir = args[2]; + String output_fname = args[2]; int iterations = Integer.parseInt(args[3]); double scale = Double.parseDouble(args[4]); int threads = Integer.parseInt(args[5]); + boolean runEM = Boolean.parseBoolean(args[6]); PhraseCorpus corpus = new PhraseCorpus(input_fname); PhraseCluster cluster = new PhraseCluster(tags, corpus, scale, threads); - PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out"); + //PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out"); + double last = 0; for(int i=0;i 0) pool = Executors.newFixedThreadPool(threads); - emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words]; + emit=new double [K][c.numContexts][n_words]; pi=new double[n_phrase][K]; for(double [][]i:emit){ @@ -82,7 +105,7 @@ public class PhraseCluster { } public double EM(){ - double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words]; + double [][][]exp_emit=new double [K][c.numContexts][n_words]; double [][]exp_pi=new double[n_phrase][K]; double loglikelihood=0; @@ -93,7 +116,9 @@ public class PhraseCluster { for(int ctx=0;ctx 0; + loglikelihood+=Math.log(z); arr.F.l1normalize(p); int contextCnt=context[context.length-1]; @@ -132,7 +157,7 @@ public class PhraseCluster { if (pool != null) return PREMParallel(); - double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words]; + double [][][]exp_emit=new double [K][c.numContexts][n_words]; double [][]exp_pi=new double[n_phrase][K]; double loglikelihood=0; @@ -142,7 +167,7 @@ public class PhraseCluster { PhraseObjective po=new PhraseObjective(this,phrase); po.optimizeWithProjectedGradientDescent(); double [][] q=po.posterior(); - loglikelihood+=po.getValue(); + loglikelihood+=po.llh; primal+=po.primal(); for(int edge=0;edge expectations = new LinkedBlockingQueue(); - double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words]; + double [][][]exp_emit=new double [K][c.numContexts][n_words]; double [][]exp_pi=new double[n_phrase][K]; double loglikelihood=0; @@ -220,7 +245,7 @@ public class PhraseCluster { int phrase = po.phrase; //System.out.println("" + Thread.currentThread().getId() + " taken phrase " + phrase); double [][] q=po.posterior(); - loglikelihood+=po.getValue(); + loglikelihood+=po.llh; primal+=po.primal(); for(int edge=0;edge 1e-10) + ps.print(c.wordList[word]+"="+emit[i][position][word]+"\t"); } ps.println(); } @@ -344,4 +370,26 @@ public class PhraseCluster { } } + + double posterior_l1lmax() + { + double sum=0; + for(int phrase=0;phrasewordLex; public HashMapphraseLex; @@ -23,6 +21,7 @@ public class PhraseCorpus { //data[phrase][num context][position] public int data[][][]; + public int numContexts; public static void main(String[] args) { // TODO Auto-generated method stub @@ -40,6 +39,7 @@ public class PhraseCorpus { ArrayListdataList=new ArrayList(); String line=null; + numContexts = 0; while((line=readLine(r))!=null){ @@ -54,7 +54,12 @@ public class PhraseCorpus { for(int i=0;i")){ @@ -68,9 +73,7 @@ public class PhraseCorpus { String count=toks[i+1]; context[idx]=Integer.parseInt(count.trim().substring(2)); - ctxList.add(context); - } dataList.add(ctxList.toArray(new int [0][])); @@ -157,13 +160,17 @@ public class PhraseCorpus { return dict; } - public String getContextString(int context[]) + public String getContextString(int context[], boolean addPhraseMarker) { StringBuffer b = new StringBuffer(); for (int i=0;i 0) b.append(" "); + + if (i == context.length/2) + b.append(" "); + b.append(wordList[context[i]]); } return b.toString(); diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 71c91b96..b7c62261 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -20,17 +20,17 @@ import optimization.util.MathUtils; public class PhraseObjective extends ProjectedObjective{ private static final double GRAD_DIFF = 0.002; - public static double INIT_STEP_SIZE=1; - public static double VAL_DIFF=0.001; - private double c1=0.0001; - private double c2=0.9; + public static double INIT_STEP_SIZE = 10; + public static double VAL_DIFF = 0.001; // FIXME needs to be tuned + //private double c1=0.0001; // wolf stuff + //private double c2=0.9; private PhraseCluster c; /**@brief * for debugging purposes */ - public static PrintStream ps; + //public static PrintStream ps; /**@brief current phrase being optimzed*/ public int phrase; @@ -61,7 +61,7 @@ public class PhraseObjective extends ProjectedObjective{ /**@brief likelihood under p * */ - private double llh; + public double llh; public PhraseObjective(PhraseCluster cluster, int phraseIdx){ phrase=phraseIdx; @@ -181,7 +181,7 @@ public class PhraseObjective extends ProjectedObjective{ boolean succed = optimizer.optimize(this,stats,compositeStop); // System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1)); if(succed){ - System.out.println("Ended optimization in " + optimizer.getCurrentIteration()); + //System.out.println("Ended optimization in " + optimizer.getCurrentIteration()); }else{ System.out.println("Failed to optimize"); } @@ -208,6 +208,10 @@ public class PhraseObjective extends ProjectedObjective{ double kl=-loglikelihood +MathUtils.dotProduct(parameters, gradient); // ps.print(", "+kl); + //System.out.println("llh " + llh); + //System.out.println("kl " + kl); + + l=l-kl; double sum=0; for(int tag=0;tag