From 808aa98dfdc0f2beb42503172de61f72981d6735 Mon Sep 17 00:00:00 2001
From: "trevor.cohn" <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>
Date: Fri, 9 Jul 2010 16:22:54 +0000
Subject: Added formal command line options & new main class.

git-svn-id: https://ws10smt.googlecode.com/svn/trunk@200 ec762483-ff6d-05da-a07a-a48fb63a330f
---
 gi/posterior-regularisation/prjava/build.xml       |   1 +
 .../prjava/prjava-20100707.jar                     | Bin 933814 -> 0 bytes
 .../prjava/prjava-20100707_1.jar                   | Bin 934198 -> 0 bytes
 gi/posterior-regularisation/prjava/src/arr/F.java  |   2 +-
 .../prjava/src/io/FileUtil.java                    |  47 +++----
 .../prjava/src/phrase/Corpus.java                  |   8 +-
 .../prjava/src/phrase/PhraseCluster.java           |  71 +---------
 .../prjava/src/phrase/PhraseContextObjective.java  |   2 +-
 .../prjava/src/phrase/PhraseCorpus.java            |  17 ++-
 .../prjava/src/phrase/PhraseObjective.java         |   2 +-
 .../prjava/src/phrase/Trainer.java                 | 150 +++++++++++++++++++++
 .../prjava/train-PR-cluster.sh                     |   2 +-
 12 files changed, 194 insertions(+), 108 deletions(-)
 delete mode 100644 gi/posterior-regularisation/prjava/prjava-20100707.jar
 delete mode 100644 gi/posterior-regularisation/prjava/prjava-20100707_1.jar
 create mode 100644 gi/posterior-regularisation/prjava/src/phrase/Trainer.java

(limited to 'gi/posterior-regularisation/prjava')

diff --git a/gi/posterior-regularisation/prjava/build.xml b/gi/posterior-regularisation/prjava/build.xml
index c9ed2e8d..40199144 100644
--- a/gi/posterior-regularisation/prjava/build.xml
+++ b/gi/posterior-regularisation/prjava/build.xml
@@ -6,6 +6,7 @@
   <path id="classpath">
       <pathelement location="lib/trove-2.0.2.jar"/>
       <pathelement location="lib/optimization.jar"/>
+      <pathelement location="lib/jopt-simple-3.2.jar"/>
   </path>
 
   <target name="init">
diff --git a/gi/posterior-regularisation/prjava/prjava-20100707.jar b/gi/posterior-regularisation/prjava/prjava-20100707.jar
deleted file mode 100644
index 195374d9..00000000
Binary files a/gi/posterior-regularisation/prjava/prjava-20100707.jar and /dev/null differ
diff --git a/gi/posterior-regularisation/prjava/prjava-20100707_1.jar b/gi/posterior-regularisation/prjava/prjava-20100707_1.jar
deleted file mode 100644
index a65f4d43..00000000
Binary files a/gi/posterior-regularisation/prjava/prjava-20100707_1.jar and /dev/null differ
diff --git a/gi/posterior-regularisation/prjava/src/arr/F.java b/gi/posterior-regularisation/prjava/src/arr/F.java
index 5821af42..7f2b140a 100644
--- a/gi/posterior-regularisation/prjava/src/arr/F.java
+++ b/gi/posterior-regularisation/prjava/src/arr/F.java
@@ -3,7 +3,7 @@ package arr;
 import java.util.Random;
 
 public class F {
-	private static Random rng = new Random(); //(9562724l);
+	public static Random rng = new Random();
 	
 	public static void randomise(double probs[])
 	{
diff --git a/gi/posterior-regularisation/prjava/src/io/FileUtil.java b/gi/posterior-regularisation/prjava/src/io/FileUtil.java
index 67ce571e..81e7747b 100644
--- a/gi/posterior-regularisation/prjava/src/io/FileUtil.java
+++ b/gi/posterior-regularisation/prjava/src/io/FileUtil.java
@@ -3,7 +3,24 @@ import java.util.*;
 import java.util.zip.GZIPInputStream;
 import java.util.zip.GZIPOutputStream;
 import java.io.*;
-public class FileUtil {
+public class FileUtil 
+{
+	public static BufferedReader reader(File file) throws FileNotFoundException, IOException
+	{
+		if (file.getName().endsWith(".gz"))
+			return new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(file))));
+		else
+			return new BufferedReader(new FileReader(file));
+	}
+	
+	public static PrintStream printstream(File file) throws FileNotFoundException, IOException
+	{
+		if (file.getName().endsWith(".gz"))
+			return new PrintStream(new GZIPOutputStream(new FileOutputStream(file)));
+		else
+			return new PrintStream(new FileOutputStream(file));
+	}
+
 	public static Scanner openInFile(String filename){
 		Scanner localsc=null;
 		try
@@ -16,34 +33,6 @@ public class FileUtil {
 		return localsc;
 	}
 	
-	public static BufferedReader openBufferedReader(String filename){
-		BufferedReader r=null;
-		try
-		{
-			if (filename.endsWith(".gz"))
-				r=(new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(new File(filename))))));
-			else
-				r=(new BufferedReader(new FileReader(new File(filename))));
-		}catch(IOException ioe){
-			System.out.println(ioe.getMessage());
-		}
-		return r;		
-	}
-	
-	public static PrintStream  openOutFile(String filename){
-		PrintStream localps=null;
-		try
-		{
-			if (filename.endsWith(".gz"))
-				localps=new PrintStream (new GZIPOutputStream(new FileOutputStream(filename)));
-			else
-				localps=new PrintStream (new FileOutputStream(filename));
-
-		}catch(IOException ioe){
-			System.out.println(ioe.getMessage());
-		}
-		return localps;
-	}
 	public static FileInputStream openInputStream(String infilename){
 		FileInputStream fis=null;
 		try {
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
index d5e856ca..81264ab9 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
@@ -217,5 +217,11 @@ public class Corpus
 		}
 		
 		return c;
+	}
+
+	public void printStats(PrintStream out) 
+	{
+		out.println("Corpus has " + edges.size() + " edges " + phraseLexicon.size() + " phrases " 
+				+ contextLexicon.size() + " contexts and " + wordLexicon.size() + " word types");
 	}	
-}
+}
\ No newline at end of file
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index 63a60682..7d7c46dd 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -25,73 +25,6 @@ public class PhraseCluster {
 	// pi[phrase][tag] = p(tag | phrase)
 	private double pi[][];
 	
-	public static void main(String[] args) 
-	{
-		String input_fname = args[0];
-		int tags = Integer.parseInt(args[1]);
-		String output_fname = args[2];
-		int iterations = Integer.parseInt(args[3]);
-		double scalePT = Double.parseDouble(args[4]);
-		double scaleCT = Double.parseDouble(args[5]);
-		int threads = Integer.parseInt(args[6]);
-		boolean runEM = Boolean.parseBoolean(args[7]);
-		
-		assert(tags >= 2);
-		assert(scalePT >= 0);
-		assert(scaleCT >= 0);
-		
-		Corpus corpus = null;
-		try {
-			corpus = Corpus.readFromFile(FileUtil.openBufferedReader(input_fname));
-		} catch (IOException e) {
-			System.err.println("Failed to open input file: " + input_fname);
-			e.printStackTrace();
-			System.exit(1);
-		}
-		PhraseCluster cluster = new PhraseCluster(tags, corpus, scalePT, scaleCT, threads);
-		
-		//PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out");
-		
-		double last = 0;
-		for(int i=0;i<iterations;i++){
-			
-			double o;
-			if (runEM || i < 3) 
-				o = cluster.EM();
-			else if (scaleCT == 0)
-			{
-				if (threads >= 1)
-					o = cluster.PREM_phrase_constraints_parallel();
-				else
-					o = cluster.PREM_phrase_constraints();
-			}
-			else 
-				o = cluster.PREM_phrase_context_constraints();
-			
-			//PhraseObjective.ps.
-			System.out.println("ITER: "+i+" objective: " + o);
-			last = o;
-		}
-		
-		double pl1lmax = cluster.phrase_l1lmax();
-		double cl1lmax = cluster.context_l1lmax();
-		System.out.println("Final posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax);
-		if (runEM) System.out.println("With PR objective " + (last - scalePT*pl1lmax - scaleCT*cl1lmax));
-		
-		PrintStream ps=io.FileUtil.openOutFile(output_fname);
-		cluster.displayPosterior(ps);
-		ps.close();
-		
-		//PhraseObjective.ps.close();
-
-		//ps = io.FileUtil.openOutFile(outputDir + "/parameters.out");
-		//cluster.displayModelParam(ps);
-		//ps.close();
-		
-		if (cluster.pool != null)
-			cluster.pool.shutdown();
-	}
-
 	public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){
 		K=numCluster;
 		c=corpus;
@@ -134,7 +67,7 @@ public class PhraseCluster {
 				double p[]=posterior(edge);
 				double z = arr.F.l1norm(p);
 				assert z > 0;
-				loglikelihood+=Math.log(z);
+				loglikelihood += edge.getCount() * Math.log(z);
 				arr.F.l1normalize(p);
 				
 				int count = edge.getCount();
@@ -150,7 +83,7 @@ public class PhraseCluster {
 			}
 		}
 		
-		System.out.println("Log likelihood: "+loglikelihood);
+		//System.out.println("Log likelihood: "+loglikelihood);
 		
 		//M
 		for(double [][]i:exp_emit){
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
index fbf43a7f..15bd29c2 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
@@ -26,7 +26,7 @@ import phrase.Corpus.Edge;
 public class PhraseContextObjective extends ProjectedObjective
 {
 	private static final double GRAD_DIFF = 0.00002;
-	private static double INIT_STEP_SIZE = 10;
+	private static double INIT_STEP_SIZE = 300;
 	private static double VAL_DIFF = 1e-4; // FIXME needs to be tuned
 	private static int ITERATIONS = 100;
 	
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java
index 11e948ff..903e47c8 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java
@@ -1,7 +1,11 @@
 package phrase;
 
+import io.FileUtil;
+
 import java.io.BufferedInputStream;
 import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileNotFoundException;
 import java.io.IOException;
 import java.io.PrintStream;
 import java.util.ArrayList;
@@ -20,8 +24,9 @@ public class PhraseCorpus
 	public int data[][][];
 	public int numContexts;	
 
-	public PhraseCorpus(String filename){
-		BufferedReader r=io.FileUtil.openBufferedReader(filename);
+	public PhraseCorpus(String filename) throws FileNotFoundException, IOException
+	{
+		BufferedReader r = FileUtil.reader(new File(filename));
 		
 		phraseLex=new HashMap<String,Integer>();
 		wordLex=new HashMap<String,Integer>();
@@ -84,8 +89,9 @@ public class PhraseCorpus
 	}
 	
 	//for debugging
-	public void saveLex(String lexFilename){
-		PrintStream ps=io.FileUtil.openOutFile(lexFilename);
+	public void saveLex(String lexFilename) throws FileNotFoundException, IOException
+	{
+		PrintStream ps = FileUtil.printstream(new File(lexFilename));
 		ps.println("Phrase Lexicon");
 		ps.println(phraseLex.size());
 		printDict(phraseLex,ps);
@@ -175,7 +181,8 @@ public class PhraseCorpus
 		return null;
 	}
 
-	public static void main(String[] args) {
+	public static void main(String[] args) throws Exception 
+	{
 		String LEX_FILENAME="../pdata/lex.out";
 		String DATA_FILENAME="../pdata/btec.con";
 		PhraseCorpus c=new PhraseCorpus(DATA_FILENAME);
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
index 0a76e2dc..3314f74a 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
@@ -21,7 +21,7 @@ import optimization.util.MathUtils;
 public class PhraseObjective extends ProjectedObjective
 {
 	static final double GRAD_DIFF = 0.00002;
-	static double INIT_STEP_SIZE = 10;
+	static double INIT_STEP_SIZE = 300;
 	static double VAL_DIFF = 1e-4; // FIXME needs to be tuned - and this might be too weak
 	static int ITERATIONS = 100;
 	//private double c1=0.0001; // wolf stuff
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
new file mode 100644
index 00000000..b19f3fb9
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
@@ -0,0 +1,150 @@
+package phrase;
+
+import io.FileUtil;
+import joptsimple.OptionParser;
+import joptsimple.OptionSet;
+import java.io.File;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.Random;
+
+import arr.F;
+
+public class Trainer 
+{
+	public static void main(String[] args) 
+	{
+        OptionParser parser = new OptionParser();
+        parser.accepts("help");
+        parser.accepts("in").withRequiredArg().ofType(File.class);
+        parser.accepts("out").withRequiredArg().ofType(File.class);
+        parser.accepts("parameters").withRequiredArg().ofType(File.class);
+        parser.accepts("topics").withRequiredArg().ofType(Integer.class).defaultsTo(5);
+        parser.accepts("em-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(5);
+        parser.accepts("pr-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(0);
+        parser.accepts("threads").withRequiredArg().ofType(Integer.class).defaultsTo(0);
+        parser.accepts("scale-phrase").withRequiredArg().ofType(Double.class).defaultsTo(5.0);
+        parser.accepts("scale-context").withRequiredArg().ofType(Double.class).defaultsTo(0.0);
+        parser.accepts("seed").withRequiredArg().ofType(Long.class).defaultsTo(0l);
+        parser.accepts("convergence-threshold").withRequiredArg().ofType(Double.class).defaultsTo(1e-6);
+        OptionSet options = parser.parse(args);
+
+        if (options.has("help") || !options.has("in"))
+        {
+        	try {
+				parser.printHelpOn(System.err);
+			} catch (IOException e) {
+				System.err.println("This should never happen. Really.");
+				e.printStackTrace();
+			}
+        	System.exit(1);     
+        }
+		
+		int tags = (Integer) options.valueOf("topics");
+		int em_iterations = (Integer) options.valueOf("em-iterations");
+		int pr_iterations = (Integer) options.valueOf("pr-iterations");
+		double scale_phrase = (Double) options.valueOf("scale-phrase");
+		double scale_context = (Double) options.valueOf("scale-context");
+		int threads = (Integer) options.valueOf("threads");
+		double threshold = (Double) options.valueOf("convergence-threshold");
+		
+		if (options.has("seed"))
+			F.rng = new Random((Long) options.valueOf("seed"));
+		
+		if (tags <= 1 || scale_phrase < 0 || scale_context < 0 || threshold < 0)
+		{
+			System.err.println("Invalid arguments. Try again!");
+			System.exit(1);
+		}
+		
+		Corpus corpus = null;
+		File infile = (File) options.valueOf("in");
+		try {
+			System.out.println("Reading concordance from " + infile);
+			corpus = Corpus.readFromFile(FileUtil.reader(infile));
+			corpus.printStats(System.out);
+		} catch (IOException e) {
+			System.err.println("Failed to open input file: " + infile);
+			e.printStackTrace();
+			System.exit(1);
+		}
+		
+ 		System.out.println("Running with " + tags + " tags " +
+ 				"for " + em_iterations + " EM and " + pr_iterations + " PR iterations " +
+ 				"with scale " + scale_phrase + " phrase and " + scale_context + " context " +
+ 				"and " + threads + " threads");
+ 		System.out.println();
+		
+		PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads);
+				
+		double last = 0;
+		for (int i=0; i<em_iterations+pr_iterations; i++)
+		{
+			double o;
+			if (i < em_iterations) 
+				o = cluster.EM();
+			else if (scale_context == 0)
+			{
+				if (threads >= 1)
+					o = cluster.PREM_phrase_constraints_parallel();
+				else
+					o = cluster.PREM_phrase_constraints();
+			}
+			else 
+				o = cluster.PREM_phrase_context_constraints();
+			
+			System.out.println("ITER: "+i+" objective: " + o);
+			
+			if (i != 0 && Math.abs((o - last) / o) < threshold)
+			{
+				last = o;
+				if (i < em_iterations)
+				{
+					i = em_iterations - 1;
+					continue;
+				}
+				else
+					break;
+			}
+			last = o;
+		}
+		
+		double pl1lmax = cluster.phrase_l1lmax();
+		double cl1lmax = cluster.context_l1lmax();
+		System.out.println("\nFinal posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax);
+		if (pr_iterations == 0) 
+			System.out.println("With PR objective " + (last - scale_phrase*pl1lmax - scale_context*cl1lmax));
+		
+		if (options.has("out"))
+		{
+			File outfile = (File) options.valueOf("out");
+			try {
+				PrintStream ps = FileUtil.printstream(outfile);
+				cluster.displayPosterior(ps);
+				ps.close();
+			} catch (IOException e) {
+				System.err.println("Failed to open output file: " + outfile);
+				e.printStackTrace();
+				System.exit(1);
+			}
+		}
+
+		if (options.has("parameters"))
+		{
+			File outfile = (File) options.valueOf("parameters");
+			PrintStream ps;
+			try {
+				ps = FileUtil.printstream(outfile);
+				cluster.displayModelParam(ps);
+				ps.close();
+			} catch (IOException e) {
+				System.err.println("Failed to open output parameters file: " + outfile);
+				e.printStackTrace();
+				System.exit(1);
+			}
+		}
+		
+		if (cluster.pool != null)
+			cluster.pool.shutdown();
+	}
+}
diff --git a/gi/posterior-regularisation/prjava/train-PR-cluster.sh b/gi/posterior-regularisation/prjava/train-PR-cluster.sh
index b86d564b..41bb403f 100755
--- a/gi/posterior-regularisation/prjava/train-PR-cluster.sh
+++ b/gi/posterior-regularisation/prjava/train-PR-cluster.sh
@@ -1,4 +1,4 @@
 #!/bin/sh
 
 d=`dirname $0`
-java -ea -Xmx8g -cp $d/prjava.jar:$d/lib/trove-2.0.2.jar:$d/lib/optimization.jar phrase.PhraseCluster $*
+java -ea -Xmx8g -cp $d/prjava.jar:$d/lib/trove-2.0.2.jar:$d/lib/optimization.jar:$d/lib/jopt-simple-3.2.jar phrase.Trainer $*
-- 
cgit v1.2.3