diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/Trainer.java')
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/Trainer.java | 83 |
1 files changed, 56 insertions, 27 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index f205ce67..6f302b20 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -4,11 +4,12 @@ import io.FileUtil; import joptsimple.OptionParser; import joptsimple.OptionSet; import java.io.File; -import java.io.FileNotFoundException; import java.io.IOException; import java.io.PrintStream; import java.util.List; import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import phrase.Corpus.Edge; @@ -18,7 +19,6 @@ public class Trainer { public static void main(String[] args) { - OptionParser parser = new OptionParser(); parser.accepts("help"); parser.accepts("in").withRequiredArg().ofType(File.class); @@ -68,6 +68,10 @@ public class Trainer if (options.has("seed")) F.rng = new Random((Long) options.valueOf("seed")); + ExecutorService threadPool = null; + if (threads > 0) + threadPool = Executors.newFixedThreadPool(threads); + if (tags <= 1 || scale_phrase < 0 || scale_context < 0 || threshold < 0) { System.err.println("Invalid arguments. Try again!"); @@ -114,26 +118,30 @@ public class Trainer agree = new Agree(tags, corpus); else { - cluster = new PhraseCluster(tags, corpus); - if (threads > 0) cluster.useThreadPool(threads); - - if (vb) { - //cluster.initialiseVB(alphaEmit, alphaPi); + if (vb) + { vbModel=new VB(tags,corpus); vbModel.alpha=alphaPi; vbModel.lambda=alphaEmit; - } - if (options.has("no-parameter-cache")) - cluster.cacheLambda = false; - if (options.has("start")) + if (threadPool != null) vbModel.useThreadPool(threadPool); + } + else { - try { - System.err.println("Reading starting parameters from " + options.valueOf("start")); - cluster.loadParameters(FileUtil.reader((File)options.valueOf("start"))); - } catch (IOException e) { - System.err.println("Failed to open input file: " + options.valueOf("start")); - e.printStackTrace(); - } + cluster = new PhraseCluster(tags, corpus); + if (threadPool != null) cluster.useThreadPool(threadPool); + + if (options.has("no-parameter-cache")) + cluster.cacheLambda = false; + if (options.has("start")) + { + try { + System.err.println("Reading starting parameters from " + options.valueOf("start")); + cluster.loadParameters(FileUtil.reader((File)options.valueOf("start"))); + } catch (IOException e) { + System.err.println("Failed to open input file: " + options.valueOf("start")); + e.printStackTrace(); + } + } } } @@ -143,9 +151,8 @@ public class Trainer double o; if (agree != null) o = agree.EM(); - else if(agree2sides!=null){ + else if(agree2sides!=null) o = agree2sides.EM(); - } else { if (i < skip) @@ -173,11 +180,25 @@ public class Trainer last = o; } - if (cluster == null) - cluster = agree.model1; + double pl1lmax = 0, cl1lmax = 0; + if (cluster != null) + { + pl1lmax = cluster.phrase_l1lmax(); + cl1lmax = cluster.context_l1lmax(); + } + else if (agree != null) + { + // fairly arbitrary choice of model1 cf model2 + pl1lmax = agree.model1.phrase_l1lmax(); + cl1lmax = agree.model1.context_l1lmax(); + } + else if (agree2sides != null) + { + // fairly arbitrary choice of model1 cf model2 + pl1lmax = agree2sides.model1.phrase_l1lmax(); + cl1lmax = agree2sides.model1.context_l1lmax(); + } - double pl1lmax = cluster.phrase_l1lmax(); - double cl1lmax = cluster.context_l1lmax(); System.out.println("\nFinal posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax); if (options.has("out")) @@ -194,11 +215,18 @@ public class Trainer System.out.println("Reading testing concordance from " + infile); test = corpus.readEdges(FileUtil.reader(infile)); } - if(vb){ + if(vb) { + assert !options.has("test"); vbModel.displayPosterior(ps); - }else{ + } else if (cluster != null) cluster.displayPosterior(ps, test); + else if (agree != null) + agree.displayPosterior(ps, test); + else if (agree2sides != null) { + assert !options.has("test"); + agree2sides.displayPosterior(ps); } + ps.close(); } catch (IOException e) { System.err.println("Failed to open either testing file or output file"); @@ -209,6 +237,7 @@ public class Trainer if (options.has("parameters")) { + assert !vb; File outfile = (File) options.valueOf("parameters"); PrintStream ps; try { @@ -222,7 +251,7 @@ public class Trainer } } - if (cluster.pool != null) + if (cluster != null && cluster.pool != null) cluster.pool.shutdown(); } } |