diff options
author | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-16 21:34:28 +0000 |
---|---|---|
committer | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-16 21:34:28 +0000 |
commit | 1207aaee1f55dbaac8a46f37635a4d1baf392760 (patch) | |
tree | ad335c14a9df152e4603cc70957103137817d018 /gi/posterior-regularisation/prjava/src/phrase/Trainer.java | |
parent | 9ffba9b2a35582df415384117450f994e64d7cdb (diff) |
Added various flags to filter out low count events (words, edges).
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@298 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/Trainer.java')
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/Trainer.java | 48 |
1 files changed, 29 insertions, 19 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index a67c17a2..ed7a6bbe 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -4,6 +4,7 @@ import io.FileUtil; import joptsimple.OptionParser; import joptsimple.OptionSet; import java.io.File; +import java.io.FileNotFoundException; import java.io.IOException; import java.io.PrintStream; import java.util.Random; @@ -18,12 +19,12 @@ public class Trainer parser.accepts("help"); parser.accepts("in").withRequiredArg().ofType(File.class); parser.accepts("out").withRequiredArg().ofType(File.class); + parser.accepts("start").withRequiredArg().ofType(File.class); parser.accepts("parameters").withRequiredArg().ofType(File.class); parser.accepts("topics").withRequiredArg().ofType(Integer.class).defaultsTo(5); - parser.accepts("em-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(5); - parser.accepts("pr-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(0); + parser.accepts("iterations").withRequiredArg().ofType(Integer.class).defaultsTo(10); parser.accepts("threads").withRequiredArg().ofType(Integer.class).defaultsTo(0); - parser.accepts("scale-phrase").withRequiredArg().ofType(Double.class).defaultsTo(5.0); + parser.accepts("scale-phrase").withRequiredArg().ofType(Double.class).defaultsTo(0.0); parser.accepts("scale-context").withRequiredArg().ofType(Double.class).defaultsTo(0.0); parser.accepts("seed").withRequiredArg().ofType(Long.class).defaultsTo(0l); parser.accepts("convergence-threshold").withRequiredArg().ofType(Double.class).defaultsTo(1e-6); @@ -33,6 +34,8 @@ public class Trainer parser.accepts("agree"); parser.accepts("no-parameter-cache"); parser.accepts("skip-large-phrases").withRequiredArg().ofType(Integer.class).defaultsTo(5); + parser.accepts("rare-word").withRequiredArg().ofType(Integer.class).defaultsTo(0); + parser.accepts("rare-edge").withRequiredArg().ofType(Integer.class).defaultsTo(0); OptionSet options = parser.parse(args); if (options.has("help") || !options.has("in")) @@ -47,8 +50,7 @@ public class Trainer } int tags = (Integer) options.valueOf("topics"); - int em_iterations = (Integer) options.valueOf("em-iterations"); - int pr_iterations = (Integer) options.valueOf("pr-iterations"); + int iterations = (Integer) options.valueOf("iterations"); double scale_phrase = (Double) options.valueOf("scale-phrase"); double scale_context = (Double) options.valueOf("scale-context"); int threads = (Integer) options.valueOf("threads"); @@ -57,6 +59,8 @@ public class Trainer double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0; double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0; int skip = (Integer) options.valueOf("skip-large-phrases"); + int wordThreshold = (Integer) options.valueOf("rare-word"); + int edgeThreshold = (Integer) options.valueOf("rare-edge"); if (options.has("seed")) F.rng = new Random((Long) options.valueOf("seed")); @@ -79,15 +83,18 @@ public class Trainer System.exit(1); } + if (wordThreshold > 0) + corpus.applyWordThreshold(wordThreshold); + if (!options.has("agree")) System.out.println("Running with " + tags + " tags " + - "for " + em_iterations + " EM and " + pr_iterations + " PR iterations " + - "skipping large phrases for first " + skip + " iterations " + + "for " + iterations + " iterations " + + ((skip > 0) ? "skipping large phrases for first " + skip + " iterations " : "") + "with scale " + scale_phrase + " phrase and " + scale_context + " context " + "and " + threads + " threads"); else System.out.println("Running agreement model with " + tags + " tags " + - "for " + em_iterations); + "for " + iterations); System.out.println(); @@ -102,17 +109,28 @@ public class Trainer if (vb) cluster.initialiseVB(alphaEmit, alphaPi); if (options.has("no-parameter-cache")) cluster.cacheLambda = false; + if (options.has("start")) + { + try { + System.err.println("Reading starting parameters from " + options.valueOf("start")); + cluster.loadParameters(FileUtil.reader((File)options.valueOf("start"))); + } catch (IOException e) { + System.err.println("Failed to open input file: " + options.valueOf("start")); + e.printStackTrace(); + } + } + cluster.setEdgeThreshold(edgeThreshold); } double last = 0; - for (int i=0; i<em_iterations+pr_iterations; i++) + for (int i=0; i < iterations; i++) { double o; if (agree != null) o = agree.EM(); else { - if (i < em_iterations) + if (scale_phrase <= 0 && scale_context <= 0) { if (!vb) o = cluster.EM(i < skip); @@ -128,13 +146,7 @@ public class Trainer if (i != 0 && Math.abs((o - last) / o) < threshold) { last = o; - if (i < Math.max(em_iterations, skip)) - { - i = Math.max(em_iterations, skip) - 1; - continue; - } - else - break; + break; } last = o; } @@ -145,8 +157,6 @@ public class Trainer double pl1lmax = cluster.phrase_l1lmax(); double cl1lmax = cluster.context_l1lmax(); System.out.println("\nFinal posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax); - if (pr_iterations == 0) - System.out.println("With PR objective " + (last - scale_phrase*pl1lmax - scale_context*cl1lmax)); if (options.has("out")) { |