summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/Trainer.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Trainer.java15
1 files changed, 13 insertions, 2 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
index b19f3fb9..439fb337 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
@@ -27,6 +27,9 @@ public class Trainer
parser.accepts("scale-context").withRequiredArg().ofType(Double.class).defaultsTo(0.0);
parser.accepts("seed").withRequiredArg().ofType(Long.class).defaultsTo(0l);
parser.accepts("convergence-threshold").withRequiredArg().ofType(Double.class).defaultsTo(1e-6);
+ parser.accepts("variational-bayes");
+ parser.accepts("alpha-emit").withRequiredArg().ofType(Double.class).defaultsTo(0.1);
+ parser.accepts("alpha-pi").withRequiredArg().ofType(Double.class).defaultsTo(0.01);
OptionSet options = parser.parse(args);
if (options.has("help") || !options.has("in"))
@@ -47,6 +50,9 @@ public class Trainer
double scale_context = (Double) options.valueOf("scale-context");
int threads = (Integer) options.valueOf("threads");
double threshold = (Double) options.valueOf("convergence-threshold");
+ boolean vb = options.has("variational-bayes");
+ double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0;
+ double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0;
if (options.has("seed"))
F.rng = new Random((Long) options.valueOf("seed"));
@@ -75,14 +81,19 @@ public class Trainer
"and " + threads + " threads");
System.out.println();
- PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads);
+ PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads, alphaEmit, alphaPi);
double last = 0;
for (int i=0; i<em_iterations+pr_iterations; i++)
{
double o;
if (i < em_iterations)
- o = cluster.EM();
+ {
+ if (!vb)
+ o = cluster.EM();
+ else
+ o = cluster.VBEM();
+ }
else if (scale_context == 0)
{
if (threads >= 1)