From 1207aaee1f55dbaac8a46f37635a4d1baf392760 Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Fri, 16 Jul 2010 21:34:28 +0000 Subject: Added various flags to filter out low count events (words, edges). git-svn-id: https://ws10smt.googlecode.com/svn/trunk@298 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../prjava/src/phrase/Trainer.java | 48 +++++++++++++--------- 1 file changed, 29 insertions(+), 19 deletions(-) (limited to 'gi/posterior-regularisation/prjava/src/phrase/Trainer.java') diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index a67c17a2..ed7a6bbe 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -4,6 +4,7 @@ import io.FileUtil; import joptsimple.OptionParser; import joptsimple.OptionSet; import java.io.File; +import java.io.FileNotFoundException; import java.io.IOException; import java.io.PrintStream; import java.util.Random; @@ -18,12 +19,12 @@ public class Trainer parser.accepts("help"); parser.accepts("in").withRequiredArg().ofType(File.class); parser.accepts("out").withRequiredArg().ofType(File.class); + parser.accepts("start").withRequiredArg().ofType(File.class); parser.accepts("parameters").withRequiredArg().ofType(File.class); parser.accepts("topics").withRequiredArg().ofType(Integer.class).defaultsTo(5); - parser.accepts("em-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(5); - parser.accepts("pr-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(0); + parser.accepts("iterations").withRequiredArg().ofType(Integer.class).defaultsTo(10); parser.accepts("threads").withRequiredArg().ofType(Integer.class).defaultsTo(0); - parser.accepts("scale-phrase").withRequiredArg().ofType(Double.class).defaultsTo(5.0); + parser.accepts("scale-phrase").withRequiredArg().ofType(Double.class).defaultsTo(0.0); parser.accepts("scale-context").withRequiredArg().ofType(Double.class).defaultsTo(0.0); parser.accepts("seed").withRequiredArg().ofType(Long.class).defaultsTo(0l); parser.accepts("convergence-threshold").withRequiredArg().ofType(Double.class).defaultsTo(1e-6); @@ -33,6 +34,8 @@ public class Trainer parser.accepts("agree"); parser.accepts("no-parameter-cache"); parser.accepts("skip-large-phrases").withRequiredArg().ofType(Integer.class).defaultsTo(5); + parser.accepts("rare-word").withRequiredArg().ofType(Integer.class).defaultsTo(0); + parser.accepts("rare-edge").withRequiredArg().ofType(Integer.class).defaultsTo(0); OptionSet options = parser.parse(args); if (options.has("help") || !options.has("in")) @@ -47,8 +50,7 @@ public class Trainer } int tags = (Integer) options.valueOf("topics"); - int em_iterations = (Integer) options.valueOf("em-iterations"); - int pr_iterations = (Integer) options.valueOf("pr-iterations"); + int iterations = (Integer) options.valueOf("iterations"); double scale_phrase = (Double) options.valueOf("scale-phrase"); double scale_context = (Double) options.valueOf("scale-context"); int threads = (Integer) options.valueOf("threads"); @@ -57,6 +59,8 @@ public class Trainer double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0; double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0; int skip = (Integer) options.valueOf("skip-large-phrases"); + int wordThreshold = (Integer) options.valueOf("rare-word"); + int edgeThreshold = (Integer) options.valueOf("rare-edge"); if (options.has("seed")) F.rng = new Random((Long) options.valueOf("seed")); @@ -79,15 +83,18 @@ public class Trainer System.exit(1); } + if (wordThreshold > 0) + corpus.applyWordThreshold(wordThreshold); + if (!options.has("agree")) System.out.println("Running with " + tags + " tags " + - "for " + em_iterations + " EM and " + pr_iterations + " PR iterations " + - "skipping large phrases for first " + skip + " iterations " + + "for " + iterations + " iterations " + + ((skip > 0) ? "skipping large phrases for first " + skip + " iterations " : "") + "with scale " + scale_phrase + " phrase and " + scale_context + " context " + "and " + threads + " threads"); else System.out.println("Running agreement model with " + tags + " tags " + - "for " + em_iterations); + "for " + iterations); System.out.println(); @@ -102,17 +109,28 @@ public class Trainer if (vb) cluster.initialiseVB(alphaEmit, alphaPi); if (options.has("no-parameter-cache")) cluster.cacheLambda = false; + if (options.has("start")) + { + try { + System.err.println("Reading starting parameters from " + options.valueOf("start")); + cluster.loadParameters(FileUtil.reader((File)options.valueOf("start"))); + } catch (IOException e) { + System.err.println("Failed to open input file: " + options.valueOf("start")); + e.printStackTrace(); + } + } + cluster.setEdgeThreshold(edgeThreshold); } double last = 0; - for (int i=0; i