summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
diff options
context:
space:
mode:
authortrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-23 19:26:17 +0000
committertrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-23 19:26:17 +0000
commit01739cab52552013a68843d2f64b02e868dcd281 (patch)
treeec18da9acba3d5edb463daadc0fe31236077f113 /gi/posterior-regularisation/prjava/src/phrase/Trainer.java
parent00be683d744ea87e67dd13db08e4a17531bfb1f3 (diff)
Parallelised VB-EM
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@384 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/Trainer.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Trainer.java83
1 files changed, 56 insertions, 27 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
index f205ce67..6f302b20 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
@@ -4,11 +4,12 @@ import io.FileUtil;
import joptsimple.OptionParser;
import joptsimple.OptionSet;
import java.io.File;
-import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintStream;
import java.util.List;
import java.util.Random;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
import phrase.Corpus.Edge;
@@ -18,7 +19,6 @@ public class Trainer
{
public static void main(String[] args)
{
-
OptionParser parser = new OptionParser();
parser.accepts("help");
parser.accepts("in").withRequiredArg().ofType(File.class);
@@ -68,6 +68,10 @@ public class Trainer
if (options.has("seed"))
F.rng = new Random((Long) options.valueOf("seed"));
+ ExecutorService threadPool = null;
+ if (threads > 0)
+ threadPool = Executors.newFixedThreadPool(threads);
+
if (tags <= 1 || scale_phrase < 0 || scale_context < 0 || threshold < 0)
{
System.err.println("Invalid arguments. Try again!");
@@ -114,26 +118,30 @@ public class Trainer
agree = new Agree(tags, corpus);
else
{
- cluster = new PhraseCluster(tags, corpus);
- if (threads > 0) cluster.useThreadPool(threads);
-
- if (vb) {
- //cluster.initialiseVB(alphaEmit, alphaPi);
+ if (vb)
+ {
vbModel=new VB(tags,corpus);
vbModel.alpha=alphaPi;
vbModel.lambda=alphaEmit;
- }
- if (options.has("no-parameter-cache"))
- cluster.cacheLambda = false;
- if (options.has("start"))
+ if (threadPool != null) vbModel.useThreadPool(threadPool);
+ }
+ else
{
- try {
- System.err.println("Reading starting parameters from " + options.valueOf("start"));
- cluster.loadParameters(FileUtil.reader((File)options.valueOf("start")));
- } catch (IOException e) {
- System.err.println("Failed to open input file: " + options.valueOf("start"));
- e.printStackTrace();
- }
+ cluster = new PhraseCluster(tags, corpus);
+ if (threadPool != null) cluster.useThreadPool(threadPool);
+
+ if (options.has("no-parameter-cache"))
+ cluster.cacheLambda = false;
+ if (options.has("start"))
+ {
+ try {
+ System.err.println("Reading starting parameters from " + options.valueOf("start"));
+ cluster.loadParameters(FileUtil.reader((File)options.valueOf("start")));
+ } catch (IOException e) {
+ System.err.println("Failed to open input file: " + options.valueOf("start"));
+ e.printStackTrace();
+ }
+ }
}
}
@@ -143,9 +151,8 @@ public class Trainer
double o;
if (agree != null)
o = agree.EM();
- else if(agree2sides!=null){
+ else if(agree2sides!=null)
o = agree2sides.EM();
- }
else
{
if (i < skip)
@@ -173,11 +180,25 @@ public class Trainer
last = o;
}
- if (cluster == null)
- cluster = agree.model1;
+ double pl1lmax = 0, cl1lmax = 0;
+ if (cluster != null)
+ {
+ pl1lmax = cluster.phrase_l1lmax();
+ cl1lmax = cluster.context_l1lmax();
+ }
+ else if (agree != null)
+ {
+ // fairly arbitrary choice of model1 cf model2
+ pl1lmax = agree.model1.phrase_l1lmax();
+ cl1lmax = agree.model1.context_l1lmax();
+ }
+ else if (agree2sides != null)
+ {
+ // fairly arbitrary choice of model1 cf model2
+ pl1lmax = agree2sides.model1.phrase_l1lmax();
+ cl1lmax = agree2sides.model1.context_l1lmax();
+ }
- double pl1lmax = cluster.phrase_l1lmax();
- double cl1lmax = cluster.context_l1lmax();
System.out.println("\nFinal posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax);
if (options.has("out"))
@@ -194,11 +215,18 @@ public class Trainer
System.out.println("Reading testing concordance from " + infile);
test = corpus.readEdges(FileUtil.reader(infile));
}
- if(vb){
+ if(vb) {
+ assert !options.has("test");
vbModel.displayPosterior(ps);
- }else{
+ } else if (cluster != null)
cluster.displayPosterior(ps, test);
+ else if (agree != null)
+ agree.displayPosterior(ps, test);
+ else if (agree2sides != null) {
+ assert !options.has("test");
+ agree2sides.displayPosterior(ps);
}
+
ps.close();
} catch (IOException e) {
System.err.println("Failed to open either testing file or output file");
@@ -209,6 +237,7 @@ public class Trainer
if (options.has("parameters"))
{
+ assert !vb;
File outfile = (File) options.valueOf("parameters");
PrintStream ps;
try {
@@ -222,7 +251,7 @@ public class Trainer
}
}
- if (cluster.pool != null)
+ if (cluster != null && cluster.pool != null)
cluster.pool.shutdown();
}
}