summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/Trainer.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Trainer.java46
1 files changed, 23 insertions, 23 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
index 7f0b1970..ec1a5804 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
@@ -7,8 +7,11 @@ import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintStream;
+import java.util.List;
import java.util.Random;
+import phrase.Corpus.Edge;
+
import arr.F;
public class Trainer
@@ -34,10 +37,6 @@ public class Trainer
parser.accepts("agree");
parser.accepts("no-parameter-cache");
parser.accepts("skip-large-phrases").withRequiredArg().ofType(Integer.class).defaultsTo(5);
- parser.accepts("rare-word").withRequiredArg().ofType(Integer.class).defaultsTo(10);
- parser.accepts("rare-edge").withRequiredArg().ofType(Integer.class).defaultsTo(1);
- parser.accepts("rare-phrase").withRequiredArg().ofType(Integer.class).defaultsTo(2);
- parser.accepts("rare-context").withRequiredArg().ofType(Integer.class).defaultsTo(2);
OptionSet options = parser.parse(args);
if (options.has("help") || !options.has("in"))
@@ -61,10 +60,6 @@ public class Trainer
double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0;
double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0;
int skip = (Integer) options.valueOf("skip-large-phrases");
- int wordThreshold = (Integer) options.valueOf("rare-word");
- int edgeThreshold = (Integer) options.valueOf("rare-edge");
- int phraseThreshold = (Integer) options.valueOf("rare-phrase");
- int contextThreshold = (Integer) options.valueOf("rare-context");
if (options.has("seed"))
F.rng = new Random((Long) options.valueOf("seed"));
@@ -86,14 +81,7 @@ public class Trainer
e.printStackTrace();
System.exit(1);
}
-
- if (wordThreshold > 1)
- corpus.applyWordThreshold(wordThreshold);
- if (phraseThreshold > 1)
- corpus.applyPhraseThreshold(phraseThreshold);
- if (contextThreshold > 1)
- corpus.applyContextThreshold(contextThreshold);
-
+
if (!options.has("agree"))
System.out.println("Running with " + tags + " tags " +
"for " + iterations + " iterations " +
@@ -127,7 +115,6 @@ public class Trainer
e.printStackTrace();
}
}
- cluster.setEdgeThreshold(edgeThreshold);
}
double last = 0;
@@ -138,20 +125,24 @@ public class Trainer
o = agree.EM();
else
{
+ if (i < skip)
+ System.out.println("Skipping phrases of length > " + (i+1));
+
if (scale_phrase <= 0 && scale_context <= 0)
{
if (!vb)
- o = cluster.EM(i < skip);
+ o = cluster.EM((i < skip) ? i+1 : 0);
else
- o = cluster.VBEM(alphaEmit, alphaPi, i < skip);
+ o = cluster.VBEM(alphaEmit, alphaPi);
}
else
- o = cluster.PREM(scale_phrase, scale_context, i < skip);
+ o = cluster.PREM(scale_phrase, scale_context, (i < skip) ? i+1 : 0);
}
System.out.println("ITER: "+i+" objective: " + o);
- if (i != 0 && Math.abs((o - last) / o) < threshold)
+ // sometimes takes a few iterations to break the ties
+ if (i > 5 && Math.abs((o - last) / o) < threshold)
{
last = o;
break;
@@ -171,10 +162,19 @@ public class Trainer
File outfile = (File) options.valueOf("out");
try {
PrintStream ps = FileUtil.printstream(outfile);
- cluster.displayPosterior(ps);
+ List<Edge> test;
+ if (!options.has("test"))
+ test = corpus.getEdges();
+ else
+ {
+ infile = (File) options.valueOf("test");
+ System.out.println("Reading testing concordance from " + infile);
+ test = corpus.readEdges(FileUtil.reader(infile));
+ }
+ cluster.displayPosterior(ps, test);
ps.close();
} catch (IOException e) {
- System.err.println("Failed to open output file: " + outfile);
+ System.err.println("Failed to open either testing file or output file");
e.printStackTrace();
System.exit(1);
}