diff options
Diffstat (limited to 'gi/pyp-topics/src/train.cc')
-rw-r--r-- | gi/pyp-topics/src/train.cc | 51 |
1 files changed, 9 insertions, 42 deletions
diff --git a/gi/pyp-topics/src/train.cc b/gi/pyp-topics/src/train.cc index a8fd994c..759bea1d 100644 --- a/gi/pyp-topics/src/train.cc +++ b/gi/pyp-topics/src/train.cc @@ -39,14 +39,12 @@ int main(int argc, char **argv) options_description generic("Allowed options"); generic.add_options() ("documents,d", value<string>(), "file containing the documents") - ("contexts", value<string>(), "file containing the documents in phrase contexts format") ("topics,t", value<int>()->default_value(50), "number of topics") ("document-topics-out,o", value<string>(), "file to write the document topics to") ("topic-words-out,w", value<string>(), "file to write the topic word distribution to") ("samples,s", value<int>()->default_value(10), "number of sampling passes through the data") ("test-corpus", value<string>(), "file containing the test data") ("backoff-paths", value<string>(), "file containing the term backoff paths") - ("backoff-type", value<string>(), "backoff type: none|simple") ; options_description config_options, cmdline_options; config_options.add(generic); @@ -59,12 +57,8 @@ int main(int argc, char **argv) } notify(vm); //////////////////////////////////////////////////////////////////////////////////////////// - if (vm.count("contexts") > 0 && vm.count("documents") > 0) { - cerr << "Only one of --documents or --contexts must be specified." << std::endl; - return 1; - } - if (vm.count("documents") == 0 && vm.count("contexts") == 0) { + if (vm.count("documents") == 0) { cerr << "Please specify a file containing the documents." << endl; cout << cmdline_options << "\n"; return 1; @@ -81,50 +75,23 @@ int main(int argc, char **argv) PYPTopics model(vm["topics"].as<int>()); // read the data - boost::shared_ptr<Corpus> corpus; - if (vm.count("documents") == 0 && vm.count("contexts") == 0) { - corpus.reset(new Corpus); - corpus->read(vm["documents"].as<string>()); - - // read the backoff dictionary - if (vm.count("backoff-paths")) - model.set_backoff(vm["backoff-paths"].as<string>()); + Corpus corpus; + corpus.read(vm["documents"].as<string>()); - } - else { - BackoffGenerator* backoff_gen=0; - if (vm.count("backoff-type")) { - if (vm["backoff-type"].as<std::string>() == "none") { - backoff_gen = 0; - } - else if (vm["backoff-type"].as<std::string>() == "simple") { - backoff_gen = new SimpleBackoffGenerator(); - } - else { - std::cerr << "Backoff type (--backoff-type) must be one of none|simple." << std::endl; - return(1); - } - } - - boost::shared_ptr<ContextsCorpus> contexts_corpus(new ContextsCorpus); - contexts_corpus->read_contexts(vm["contexts"].as<string>(), backoff_gen); - corpus = contexts_corpus; - model.set_backoff(contexts_corpus->backoff_index()); - - if (backoff_gen) - delete backoff_gen; - } + // read the backoff dictionary + if (vm.count("backoff-paths")) + model.set_backoff(vm["backoff-paths"].as<string>()); // train the sampler - model.sample(*corpus, vm["samples"].as<int>()); + model.sample(corpus, vm["samples"].as<int>()); if (vm.count("document-topics-out")) { ogzstream documents_out(vm["document-topics-out"].as<string>().c_str()); //model.print_document_topics(documents_out); int document_id=0; - for (Corpus::const_iterator corpusIt=corpus->begin(); - corpusIt != corpus->end(); ++corpusIt, ++document_id) { + for (Corpus::const_iterator corpusIt=corpus.begin(); + corpusIt != corpus.end(); ++corpusIt, ++document_id) { std::vector<int> unique_terms; for (Document::const_iterator docIt=corpusIt->begin(); docIt != corpusIt->end(); ++docIt) { |