diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/Trainer.java')
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/Trainer.java | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index 20f6c905..a67c17a2 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -32,6 +32,7 @@ public class Trainer parser.accepts("alpha-pi").withRequiredArg().ofType(Double.class).defaultsTo(0.01); parser.accepts("agree"); parser.accepts("no-parameter-cache"); + parser.accepts("skip-large-phrases").withRequiredArg().ofType(Integer.class).defaultsTo(5); OptionSet options = parser.parse(args); if (options.has("help") || !options.has("in")) @@ -55,6 +56,7 @@ public class Trainer boolean vb = options.has("variational-bayes"); double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0; double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0; + int skip = (Integer) options.valueOf("skip-large-phrases"); if (options.has("seed")) F.rng = new Random((Long) options.valueOf("seed")); @@ -80,6 +82,7 @@ public class Trainer if (!options.has("agree")) System.out.println("Running with " + tags + " tags " + "for " + em_iterations + " EM and " + pr_iterations + " PR iterations " + + "skipping large phrases for first " + skip + " iterations " + "with scale " + scale_phrase + " phrase and " + scale_context + " context " + "and " + threads + " threads"); else @@ -112,12 +115,12 @@ public class Trainer if (i < em_iterations) { if (!vb) - o = cluster.EM(); + o = cluster.EM(i < skip); else - o = cluster.VBEM(alphaEmit, alphaPi); + o = cluster.VBEM(alphaEmit, alphaPi, i < skip); } else - o = cluster.PREM(scale_phrase, scale_context); + o = cluster.PREM(scale_phrase, scale_context, i < skip); } System.out.println("ITER: "+i+" objective: " + o); @@ -125,9 +128,9 @@ public class Trainer if (i != 0 && Math.abs((o - last) / o) < threshold) { last = o; - if (i < em_iterations) + if (i < Math.max(em_iterations, skip)) { - i = em_iterations - 1; + i = Math.max(em_iterations, skip) - 1; continue; } else |