diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava/src')
4 files changed, 114 insertions, 46 deletions
| diff --git a/gi/posterior-regularisation/prjava/src/io/FileUtil.java b/gi/posterior-regularisation/prjava/src/io/FileUtil.java index 7d9f2bc5..67ce571e 100644 --- a/gi/posterior-regularisation/prjava/src/io/FileUtil.java +++ b/gi/posterior-regularisation/prjava/src/io/FileUtil.java @@ -1,5 +1,7 @@  package io;
  import java.util.*;
 +import java.util.zip.GZIPInputStream;
 +import java.util.zip.GZIPOutputStream;
  import java.io.*;
  public class FileUtil {
  	public static Scanner openInFile(String filename){
 @@ -18,7 +20,10 @@ public class FileUtil {  		BufferedReader r=null;
  		try
  		{
 -			r=(new BufferedReader(new FileReader(new File(filename))));
 +			if (filename.endsWith(".gz"))
 +				r=(new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(new File(filename))))));
 +			else
 +				r=(new BufferedReader(new FileReader(new File(filename))));
  		}catch(IOException ioe){
  			System.out.println(ioe.getMessage());
  		}
 @@ -29,7 +34,10 @@ public class FileUtil {  		PrintStream localps=null;
  		try
  		{
 -			localps=new PrintStream (new FileOutputStream(filename));
 +			if (filename.endsWith(".gz"))
 +				localps=new PrintStream (new GZIPOutputStream(new FileOutputStream(filename)));
 +			else
 +				localps=new PrintStream (new FileOutputStream(filename));
  		}catch(IOException ioe){
  			System.out.println(ioe.getMessage());
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index cd28c12e..731d03ac 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -1,11 +1,16 @@  package phrase;
  import io.FileUtil;
 +
 +import java.io.FileOutputStream;
 +import java.io.IOException;
 +import java.io.OutputStream;
  import java.io.PrintStream;
  import java.util.Arrays;
  import java.util.concurrent.ExecutorService;
  import java.util.concurrent.Executors;
  import java.util.concurrent.LinkedBlockingQueue;
 +import java.util.zip.GZIPOutputStream;
  public class PhraseCluster {
 @@ -26,28 +31,46 @@ public class PhraseCluster {  	public static void main(String[] args) {
  		String input_fname = args[0];
  		int tags = Integer.parseInt(args[1]);
 -		String outputDir = args[2];
 +		String output_fname = args[2];
  		int iterations = Integer.parseInt(args[3]);
  		double scale = Double.parseDouble(args[4]);
  		int threads = Integer.parseInt(args[5]);
 +		boolean runEM = Boolean.parseBoolean(args[6]);
  		PhraseCorpus corpus = new PhraseCorpus(input_fname);
  		PhraseCluster cluster = new PhraseCluster(tags, corpus, scale, threads);
 -		PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out");
 +		//PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out");
 +		double last = 0;
  		for(int i=0;i<iterations;i++){
 -			double o = cluster.PREM();
 -		    //double o = cluster.EM();
 -			PhraseObjective.ps.println("ITER: "+i+" objective: " + o);
 +			
 +			double o;
 +			if (runEM || i < 3) 
 +				o = cluster.EM();
 +			else
 +				o = cluster.PREM();
 +			//PhraseObjective.ps.
 +			System.out.println("ITER: "+i+" objective: " + o);
 +			last = o;
 +		}
 +		
 +		if (runEM)
 +		{
 +			double l1lmax = cluster.posterior_l1lmax();
 +			System.out.println("Final l1lmax term " + l1lmax + ", total PR objective " + (last - scale*l1lmax));
 +			// nb. KL is 0 by definition
  		}
 -		PrintStream ps=io.FileUtil.openOutFile(outputDir + "/posterior.out");
 +		PrintStream ps=io.FileUtil.openOutFile(output_fname);
  		cluster.displayPosterior(ps);
 -		ps.println();
 -		cluster.displayModelParam(ps);
  		ps.close();
 -		PhraseObjective.ps.close();
 +		
 +		//PhraseObjective.ps.close();
 +
 +		//ps = io.FileUtil.openOutFile(outputDir + "/parameters.out");
 +		//cluster.displayModelParam(ps);
 +		//ps.close();
  		cluster.finish();
  	}
 @@ -61,7 +84,7 @@ public class PhraseCluster {  		if (threads > 0)
  			pool = Executors.newFixedThreadPool(threads);
 -		emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
 +		emit=new double [K][c.numContexts][n_words];
  		pi=new double[n_phrase][K];
  		for(double [][]i:emit){
 @@ -82,7 +105,7 @@ public class PhraseCluster {  	}
  	public double EM(){
 -		double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
 +		double [][][]exp_emit=new double [K][c.numContexts][n_words];
  		double [][]exp_pi=new double[n_phrase][K];
  		double loglikelihood=0;
 @@ -93,7 +116,9 @@ public class PhraseCluster {  			for(int ctx=0;ctx<data.length;ctx++){
  				int context[]=data[ctx];
  				double p[]=posterior(phrase,context);
 -				loglikelihood+=Math.log(arr.F.l1norm(p));
 +				double z = arr.F.l1norm(p);
 +				assert z > 0;
 +				loglikelihood+=Math.log(z);
  				arr.F.l1normalize(p);
  				int contextCnt=context[context.length-1];
 @@ -132,7 +157,7 @@ public class PhraseCluster {  		if (pool != null)
  			return PREMParallel();
 -		double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
 +		double [][][]exp_emit=new double [K][c.numContexts][n_words];
  		double [][]exp_pi=new double[n_phrase][K];
  		double loglikelihood=0;
 @@ -142,7 +167,7 @@ public class PhraseCluster {  			PhraseObjective po=new PhraseObjective(this,phrase);
  			po.optimizeWithProjectedGradientDescent();
  			double [][] q=po.posterior();
 -			loglikelihood+=po.getValue();
 +			loglikelihood+=po.llh;
  			primal+=po.primal();
  			for(int edge=0;edge<q.length;edge++){
  				int []context=c.data[phrase][edge];
 @@ -184,7 +209,7 @@ public class PhraseCluster {  		final LinkedBlockingQueue<PhraseObjective> expectations 
  			= new LinkedBlockingQueue<PhraseObjective>();
 -		double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
 +		double [][][]exp_emit=new double [K][c.numContexts][n_words];
  		double [][]exp_pi=new double[n_phrase][K];
  		double loglikelihood=0;
 @@ -220,7 +245,7 @@ public class PhraseCluster {  				int phrase = po.phrase;
  				//System.out.println("" + Thread.currentThread().getId() + " taken phrase " + phrase);
  				double [][] q=po.posterior();
 -				loglikelihood+=po.getValue();
 +				loglikelihood+=po.llh;
  				primal+=po.primal();
  				for(int edge=0;edge<q.length;edge++){
  					int []context=c.data[phrase][edge];
 @@ -295,18 +320,19 @@ public class PhraseCluster {  				// emit phrase
  				ps.print(c.phraseList[i]);
  				ps.print("\t");
 -				ps.print(c.getContextString(e));
 -				ps.print("||| C=" + e[e.length-1] + " |||");
 -
 +				ps.print(c.getContextString(e, true));
  				int t=arr.F.argmax(probs);
 +				ps.println(" ||| C=" + t);
 +
 +				//ps.print("||| C=" + e[e.length-1] + " |||");
 -				ps.print(t+"||| [");
 -				for(t=0;t<K;t++){
 -					ps.print(probs[t]+", ");
 -				}
 +				//ps.print(t+"||| [");
 +				//for(t=0;t<K;t++){
 +				//	ps.print(probs[t]+", ");
 +				//}
  				// for (int t = 0; t < numTags; ++t)
  				// System.out.print(" " + probs[t]);
 -				ps.println("]");
 +				//ps.println("]");
  			}
  		}
  	}
 @@ -329,14 +355,14 @@ public class PhraseCluster {  		ps.println("P(word|tag,position)");
  		for (int i = 0; i < K; ++i)
  		{
 -			ps.println(i);
 -			for(int position=0;position<PhraseCorpus.NUM_CONTEXT;position++){
 -				ps.println(position);
 +			for(int position=0;position<c.numContexts;position++){
 +				ps.println("tag " + i + " position " + position);
  				for(int word=0;word<emit[i][position].length;word++){
 -					if((word+1)%100==0){
 -						ps.println();
 -					}
 -					ps.print(c.wordList[word]+"="+emit[i][position][word]+"\t");
 +					//if((word+1)%100==0){
 +					//	ps.println();
 +					//}
 +					if (emit[i][position][word] > 1e-10)
 +						ps.print(c.wordList[word]+"="+emit[i][position][word]+"\t");
  				}
  				ps.println();
  			}
 @@ -344,4 +370,26 @@ public class PhraseCluster {  		}
  	}
 +	
 +	double posterior_l1lmax()
 +	{
 +		double sum=0;
 +		for(int phrase=0;phrase<c.data.length;phrase++)
 +		{
 +			int [][] data = c.data[phrase];
 +			double [] maxes = new double[K];
 +			for(int ctx=0;ctx<data.length;ctx++)
 +			{
 +				int context[]=data[ctx];
 +				double p[]=posterior(phrase,context);
 +				arr.F.l1normalize(p);
 +
 +				for(int tag=0;tag<K;tag++)
 +					maxes[tag] = Math.max(maxes[tag], p[tag]);
 +			}
 +			for(int tag=0;tag<K;tag++)
 +				sum += maxes[tag];
 +		}
 +		return sum;
 +	}
  }
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java index 99545371..b8f1f24a 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java @@ -9,11 +9,9 @@ import java.util.HashMap;  import java.util.Scanner;
  public class PhraseCorpus {
 -
  	public static String LEX_FILENAME="../pdata/lex.out";
  	public static String DATA_FILENAME="../pdata/btec.con";
 -	public static int NUM_CONTEXT=4;
  	public HashMap<String,Integer>wordLex;
  	public HashMap<String,Integer>phraseLex;
 @@ -23,6 +21,7 @@ public class PhraseCorpus {  	//data[phrase][num context][position]
  	public int data[][][];
 +	public int numContexts;
  	public static void main(String[] args) {
  		// TODO Auto-generated method stub
 @@ -40,6 +39,7 @@ public class PhraseCorpus {  		ArrayList<int[][]>dataList=new ArrayList<int[][]>();
  		String line=null;
 +		numContexts = 0;
  		while((line=readLine(r))!=null){
 @@ -54,7 +54,12 @@ public class PhraseCorpus {  			for(int i=0;i<toks.length;i+=2){
  				String ctx=toks[i];
  				String words[]=ctx.split(" ");
 -				int []context=new int [NUM_CONTEXT+1];
 +				if (numContexts == 0)
 +					numContexts = words.length - 1;
 +				else
 +					assert numContexts == words.length - 1;
 +				
 +				int []context=new int [numContexts+1];
  				int idx=0;
  				for(String word:words){
  					if(word.equals("<PHRASE>")){
 @@ -68,9 +73,7 @@ public class PhraseCorpus {  				String count=toks[i+1];
  				context[idx]=Integer.parseInt(count.trim().substring(2));
 -				
  				ctxList.add(context);
 -				
  			}
  			dataList.add(ctxList.toArray(new int [0][]));
 @@ -157,13 +160,17 @@ public class PhraseCorpus {  		return dict;
  	}
 -	public String getContextString(int context[])
 +	public String getContextString(int context[], boolean addPhraseMarker)
  	{
  		StringBuffer b = new StringBuffer();
  		for (int i=0;i<context.length-1;i++)
  		{
  			if (b.length() > 0)
  				b.append(" ");
 +
 +			if (i == context.length/2)
 +				b.append("<PHRASE> ");
 +			
  			b.append(wordList[context[i]]);
  		}
  		return b.toString();
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 71c91b96..b7c62261 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -20,17 +20,17 @@ import optimization.util.MathUtils;  public class PhraseObjective extends ProjectedObjective{
  	private static final double GRAD_DIFF = 0.002;
 -	public static double INIT_STEP_SIZE=1;
 -	public static double VAL_DIFF=0.001;
 -	private double c1=0.0001;
 -	private double c2=0.9;
 +	public static double INIT_STEP_SIZE = 10;
 +	public static double VAL_DIFF = 0.001; // FIXME needs to be tuned
 +	//private double c1=0.0001; // wolf stuff
 +	//private double c2=0.9;
  	private PhraseCluster c;
  	/**@brief
  	 *  for debugging purposes
  	 */
 -	public static PrintStream ps;
 +	//public static PrintStream ps;
  	/**@brief current phrase being optimzed*/
  	public int phrase;
 @@ -61,7 +61,7 @@ public class PhraseObjective extends ProjectedObjective{  	/**@brief likelihood under p
  	 * 
  	 */
 -	private double llh;
 +	public double llh;
  	public PhraseObjective(PhraseCluster cluster, int phraseIdx){
  		phrase=phraseIdx;
 @@ -181,7 +181,7 @@ public class PhraseObjective extends ProjectedObjective{  		boolean succed = optimizer.optimize(this,stats,compositeStop);
  //		System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
  		if(succed){
 -			System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
 +			//System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
  		}else{
  			System.out.println("Failed to optimize");
  		}
 @@ -208,6 +208,10 @@ public class PhraseObjective extends ProjectedObjective{  		double kl=-loglikelihood
  			+MathUtils.dotProduct(parameters, gradient);
  //		ps.print(", "+kl);
 +		//System.out.println("llh " + llh);
 +		//System.out.println("kl " + kl);
 +		
 +
  		l=l-kl;
  		double sum=0;
  		for(int tag=0;tag<c.K;tag++){
 @@ -219,6 +223,7 @@ public class PhraseObjective extends ProjectedObjective{  			}
  			sum+=max;
  		}
 +		//System.out.println("l1lmax " + sum);
  //		ps.println(", "+sum);
  		l=l-c.scale*sum;
  		return l;
 | 
