From 808aa98dfdc0f2beb42503172de61f72981d6735 Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Fri, 9 Jul 2010 16:22:54 +0000 Subject: Added formal command line options & new main class. git-svn-id: https://ws10smt.googlecode.com/svn/trunk@200 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/posterior-regularisation/prjava/src/arr/F.java | 2 +- .../prjava/src/io/FileUtil.java | 47 +++---- .../prjava/src/phrase/Corpus.java | 8 +- .../prjava/src/phrase/PhraseCluster.java | 71 +--------- .../prjava/src/phrase/PhraseContextObjective.java | 2 +- .../prjava/src/phrase/PhraseCorpus.java | 17 ++- .../prjava/src/phrase/PhraseObjective.java | 2 +- .../prjava/src/phrase/Trainer.java | 150 +++++++++++++++++++++ 8 files changed, 192 insertions(+), 107 deletions(-) create mode 100644 gi/posterior-regularisation/prjava/src/phrase/Trainer.java (limited to 'gi/posterior-regularisation/prjava/src') diff --git a/gi/posterior-regularisation/prjava/src/arr/F.java b/gi/posterior-regularisation/prjava/src/arr/F.java index 5821af42..7f2b140a 100644 --- a/gi/posterior-regularisation/prjava/src/arr/F.java +++ b/gi/posterior-regularisation/prjava/src/arr/F.java @@ -3,7 +3,7 @@ package arr; import java.util.Random; public class F { - private static Random rng = new Random(); //(9562724l); + public static Random rng = new Random(); public static void randomise(double probs[]) { diff --git a/gi/posterior-regularisation/prjava/src/io/FileUtil.java b/gi/posterior-regularisation/prjava/src/io/FileUtil.java index 67ce571e..81e7747b 100644 --- a/gi/posterior-regularisation/prjava/src/io/FileUtil.java +++ b/gi/posterior-regularisation/prjava/src/io/FileUtil.java @@ -3,7 +3,24 @@ import java.util.*; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; import java.io.*; -public class FileUtil { +public class FileUtil +{ + public static BufferedReader reader(File file) throws FileNotFoundException, IOException + { + if (file.getName().endsWith(".gz")) + return new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(file)))); + else + return new BufferedReader(new FileReader(file)); + } + + public static PrintStream printstream(File file) throws FileNotFoundException, IOException + { + if (file.getName().endsWith(".gz")) + return new PrintStream(new GZIPOutputStream(new FileOutputStream(file))); + else + return new PrintStream(new FileOutputStream(file)); + } + public static Scanner openInFile(String filename){ Scanner localsc=null; try @@ -16,34 +33,6 @@ public class FileUtil { return localsc; } - public static BufferedReader openBufferedReader(String filename){ - BufferedReader r=null; - try - { - if (filename.endsWith(".gz")) - r=(new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(new File(filename)))))); - else - r=(new BufferedReader(new FileReader(new File(filename)))); - }catch(IOException ioe){ - System.out.println(ioe.getMessage()); - } - return r; - } - - public static PrintStream openOutFile(String filename){ - PrintStream localps=null; - try - { - if (filename.endsWith(".gz")) - localps=new PrintStream (new GZIPOutputStream(new FileOutputStream(filename))); - else - localps=new PrintStream (new FileOutputStream(filename)); - - }catch(IOException ioe){ - System.out.println(ioe.getMessage()); - } - return localps; - } public static FileInputStream openInputStream(String infilename){ FileInputStream fis=null; try { diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java index d5e856ca..81264ab9 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java @@ -217,5 +217,11 @@ public class Corpus } return c; + } + + public void printStats(PrintStream out) + { + out.println("Corpus has " + edges.size() + " edges " + phraseLexicon.size() + " phrases " + + contextLexicon.size() + " contexts and " + wordLexicon.size() + " word types"); } -} +} \ No newline at end of file diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 63a60682..7d7c46dd 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -25,73 +25,6 @@ public class PhraseCluster { // pi[phrase][tag] = p(tag | phrase) private double pi[][]; - public static void main(String[] args) - { - String input_fname = args[0]; - int tags = Integer.parseInt(args[1]); - String output_fname = args[2]; - int iterations = Integer.parseInt(args[3]); - double scalePT = Double.parseDouble(args[4]); - double scaleCT = Double.parseDouble(args[5]); - int threads = Integer.parseInt(args[6]); - boolean runEM = Boolean.parseBoolean(args[7]); - - assert(tags >= 2); - assert(scalePT >= 0); - assert(scaleCT >= 0); - - Corpus corpus = null; - try { - corpus = Corpus.readFromFile(FileUtil.openBufferedReader(input_fname)); - } catch (IOException e) { - System.err.println("Failed to open input file: " + input_fname); - e.printStackTrace(); - System.exit(1); - } - PhraseCluster cluster = new PhraseCluster(tags, corpus, scalePT, scaleCT, threads); - - //PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out"); - - double last = 0; - for(int i=0;i= 1) - o = cluster.PREM_phrase_constraints_parallel(); - else - o = cluster.PREM_phrase_constraints(); - } - else - o = cluster.PREM_phrase_context_constraints(); - - //PhraseObjective.ps. - System.out.println("ITER: "+i+" objective: " + o); - last = o; - } - - double pl1lmax = cluster.phrase_l1lmax(); - double cl1lmax = cluster.context_l1lmax(); - System.out.println("Final posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax); - if (runEM) System.out.println("With PR objective " + (last - scalePT*pl1lmax - scaleCT*cl1lmax)); - - PrintStream ps=io.FileUtil.openOutFile(output_fname); - cluster.displayPosterior(ps); - ps.close(); - - //PhraseObjective.ps.close(); - - //ps = io.FileUtil.openOutFile(outputDir + "/parameters.out"); - //cluster.displayModelParam(ps); - //ps.close(); - - if (cluster.pool != null) - cluster.pool.shutdown(); - } - public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){ K=numCluster; c=corpus; @@ -134,7 +67,7 @@ public class PhraseCluster { double p[]=posterior(edge); double z = arr.F.l1norm(p); assert z > 0; - loglikelihood+=Math.log(z); + loglikelihood += edge.getCount() * Math.log(z); arr.F.l1normalize(p); int count = edge.getCount(); @@ -150,7 +83,7 @@ public class PhraseCluster { } } - System.out.println("Log likelihood: "+loglikelihood); + //System.out.println("Log likelihood: "+loglikelihood); //M for(double [][]i:exp_emit){ diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java index fbf43a7f..15bd29c2 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java @@ -26,7 +26,7 @@ import phrase.Corpus.Edge; public class PhraseContextObjective extends ProjectedObjective { private static final double GRAD_DIFF = 0.00002; - private static double INIT_STEP_SIZE = 10; + private static double INIT_STEP_SIZE = 300; private static double VAL_DIFF = 1e-4; // FIXME needs to be tuned private static int ITERATIONS = 100; diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java index 11e948ff..903e47c8 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java @@ -1,7 +1,11 @@ package phrase; +import io.FileUtil; + import java.io.BufferedInputStream; import java.io.BufferedReader; +import java.io.File; +import java.io.FileNotFoundException; import java.io.IOException; import java.io.PrintStream; import java.util.ArrayList; @@ -20,8 +24,9 @@ public class PhraseCorpus public int data[][][]; public int numContexts; - public PhraseCorpus(String filename){ - BufferedReader r=io.FileUtil.openBufferedReader(filename); + public PhraseCorpus(String filename) throws FileNotFoundException, IOException + { + BufferedReader r = FileUtil.reader(new File(filename)); phraseLex=new HashMap(); wordLex=new HashMap(); @@ -84,8 +89,9 @@ public class PhraseCorpus } //for debugging - public void saveLex(String lexFilename){ - PrintStream ps=io.FileUtil.openOutFile(lexFilename); + public void saveLex(String lexFilename) throws FileNotFoundException, IOException + { + PrintStream ps = FileUtil.printstream(new File(lexFilename)); ps.println("Phrase Lexicon"); ps.println(phraseLex.size()); printDict(phraseLex,ps); @@ -175,7 +181,8 @@ public class PhraseCorpus return null; } - public static void main(String[] args) { + public static void main(String[] args) throws Exception + { String LEX_FILENAME="../pdata/lex.out"; String DATA_FILENAME="../pdata/btec.con"; PhraseCorpus c=new PhraseCorpus(DATA_FILENAME); diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 0a76e2dc..3314f74a 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -21,7 +21,7 @@ import optimization.util.MathUtils; public class PhraseObjective extends ProjectedObjective { static final double GRAD_DIFF = 0.00002; - static double INIT_STEP_SIZE = 10; + static double INIT_STEP_SIZE = 300; static double VAL_DIFF = 1e-4; // FIXME needs to be tuned - and this might be too weak static int ITERATIONS = 100; //private double c1=0.0001; // wolf stuff diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java new file mode 100644 index 00000000..b19f3fb9 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -0,0 +1,150 @@ +package phrase; + +import io.FileUtil; +import joptsimple.OptionParser; +import joptsimple.OptionSet; +import java.io.File; +import java.io.IOException; +import java.io.PrintStream; +import java.util.Random; + +import arr.F; + +public class Trainer +{ + public static void main(String[] args) + { + OptionParser parser = new OptionParser(); + parser.accepts("help"); + parser.accepts("in").withRequiredArg().ofType(File.class); + parser.accepts("out").withRequiredArg().ofType(File.class); + parser.accepts("parameters").withRequiredArg().ofType(File.class); + parser.accepts("topics").withRequiredArg().ofType(Integer.class).defaultsTo(5); + parser.accepts("em-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(5); + parser.accepts("pr-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(0); + parser.accepts("threads").withRequiredArg().ofType(Integer.class).defaultsTo(0); + parser.accepts("scale-phrase").withRequiredArg().ofType(Double.class).defaultsTo(5.0); + parser.accepts("scale-context").withRequiredArg().ofType(Double.class).defaultsTo(0.0); + parser.accepts("seed").withRequiredArg().ofType(Long.class).defaultsTo(0l); + parser.accepts("convergence-threshold").withRequiredArg().ofType(Double.class).defaultsTo(1e-6); + OptionSet options = parser.parse(args); + + if (options.has("help") || !options.has("in")) + { + try { + parser.printHelpOn(System.err); + } catch (IOException e) { + System.err.println("This should never happen. Really."); + e.printStackTrace(); + } + System.exit(1); + } + + int tags = (Integer) options.valueOf("topics"); + int em_iterations = (Integer) options.valueOf("em-iterations"); + int pr_iterations = (Integer) options.valueOf("pr-iterations"); + double scale_phrase = (Double) options.valueOf("scale-phrase"); + double scale_context = (Double) options.valueOf("scale-context"); + int threads = (Integer) options.valueOf("threads"); + double threshold = (Double) options.valueOf("convergence-threshold"); + + if (options.has("seed")) + F.rng = new Random((Long) options.valueOf("seed")); + + if (tags <= 1 || scale_phrase < 0 || scale_context < 0 || threshold < 0) + { + System.err.println("Invalid arguments. Try again!"); + System.exit(1); + } + + Corpus corpus = null; + File infile = (File) options.valueOf("in"); + try { + System.out.println("Reading concordance from " + infile); + corpus = Corpus.readFromFile(FileUtil.reader(infile)); + corpus.printStats(System.out); + } catch (IOException e) { + System.err.println("Failed to open input file: " + infile); + e.printStackTrace(); + System.exit(1); + } + + System.out.println("Running with " + tags + " tags " + + "for " + em_iterations + " EM and " + pr_iterations + " PR iterations " + + "with scale " + scale_phrase + " phrase and " + scale_context + " context " + + "and " + threads + " threads"); + System.out.println(); + + PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads); + + double last = 0; + for (int i=0; i= 1) + o = cluster.PREM_phrase_constraints_parallel(); + else + o = cluster.PREM_phrase_constraints(); + } + else + o = cluster.PREM_phrase_context_constraints(); + + System.out.println("ITER: "+i+" objective: " + o); + + if (i != 0 && Math.abs((o - last) / o) < threshold) + { + last = o; + if (i < em_iterations) + { + i = em_iterations - 1; + continue; + } + else + break; + } + last = o; + } + + double pl1lmax = cluster.phrase_l1lmax(); + double cl1lmax = cluster.context_l1lmax(); + System.out.println("\nFinal posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax); + if (pr_iterations == 0) + System.out.println("With PR objective " + (last - scale_phrase*pl1lmax - scale_context*cl1lmax)); + + if (options.has("out")) + { + File outfile = (File) options.valueOf("out"); + try { + PrintStream ps = FileUtil.printstream(outfile); + cluster.displayPosterior(ps); + ps.close(); + } catch (IOException e) { + System.err.println("Failed to open output file: " + outfile); + e.printStackTrace(); + System.exit(1); + } + } + + if (options.has("parameters")) + { + File outfile = (File) options.valueOf("parameters"); + PrintStream ps; + try { + ps = FileUtil.printstream(outfile); + cluster.displayModelParam(ps); + ps.close(); + } catch (IOException e) { + System.err.println("Failed to open output parameters file: " + outfile); + e.printStackTrace(); + System.exit(1); + } + } + + if (cluster.pool != null) + cluster.pool.shutdown(); + } +} -- cgit v1.2.3