summaryrefslogtreecommitdiff
path: root/gi/pyp-topics/src/train.cc
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pyp-topics/src/train.cc')
-rw-r--r--gi/pyp-topics/src/train.cc51
1 files changed, 9 insertions, 42 deletions
diff --git a/gi/pyp-topics/src/train.cc b/gi/pyp-topics/src/train.cc
index a8fd994c..759bea1d 100644
--- a/gi/pyp-topics/src/train.cc
+++ b/gi/pyp-topics/src/train.cc
@@ -39,14 +39,12 @@ int main(int argc, char **argv)
options_description generic("Allowed options");
generic.add_options()
("documents,d", value<string>(), "file containing the documents")
- ("contexts", value<string>(), "file containing the documents in phrase contexts format")
("topics,t", value<int>()->default_value(50), "number of topics")
("document-topics-out,o", value<string>(), "file to write the document topics to")
("topic-words-out,w", value<string>(), "file to write the topic word distribution to")
("samples,s", value<int>()->default_value(10), "number of sampling passes through the data")
("test-corpus", value<string>(), "file containing the test data")
("backoff-paths", value<string>(), "file containing the term backoff paths")
- ("backoff-type", value<string>(), "backoff type: none|simple")
;
options_description config_options, cmdline_options;
config_options.add(generic);
@@ -59,12 +57,8 @@ int main(int argc, char **argv)
}
notify(vm);
////////////////////////////////////////////////////////////////////////////////////////////
- if (vm.count("contexts") > 0 && vm.count("documents") > 0) {
- cerr << "Only one of --documents or --contexts must be specified." << std::endl;
- return 1;
- }
- if (vm.count("documents") == 0 && vm.count("contexts") == 0) {
+ if (vm.count("documents") == 0) {
cerr << "Please specify a file containing the documents." << endl;
cout << cmdline_options << "\n";
return 1;
@@ -81,50 +75,23 @@ int main(int argc, char **argv)
PYPTopics model(vm["topics"].as<int>());
// read the data
- boost::shared_ptr<Corpus> corpus;
- if (vm.count("documents") == 0 && vm.count("contexts") == 0) {
- corpus.reset(new Corpus);
- corpus->read(vm["documents"].as<string>());
-
- // read the backoff dictionary
- if (vm.count("backoff-paths"))
- model.set_backoff(vm["backoff-paths"].as<string>());
+ Corpus corpus;
+ corpus.read(vm["documents"].as<string>());
- }
- else {
- BackoffGenerator* backoff_gen=0;
- if (vm.count("backoff-type")) {
- if (vm["backoff-type"].as<std::string>() == "none") {
- backoff_gen = 0;
- }
- else if (vm["backoff-type"].as<std::string>() == "simple") {
- backoff_gen = new SimpleBackoffGenerator();
- }
- else {
- std::cerr << "Backoff type (--backoff-type) must be one of none|simple." << std::endl;
- return(1);
- }
- }
-
- boost::shared_ptr<ContextsCorpus> contexts_corpus(new ContextsCorpus);
- contexts_corpus->read_contexts(vm["contexts"].as<string>(), backoff_gen);
- corpus = contexts_corpus;
- model.set_backoff(contexts_corpus->backoff_index());
-
- if (backoff_gen)
- delete backoff_gen;
- }
+ // read the backoff dictionary
+ if (vm.count("backoff-paths"))
+ model.set_backoff(vm["backoff-paths"].as<string>());
// train the sampler
- model.sample(*corpus, vm["samples"].as<int>());
+ model.sample(corpus, vm["samples"].as<int>());
if (vm.count("document-topics-out")) {
ogzstream documents_out(vm["document-topics-out"].as<string>().c_str());
//model.print_document_topics(documents_out);
int document_id=0;
- for (Corpus::const_iterator corpusIt=corpus->begin();
- corpusIt != corpus->end(); ++corpusIt, ++document_id) {
+ for (Corpus::const_iterator corpusIt=corpus.begin();
+ corpusIt != corpus.end(); ++corpusIt, ++document_id) {
std::vector<int> unique_terms;
for (Document::const_iterator docIt=corpusIt->begin();
docIt != corpusIt->end(); ++docIt) {