diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/Trainer.java')
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/Trainer.java | 15 |
1 files changed, 13 insertions, 2 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index b19f3fb9..439fb337 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -27,6 +27,9 @@ public class Trainer parser.accepts("scale-context").withRequiredArg().ofType(Double.class).defaultsTo(0.0); parser.accepts("seed").withRequiredArg().ofType(Long.class).defaultsTo(0l); parser.accepts("convergence-threshold").withRequiredArg().ofType(Double.class).defaultsTo(1e-6); + parser.accepts("variational-bayes"); + parser.accepts("alpha-emit").withRequiredArg().ofType(Double.class).defaultsTo(0.1); + parser.accepts("alpha-pi").withRequiredArg().ofType(Double.class).defaultsTo(0.01); OptionSet options = parser.parse(args); if (options.has("help") || !options.has("in")) @@ -47,6 +50,9 @@ public class Trainer double scale_context = (Double) options.valueOf("scale-context"); int threads = (Integer) options.valueOf("threads"); double threshold = (Double) options.valueOf("convergence-threshold"); + boolean vb = options.has("variational-bayes"); + double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0; + double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0; if (options.has("seed")) F.rng = new Random((Long) options.valueOf("seed")); @@ -75,14 +81,19 @@ public class Trainer "and " + threads + " threads"); System.out.println(); - PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads); + PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads, alphaEmit, alphaPi); double last = 0; for (int i=0; i<em_iterations+pr_iterations; i++) { double o; if (i < em_iterations) - o = cluster.EM(); + { + if (!vb) + o = cluster.EM(); + else + o = cluster.VBEM(); + } else if (scale_context == 0) { if (threads >= 1) |