diff options
Diffstat (limited to 'gi')
| -rw-r--r-- | gi/posterior-regularisation/prjava/build.xml | 1 | ||||
| -rw-r--r-- | gi/posterior-regularisation/prjava/prjava-20100707.jar | bin | 933814 -> 0 bytes | |||
| -rw-r--r-- | gi/posterior-regularisation/prjava/prjava-20100707_1.jar | bin | 934198 -> 0 bytes | |||
| -rw-r--r-- | gi/posterior-regularisation/prjava/src/arr/F.java | 2 | ||||
| -rw-r--r-- | gi/posterior-regularisation/prjava/src/io/FileUtil.java | 47 | ||||
| -rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/Corpus.java | 8 | ||||
| -rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java | 71 | ||||
| -rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java | 2 | ||||
| -rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java | 17 | ||||
| -rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java | 2 | ||||
| -rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/Trainer.java | 150 | ||||
| -rwxr-xr-x | gi/posterior-regularisation/prjava/train-PR-cluster.sh | 2 | 
12 files changed, 194 insertions, 108 deletions
| diff --git a/gi/posterior-regularisation/prjava/build.xml b/gi/posterior-regularisation/prjava/build.xml index c9ed2e8d..40199144 100644 --- a/gi/posterior-regularisation/prjava/build.xml +++ b/gi/posterior-regularisation/prjava/build.xml @@ -6,6 +6,7 @@    <path id="classpath">        <pathelement location="lib/trove-2.0.2.jar"/>        <pathelement location="lib/optimization.jar"/> +      <pathelement location="lib/jopt-simple-3.2.jar"/>    </path>    <target name="init"> diff --git a/gi/posterior-regularisation/prjava/prjava-20100707.jar b/gi/posterior-regularisation/prjava/prjava-20100707.jarBinary files differ deleted file mode 100644 index 195374d9..00000000 --- a/gi/posterior-regularisation/prjava/prjava-20100707.jar +++ /dev/null diff --git a/gi/posterior-regularisation/prjava/prjava-20100707_1.jar b/gi/posterior-regularisation/prjava/prjava-20100707_1.jarBinary files differ deleted file mode 100644 index a65f4d43..00000000 --- a/gi/posterior-regularisation/prjava/prjava-20100707_1.jar +++ /dev/null diff --git a/gi/posterior-regularisation/prjava/src/arr/F.java b/gi/posterior-regularisation/prjava/src/arr/F.java index 5821af42..7f2b140a 100644 --- a/gi/posterior-regularisation/prjava/src/arr/F.java +++ b/gi/posterior-regularisation/prjava/src/arr/F.java @@ -3,7 +3,7 @@ package arr;  import java.util.Random;
  public class F {
 -	private static Random rng = new Random(); //(9562724l);
 +	public static Random rng = new Random();
  	public static void randomise(double probs[])
  	{
 diff --git a/gi/posterior-regularisation/prjava/src/io/FileUtil.java b/gi/posterior-regularisation/prjava/src/io/FileUtil.java index 67ce571e..81e7747b 100644 --- a/gi/posterior-regularisation/prjava/src/io/FileUtil.java +++ b/gi/posterior-regularisation/prjava/src/io/FileUtil.java @@ -3,7 +3,24 @@ import java.util.*;  import java.util.zip.GZIPInputStream;
  import java.util.zip.GZIPOutputStream;
  import java.io.*;
 -public class FileUtil {
 +public class FileUtil 
 +{
 +	public static BufferedReader reader(File file) throws FileNotFoundException, IOException
 +	{
 +		if (file.getName().endsWith(".gz"))
 +			return new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(file))));
 +		else
 +			return new BufferedReader(new FileReader(file));
 +	}
 +	
 +	public static PrintStream printstream(File file) throws FileNotFoundException, IOException
 +	{
 +		if (file.getName().endsWith(".gz"))
 +			return new PrintStream(new GZIPOutputStream(new FileOutputStream(file)));
 +		else
 +			return new PrintStream(new FileOutputStream(file));
 +	}
 +
  	public static Scanner openInFile(String filename){
  		Scanner localsc=null;
  		try
 @@ -16,34 +33,6 @@ public class FileUtil {  		return localsc;
  	}
 -	public static BufferedReader openBufferedReader(String filename){
 -		BufferedReader r=null;
 -		try
 -		{
 -			if (filename.endsWith(".gz"))
 -				r=(new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(new File(filename))))));
 -			else
 -				r=(new BufferedReader(new FileReader(new File(filename))));
 -		}catch(IOException ioe){
 -			System.out.println(ioe.getMessage());
 -		}
 -		return r;		
 -	}
 -	
 -	public static PrintStream  openOutFile(String filename){
 -		PrintStream localps=null;
 -		try
 -		{
 -			if (filename.endsWith(".gz"))
 -				localps=new PrintStream (new GZIPOutputStream(new FileOutputStream(filename)));
 -			else
 -				localps=new PrintStream (new FileOutputStream(filename));
 -
 -		}catch(IOException ioe){
 -			System.out.println(ioe.getMessage());
 -		}
 -		return localps;
 -	}
  	public static FileInputStream openInputStream(String infilename){
  		FileInputStream fis=null;
  		try {
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java index d5e856ca..81264ab9 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java @@ -217,5 +217,11 @@ public class Corpus  		}  		return c; +	} + +	public void printStats(PrintStream out)  +	{ +		out.println("Corpus has " + edges.size() + " edges " + phraseLexicon.size() + " phrases "  +				+ contextLexicon.size() + " contexts and " + wordLexicon.size() + " word types");  	}	 -} +}
\ No newline at end of file diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 63a60682..7d7c46dd 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -25,73 +25,6 @@ public class PhraseCluster {  	// pi[phrase][tag] = p(tag | phrase)
  	private double pi[][];
 -	public static void main(String[] args) 
 -	{
 -		String input_fname = args[0];
 -		int tags = Integer.parseInt(args[1]);
 -		String output_fname = args[2];
 -		int iterations = Integer.parseInt(args[3]);
 -		double scalePT = Double.parseDouble(args[4]);
 -		double scaleCT = Double.parseDouble(args[5]);
 -		int threads = Integer.parseInt(args[6]);
 -		boolean runEM = Boolean.parseBoolean(args[7]);
 -		
 -		assert(tags >= 2);
 -		assert(scalePT >= 0);
 -		assert(scaleCT >= 0);
 -		
 -		Corpus corpus = null;
 -		try {
 -			corpus = Corpus.readFromFile(FileUtil.openBufferedReader(input_fname));
 -		} catch (IOException e) {
 -			System.err.println("Failed to open input file: " + input_fname);
 -			e.printStackTrace();
 -			System.exit(1);
 -		}
 -		PhraseCluster cluster = new PhraseCluster(tags, corpus, scalePT, scaleCT, threads);
 -		
 -		//PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out");
 -		
 -		double last = 0;
 -		for(int i=0;i<iterations;i++){
 -			
 -			double o;
 -			if (runEM || i < 3) 
 -				o = cluster.EM();
 -			else if (scaleCT == 0)
 -			{
 -				if (threads >= 1)
 -					o = cluster.PREM_phrase_constraints_parallel();
 -				else
 -					o = cluster.PREM_phrase_constraints();
 -			}
 -			else 
 -				o = cluster.PREM_phrase_context_constraints();
 -			
 -			//PhraseObjective.ps.
 -			System.out.println("ITER: "+i+" objective: " + o);
 -			last = o;
 -		}
 -		
 -		double pl1lmax = cluster.phrase_l1lmax();
 -		double cl1lmax = cluster.context_l1lmax();
 -		System.out.println("Final posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax);
 -		if (runEM) System.out.println("With PR objective " + (last - scalePT*pl1lmax - scaleCT*cl1lmax));
 -		
 -		PrintStream ps=io.FileUtil.openOutFile(output_fname);
 -		cluster.displayPosterior(ps);
 -		ps.close();
 -		
 -		//PhraseObjective.ps.close();
 -
 -		//ps = io.FileUtil.openOutFile(outputDir + "/parameters.out");
 -		//cluster.displayModelParam(ps);
 -		//ps.close();
 -		
 -		if (cluster.pool != null)
 -			cluster.pool.shutdown();
 -	}
 -
  	public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){
  		K=numCluster;
  		c=corpus;
 @@ -134,7 +67,7 @@ public class PhraseCluster {  				double p[]=posterior(edge);
  				double z = arr.F.l1norm(p);
  				assert z > 0;
 -				loglikelihood+=Math.log(z);
 +				loglikelihood += edge.getCount() * Math.log(z);
  				arr.F.l1normalize(p);
  				int count = edge.getCount();
 @@ -150,7 +83,7 @@ public class PhraseCluster {  			}
  		}
 -		System.out.println("Log likelihood: "+loglikelihood);
 +		//System.out.println("Log likelihood: "+loglikelihood);
  		//M
  		for(double [][]i:exp_emit){
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java index fbf43a7f..15bd29c2 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java @@ -26,7 +26,7 @@ import phrase.Corpus.Edge;  public class PhraseContextObjective extends ProjectedObjective
  {
  	private static final double GRAD_DIFF = 0.00002;
 -	private static double INIT_STEP_SIZE = 10;
 +	private static double INIT_STEP_SIZE = 300;
  	private static double VAL_DIFF = 1e-4; // FIXME needs to be tuned
  	private static int ITERATIONS = 100;
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java index 11e948ff..903e47c8 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java @@ -1,7 +1,11 @@  package phrase;
 +import io.FileUtil;
 +
  import java.io.BufferedInputStream;
  import java.io.BufferedReader;
 +import java.io.File;
 +import java.io.FileNotFoundException;
  import java.io.IOException;
  import java.io.PrintStream;
  import java.util.ArrayList;
 @@ -20,8 +24,9 @@ public class PhraseCorpus  	public int data[][][];
  	public int numContexts;	
 -	public PhraseCorpus(String filename){
 -		BufferedReader r=io.FileUtil.openBufferedReader(filename);
 +	public PhraseCorpus(String filename) throws FileNotFoundException, IOException
 +	{
 +		BufferedReader r = FileUtil.reader(new File(filename));
  		phraseLex=new HashMap<String,Integer>();
  		wordLex=new HashMap<String,Integer>();
 @@ -84,8 +89,9 @@ public class PhraseCorpus  	}
  	//for debugging
 -	public void saveLex(String lexFilename){
 -		PrintStream ps=io.FileUtil.openOutFile(lexFilename);
 +	public void saveLex(String lexFilename) throws FileNotFoundException, IOException
 +	{
 +		PrintStream ps = FileUtil.printstream(new File(lexFilename));
  		ps.println("Phrase Lexicon");
  		ps.println(phraseLex.size());
  		printDict(phraseLex,ps);
 @@ -175,7 +181,8 @@ public class PhraseCorpus  		return null;
  	}
 -	public static void main(String[] args) {
 +	public static void main(String[] args) throws Exception 
 +	{
  		String LEX_FILENAME="../pdata/lex.out";
  		String DATA_FILENAME="../pdata/btec.con";
  		PhraseCorpus c=new PhraseCorpus(DATA_FILENAME);
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 0a76e2dc..3314f74a 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -21,7 +21,7 @@ import optimization.util.MathUtils;  public class PhraseObjective extends ProjectedObjective
  {
  	static final double GRAD_DIFF = 0.00002;
 -	static double INIT_STEP_SIZE = 10;
 +	static double INIT_STEP_SIZE = 300;
  	static double VAL_DIFF = 1e-4; // FIXME needs to be tuned - and this might be too weak
  	static int ITERATIONS = 100;
  	//private double c1=0.0001; // wolf stuff
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java new file mode 100644 index 00000000..b19f3fb9 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -0,0 +1,150 @@ +package phrase; + +import io.FileUtil; +import joptsimple.OptionParser; +import joptsimple.OptionSet; +import java.io.File; +import java.io.IOException; +import java.io.PrintStream; +import java.util.Random; + +import arr.F; + +public class Trainer  +{ +	public static void main(String[] args)  +	{ +        OptionParser parser = new OptionParser(); +        parser.accepts("help"); +        parser.accepts("in").withRequiredArg().ofType(File.class); +        parser.accepts("out").withRequiredArg().ofType(File.class); +        parser.accepts("parameters").withRequiredArg().ofType(File.class); +        parser.accepts("topics").withRequiredArg().ofType(Integer.class).defaultsTo(5); +        parser.accepts("em-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(5); +        parser.accepts("pr-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(0); +        parser.accepts("threads").withRequiredArg().ofType(Integer.class).defaultsTo(0); +        parser.accepts("scale-phrase").withRequiredArg().ofType(Double.class).defaultsTo(5.0); +        parser.accepts("scale-context").withRequiredArg().ofType(Double.class).defaultsTo(0.0); +        parser.accepts("seed").withRequiredArg().ofType(Long.class).defaultsTo(0l); +        parser.accepts("convergence-threshold").withRequiredArg().ofType(Double.class).defaultsTo(1e-6); +        OptionSet options = parser.parse(args); + +        if (options.has("help") || !options.has("in")) +        { +        	try { +				parser.printHelpOn(System.err); +			} catch (IOException e) { +				System.err.println("This should never happen. Really."); +				e.printStackTrace(); +			} +        	System.exit(1);      +        } +		 +		int tags = (Integer) options.valueOf("topics"); +		int em_iterations = (Integer) options.valueOf("em-iterations"); +		int pr_iterations = (Integer) options.valueOf("pr-iterations"); +		double scale_phrase = (Double) options.valueOf("scale-phrase"); +		double scale_context = (Double) options.valueOf("scale-context"); +		int threads = (Integer) options.valueOf("threads"); +		double threshold = (Double) options.valueOf("convergence-threshold"); +		 +		if (options.has("seed")) +			F.rng = new Random((Long) options.valueOf("seed")); +		 +		if (tags <= 1 || scale_phrase < 0 || scale_context < 0 || threshold < 0) +		{ +			System.err.println("Invalid arguments. Try again!"); +			System.exit(1); +		} +		 +		Corpus corpus = null; +		File infile = (File) options.valueOf("in"); +		try { +			System.out.println("Reading concordance from " + infile); +			corpus = Corpus.readFromFile(FileUtil.reader(infile)); +			corpus.printStats(System.out); +		} catch (IOException e) { +			System.err.println("Failed to open input file: " + infile); +			e.printStackTrace(); +			System.exit(1); +		} +		 + 		System.out.println("Running with " + tags + " tags " + + 				"for " + em_iterations + " EM and " + pr_iterations + " PR iterations " + + 				"with scale " + scale_phrase + " phrase and " + scale_context + " context " + + 				"and " + threads + " threads"); + 		System.out.println(); +		 +		PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads); +				 +		double last = 0; +		for (int i=0; i<em_iterations+pr_iterations; i++) +		{ +			double o; +			if (i < em_iterations)  +				o = cluster.EM(); +			else if (scale_context == 0) +			{ +				if (threads >= 1) +					o = cluster.PREM_phrase_constraints_parallel(); +				else +					o = cluster.PREM_phrase_constraints(); +			} +			else  +				o = cluster.PREM_phrase_context_constraints(); +			 +			System.out.println("ITER: "+i+" objective: " + o); +			 +			if (i != 0 && Math.abs((o - last) / o) < threshold) +			{ +				last = o; +				if (i < em_iterations) +				{ +					i = em_iterations - 1; +					continue; +				} +				else +					break; +			} +			last = o; +		} +		 +		double pl1lmax = cluster.phrase_l1lmax(); +		double cl1lmax = cluster.context_l1lmax(); +		System.out.println("\nFinal posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax); +		if (pr_iterations == 0)  +			System.out.println("With PR objective " + (last - scale_phrase*pl1lmax - scale_context*cl1lmax)); +		 +		if (options.has("out")) +		{ +			File outfile = (File) options.valueOf("out"); +			try { +				PrintStream ps = FileUtil.printstream(outfile); +				cluster.displayPosterior(ps); +				ps.close(); +			} catch (IOException e) { +				System.err.println("Failed to open output file: " + outfile); +				e.printStackTrace(); +				System.exit(1); +			} +		} + +		if (options.has("parameters")) +		{ +			File outfile = (File) options.valueOf("parameters"); +			PrintStream ps; +			try { +				ps = FileUtil.printstream(outfile); +				cluster.displayModelParam(ps); +				ps.close(); +			} catch (IOException e) { +				System.err.println("Failed to open output parameters file: " + outfile); +				e.printStackTrace(); +				System.exit(1); +			} +		} +		 +		if (cluster.pool != null) +			cluster.pool.shutdown(); +	} +} diff --git a/gi/posterior-regularisation/prjava/train-PR-cluster.sh b/gi/posterior-regularisation/prjava/train-PR-cluster.sh index b86d564b..41bb403f 100755 --- a/gi/posterior-regularisation/prjava/train-PR-cluster.sh +++ b/gi/posterior-regularisation/prjava/train-PR-cluster.sh @@ -1,4 +1,4 @@  #!/bin/sh  d=`dirname $0` -java -ea -Xmx8g -cp $d/prjava.jar:$d/lib/trove-2.0.2.jar:$d/lib/optimization.jar phrase.PhraseCluster $* +java -ea -Xmx8g -cp $d/prjava.jar:$d/lib/trove-2.0.2.jar:$d/lib/optimization.jar:$d/lib/jopt-simple-3.2.jar phrase.Trainer $* | 
