summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/Trainer.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Trainer.java13
1 files changed, 8 insertions, 5 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
index 20f6c905..a67c17a2 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
@@ -32,6 +32,7 @@ public class Trainer
parser.accepts("alpha-pi").withRequiredArg().ofType(Double.class).defaultsTo(0.01);
parser.accepts("agree");
parser.accepts("no-parameter-cache");
+ parser.accepts("skip-large-phrases").withRequiredArg().ofType(Integer.class).defaultsTo(5);
OptionSet options = parser.parse(args);
if (options.has("help") || !options.has("in"))
@@ -55,6 +56,7 @@ public class Trainer
boolean vb = options.has("variational-bayes");
double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0;
double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0;
+ int skip = (Integer) options.valueOf("skip-large-phrases");
if (options.has("seed"))
F.rng = new Random((Long) options.valueOf("seed"));
@@ -80,6 +82,7 @@ public class Trainer
if (!options.has("agree"))
System.out.println("Running with " + tags + " tags " +
"for " + em_iterations + " EM and " + pr_iterations + " PR iterations " +
+ "skipping large phrases for first " + skip + " iterations " +
"with scale " + scale_phrase + " phrase and " + scale_context + " context " +
"and " + threads + " threads");
else
@@ -112,12 +115,12 @@ public class Trainer
if (i < em_iterations)
{
if (!vb)
- o = cluster.EM();
+ o = cluster.EM(i < skip);
else
- o = cluster.VBEM(alphaEmit, alphaPi);
+ o = cluster.VBEM(alphaEmit, alphaPi, i < skip);
}
else
- o = cluster.PREM(scale_phrase, scale_context);
+ o = cluster.PREM(scale_phrase, scale_context, i < skip);
}
System.out.println("ITER: "+i+" objective: " + o);
@@ -125,9 +128,9 @@ public class Trainer
if (i != 0 && Math.abs((o - last) / o) < threshold)
{
last = o;
- if (i < em_iterations)
+ if (i < Math.max(em_iterations, skip))
{
- i = em_iterations - 1;
+ i = Math.max(em_iterations, skip) - 1;
continue;
}
else