diff options
Diffstat (limited to 'gi/pyp-topics/src/train.cc')
-rw-r--r-- | gi/pyp-topics/src/train.cc | 43 |
1 files changed, 29 insertions, 14 deletions
diff --git a/gi/pyp-topics/src/train.cc b/gi/pyp-topics/src/train.cc index 0d107f11..01ada182 100644 --- a/gi/pyp-topics/src/train.cc +++ b/gi/pyp-topics/src/train.cc @@ -10,6 +10,7 @@ // Local #include "pyp-topics.hh" #include "corpus.hh" +#include "contexts_corpus.hh" #include "gzstream.hh" #include "mt19937ar.h" @@ -38,6 +39,7 @@ int main(int argc, char **argv) options_description generic("Allowed options"); generic.add_options() ("documents,d", value<string>(), "file containing the documents") + ("contexts", value<string>(), "file containing the documents in phrase contexts format") ("topics,t", value<int>()->default_value(50), "number of topics") ("document-topics-out,o", value<string>(), "file to write the document topics to") ("topic-words-out,w", value<string>(), "file to write the topic word distribution to") @@ -56,42 +58,55 @@ int main(int argc, char **argv) } notify(vm); //////////////////////////////////////////////////////////////////////////////////////////// - - if (vm.count("help")) { - cout << cmdline_options << "\n"; + if (vm.count("contexts") > 0 && vm.count("documents") > 0) { + cerr << "Only one of --documents or --contexts must be specified." << std::endl; return 1; } - if (vm.count("documents") == 0) { + if (vm.count("documents") == 0 && vm.count("contexts") == 0) { cerr << "Please specify a file containing the documents." << endl; cout << cmdline_options << "\n"; return 1; } + if (vm.count("help")) { + cout << cmdline_options << "\n"; + return 1; + } + // seed the random number generator //mt_init_genrand(time(0)); + PYPTopics model(vm["topics"].as<int>()); + // read the data - Corpus corpus; - corpus.read(vm["documents"].as<string>()); + boost::shared_ptr<Corpus> corpus; + if (vm.count("documents") == 0 && vm.count("contexts") == 0) { + corpus.reset(new Corpus); + corpus->read(vm["documents"].as<string>()); - // run the sampler - PYPTopics model(vm["topics"].as<int>()); + // read the backoff dictionary + if (vm.count("backoff-paths")) + model.set_backoff(vm["backoff-paths"].as<string>()); - // read the backoff dictionary - if (vm.count("backoff-paths")) - model.set_backoff(vm["backoff-paths"].as<string>()); + } + else { + boost::shared_ptr<ContextsCorpus> contexts_corpus(new ContextsCorpus); + contexts_corpus->read_contexts(vm["contexts"].as<string>()); + corpus = contexts_corpus; + model.set_backoff(contexts_corpus->backoff_index()); + } // train the sampler - model.sample(corpus, vm["samples"].as<int>()); + model.sample(*corpus, vm["samples"].as<int>()); if (vm.count("document-topics-out")) { ogzstream documents_out(vm["document-topics-out"].as<string>().c_str()); //model.print_document_topics(documents_out); int document_id=0; - for (Corpus::const_iterator corpusIt=corpus.begin(); - corpusIt != corpus.end(); ++corpusIt, ++document_id) { + for (Corpus::const_iterator corpusIt=corpus->begin(); + corpusIt != corpus->end(); ++corpusIt, ++document_id) { std::vector<int> unique_terms; for (Document::const_iterator docIt=corpusIt->begin(); docIt != corpusIt->end(); ++docIt) { |