summaryrefslogtreecommitdiff
path: root/gi/pyp-topics/src/train.cc
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pyp-topics/src/train.cc')
-rw-r--r--gi/pyp-topics/src/train.cc21
1 files changed, 20 insertions, 1 deletions
diff --git a/gi/pyp-topics/src/train.cc b/gi/pyp-topics/src/train.cc
index 01ada182..a8fd994c 100644
--- a/gi/pyp-topics/src/train.cc
+++ b/gi/pyp-topics/src/train.cc
@@ -46,6 +46,7 @@ int main(int argc, char **argv)
("samples,s", value<int>()->default_value(10), "number of sampling passes through the data")
("test-corpus", value<string>(), "file containing the test data")
("backoff-paths", value<string>(), "file containing the term backoff paths")
+ ("backoff-type", value<string>(), "backoff type: none|simple")
;
options_description config_options, cmdline_options;
config_options.add(generic);
@@ -91,10 +92,27 @@ int main(int argc, char **argv)
}
else {
+ BackoffGenerator* backoff_gen=0;
+ if (vm.count("backoff-type")) {
+ if (vm["backoff-type"].as<std::string>() == "none") {
+ backoff_gen = 0;
+ }
+ else if (vm["backoff-type"].as<std::string>() == "simple") {
+ backoff_gen = new SimpleBackoffGenerator();
+ }
+ else {
+ std::cerr << "Backoff type (--backoff-type) must be one of none|simple." << std::endl;
+ return(1);
+ }
+ }
+
boost::shared_ptr<ContextsCorpus> contexts_corpus(new ContextsCorpus);
- contexts_corpus->read_contexts(vm["contexts"].as<string>());
+ contexts_corpus->read_contexts(vm["contexts"].as<string>(), backoff_gen);
corpus = contexts_corpus;
model.set_backoff(contexts_corpus->backoff_index());
+
+ if (backoff_gen)
+ delete backoff_gen;
}
// train the sampler
@@ -146,6 +164,7 @@ int main(int argc, char **argv)
}
topics_out.close();
}
+ std::cout << std::endl;
return 0;
}