summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
diff options
context:
space:
mode:
authortrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-16 21:34:28 +0000
committertrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-16 21:34:28 +0000
commit1207aaee1f55dbaac8a46f37635a4d1baf392760 (patch)
treead335c14a9df152e4603cc70957103137817d018 /gi/posterior-regularisation/prjava/src/phrase/Trainer.java
parent9ffba9b2a35582df415384117450f994e64d7cdb (diff)
Added various flags to filter out low count events (words, edges).
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@298 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/Trainer.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Trainer.java48
1 files changed, 29 insertions, 19 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
index a67c17a2..ed7a6bbe 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
@@ -4,6 +4,7 @@ import io.FileUtil;
import joptsimple.OptionParser;
import joptsimple.OptionSet;
import java.io.File;
+import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintStream;
import java.util.Random;
@@ -18,12 +19,12 @@ public class Trainer
parser.accepts("help");
parser.accepts("in").withRequiredArg().ofType(File.class);
parser.accepts("out").withRequiredArg().ofType(File.class);
+ parser.accepts("start").withRequiredArg().ofType(File.class);
parser.accepts("parameters").withRequiredArg().ofType(File.class);
parser.accepts("topics").withRequiredArg().ofType(Integer.class).defaultsTo(5);
- parser.accepts("em-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(5);
- parser.accepts("pr-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(0);
+ parser.accepts("iterations").withRequiredArg().ofType(Integer.class).defaultsTo(10);
parser.accepts("threads").withRequiredArg().ofType(Integer.class).defaultsTo(0);
- parser.accepts("scale-phrase").withRequiredArg().ofType(Double.class).defaultsTo(5.0);
+ parser.accepts("scale-phrase").withRequiredArg().ofType(Double.class).defaultsTo(0.0);
parser.accepts("scale-context").withRequiredArg().ofType(Double.class).defaultsTo(0.0);
parser.accepts("seed").withRequiredArg().ofType(Long.class).defaultsTo(0l);
parser.accepts("convergence-threshold").withRequiredArg().ofType(Double.class).defaultsTo(1e-6);
@@ -33,6 +34,8 @@ public class Trainer
parser.accepts("agree");
parser.accepts("no-parameter-cache");
parser.accepts("skip-large-phrases").withRequiredArg().ofType(Integer.class).defaultsTo(5);
+ parser.accepts("rare-word").withRequiredArg().ofType(Integer.class).defaultsTo(0);
+ parser.accepts("rare-edge").withRequiredArg().ofType(Integer.class).defaultsTo(0);
OptionSet options = parser.parse(args);
if (options.has("help") || !options.has("in"))
@@ -47,8 +50,7 @@ public class Trainer
}
int tags = (Integer) options.valueOf("topics");
- int em_iterations = (Integer) options.valueOf("em-iterations");
- int pr_iterations = (Integer) options.valueOf("pr-iterations");
+ int iterations = (Integer) options.valueOf("iterations");
double scale_phrase = (Double) options.valueOf("scale-phrase");
double scale_context = (Double) options.valueOf("scale-context");
int threads = (Integer) options.valueOf("threads");
@@ -57,6 +59,8 @@ public class Trainer
double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0;
double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0;
int skip = (Integer) options.valueOf("skip-large-phrases");
+ int wordThreshold = (Integer) options.valueOf("rare-word");
+ int edgeThreshold = (Integer) options.valueOf("rare-edge");
if (options.has("seed"))
F.rng = new Random((Long) options.valueOf("seed"));
@@ -79,15 +83,18 @@ public class Trainer
System.exit(1);
}
+ if (wordThreshold > 0)
+ corpus.applyWordThreshold(wordThreshold);
+
if (!options.has("agree"))
System.out.println("Running with " + tags + " tags " +
- "for " + em_iterations + " EM and " + pr_iterations + " PR iterations " +
- "skipping large phrases for first " + skip + " iterations " +
+ "for " + iterations + " iterations " +
+ ((skip > 0) ? "skipping large phrases for first " + skip + " iterations " : "") +
"with scale " + scale_phrase + " phrase and " + scale_context + " context " +
"and " + threads + " threads");
else
System.out.println("Running agreement model with " + tags + " tags " +
- "for " + em_iterations);
+ "for " + iterations);
System.out.println();
@@ -102,17 +109,28 @@ public class Trainer
if (vb) cluster.initialiseVB(alphaEmit, alphaPi);
if (options.has("no-parameter-cache"))
cluster.cacheLambda = false;
+ if (options.has("start"))
+ {
+ try {
+ System.err.println("Reading starting parameters from " + options.valueOf("start"));
+ cluster.loadParameters(FileUtil.reader((File)options.valueOf("start")));
+ } catch (IOException e) {
+ System.err.println("Failed to open input file: " + options.valueOf("start"));
+ e.printStackTrace();
+ }
+ }
+ cluster.setEdgeThreshold(edgeThreshold);
}
double last = 0;
- for (int i=0; i<em_iterations+pr_iterations; i++)
+ for (int i=0; i < iterations; i++)
{
double o;
if (agree != null)
o = agree.EM();
else
{
- if (i < em_iterations)
+ if (scale_phrase <= 0 && scale_context <= 0)
{
if (!vb)
o = cluster.EM(i < skip);
@@ -128,13 +146,7 @@ public class Trainer
if (i != 0 && Math.abs((o - last) / o) < threshold)
{
last = o;
- if (i < Math.max(em_iterations, skip))
- {
- i = Math.max(em_iterations, skip) - 1;
- continue;
- }
- else
- break;
+ break;
}
last = o;
}
@@ -145,8 +157,6 @@ public class Trainer
double pl1lmax = cluster.phrase_l1lmax();
double cl1lmax = cluster.context_l1lmax();
System.out.println("\nFinal posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax);
- if (pr_iterations == 0)
- System.out.println("With PR objective " + (last - scale_phrase*pl1lmax - scale_context*cl1lmax));
if (options.has("out"))
{