summaryrefslogtreecommitdiff
path: root/gi/pipeline
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pipeline')
-rwxr-xr-xgi/pipeline/local-gi-pipeline.pl34
1 files changed, 18 insertions, 16 deletions
diff --git a/gi/pipeline/local-gi-pipeline.pl b/gi/pipeline/local-gi-pipeline.pl
index a705af3b..e757f4cd 100755
--- a/gi/pipeline/local-gi-pipeline.pl
+++ b/gi/pipeline/local-gi-pipeline.pl
@@ -18,8 +18,10 @@ my $BIDIR = 0;
my $TOPICS_CONFIG = "pyp-topics.conf";
my $MODEL = "pyp";
-my $NUM_EM_PR_ITERS = 20;
-my $PR_SCALE = 10.0;
+my $NUM_EM_ITERS = 100;
+my $NUM_PR_ITERS = 0;
+my $PR_SCALE_P = 10;
+my $PR_SCALE_C = 0;
my $PR_THREADS = 0;
my $EXTOOLS = "$SCRIPT_DIR/../../extools";
@@ -35,7 +37,7 @@ my $C2D = "$PYPSCRIPTS/contexts2documents.py";
my $S2L = "$PYPSCRIPTS/spans2labels.py";
my $PYP_TOPICS_TRAIN="$PYPTOOLS/pyp-contexts-train";
-my $PREM_TRAIN="java -ea -Xmx4g -jar $PRTOOLS/prjava.jar";
+my $PREM_TRAIN="$PRTOOLS/prjava/train-PR-cluster.sh";
my $SORT_KEYS = "$SCRIPT_DIR/scripts/sort-by-key.sh";
my $EXTRACTOR = "$EXTOOLS/extractor";
@@ -55,8 +57,10 @@ usage() unless &GetOptions('base_phrase_max_size=i' => \$BASE_PHRASE_MAX_SIZE,
'trg_context=i' => \$CONTEXT_SIZE,
'samples=i' => \$NUM_SAMPLES,
'topics-config=s' => \$TOPICS_CONFIG,
- 'em-iterations=i' => \$NUM_EM_PR_ITERS,
- 'pr-scale=f' => \$PR_SCALE,
+ 'em-iterations=i' => \$NUM_EM_ITERS,
+ 'pr-iterations=i' => \$NUM_PR_ITERS,
+ 'pr-scale-phrase=f' => \$PR_SCALE_P,
+ 'pr-scale-context=f' => \$PR_SCALE_C,
'pr-threads=i' => \$PR_THREADS,
'tagged_corpus=s' => \$TAGGED_CORPUS,
);
@@ -81,9 +85,9 @@ if(-e $TOPICS_CONFIG) {
extract_context();
if (lc($MODEL) eq "pyp") {
topic_train();
-} else {
+} elsif (lc($MODEL) eq "prem") {
prem_train();
-}
+} else { die "Unsupported model type: $MODEL. Must be one of PYP or PREM.\n"; }
label_spans_with_topics();
my $res;
if ($BIDIR) {
@@ -102,12 +106,13 @@ sub context_dir {
sub cluster_dir {
if (lc($MODEL) eq "pyp") {
return context_dir() . ".PYP.t$NUM_TOPICS.s$NUM_SAMPLES";
- } elsif (lc($MODEL) eq "em") {
- return context_dir() . ".EM.t$NUM_TOPICS.i$NUM_EM_PR_ITERS";
- } elsif (lc($MODEL) eq "pr") {
- return context_dir() . ".PR.t$NUM_TOPICS.i$NUM_EM_PR_ITERS.s$PR_SCALE";
+ } elsif (lc($MODEL) eq "prem") {
+ if ($NUM_PR_ITERS == 0) {
+ return context_dir() . ".PREM.t$NUM_TOPICS.ie$NUM_EM_ITERS.ip$NUM_PR_ITERS";
+ } else {
+ return context_dir() . ".PREM.t$NUM_TOPICS.ie$NUM_EM_ITERS.ip$NUM_PR_ITERS.sp$PR_SCALE_P.sc$PR_SCALE_C";
+ }
}
- die "Badness 10000\n";
}
sub grammar_dir {
@@ -175,10 +180,7 @@ sub prem_train {
if (-e $OUT_CLUSTERS) {
print STDERR "$OUT_CLUSTERS exists, reusing...\n";
} else {
- my $emflag="false";
- if (lc($MODEL) eq "em") { $emflag="true"; }
- elsif (lc($MODEL) ne "pr") { die "Unsupported model type: $MODEL"; }
- safesystem("$PREM_TRAIN $IN_CONTEXTS $NUM_TOPICS $OUT_CLUSTERS $NUM_EM_PR_ITERS $PR_SCALE $PR_THREADS $emflag") or die "Topic training failed.\n";
+ safesystem("$PREM_TRAIN --in $IN_CONTEXTS --topics $NUM_TOPICS --out $OUT_CLUSTERS --em $NUM_EM_ITERS --pr $NUM_PR_ITERS --scale-phrase $PR_SCALE_P --scale-context $PR_SCALE_C --threads $PR_THREADS") or die "Topic training failed.\n";
}
}