summaryrefslogtreecommitdiff
path: root/gi/pyp-topics/src/train.cc
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pyp-topics/src/train.cc')
-rw-r--r--gi/pyp-topics/src/train.cc43
1 files changed, 29 insertions, 14 deletions
diff --git a/gi/pyp-topics/src/train.cc b/gi/pyp-topics/src/train.cc
index 0d107f11..01ada182 100644
--- a/gi/pyp-topics/src/train.cc
+++ b/gi/pyp-topics/src/train.cc
@@ -10,6 +10,7 @@
// Local
#include "pyp-topics.hh"
#include "corpus.hh"
+#include "contexts_corpus.hh"
#include "gzstream.hh"
#include "mt19937ar.h"
@@ -38,6 +39,7 @@ int main(int argc, char **argv)
options_description generic("Allowed options");
generic.add_options()
("documents,d", value<string>(), "file containing the documents")
+ ("contexts", value<string>(), "file containing the documents in phrase contexts format")
("topics,t", value<int>()->default_value(50), "number of topics")
("document-topics-out,o", value<string>(), "file to write the document topics to")
("topic-words-out,w", value<string>(), "file to write the topic word distribution to")
@@ -56,42 +58,55 @@ int main(int argc, char **argv)
}
notify(vm);
////////////////////////////////////////////////////////////////////////////////////////////
-
- if (vm.count("help")) {
- cout << cmdline_options << "\n";
+ if (vm.count("contexts") > 0 && vm.count("documents") > 0) {
+ cerr << "Only one of --documents or --contexts must be specified." << std::endl;
return 1;
}
- if (vm.count("documents") == 0) {
+ if (vm.count("documents") == 0 && vm.count("contexts") == 0) {
cerr << "Please specify a file containing the documents." << endl;
cout << cmdline_options << "\n";
return 1;
}
+ if (vm.count("help")) {
+ cout << cmdline_options << "\n";
+ return 1;
+ }
+
// seed the random number generator
//mt_init_genrand(time(0));
+ PYPTopics model(vm["topics"].as<int>());
+
// read the data
- Corpus corpus;
- corpus.read(vm["documents"].as<string>());
+ boost::shared_ptr<Corpus> corpus;
+ if (vm.count("documents") == 0 && vm.count("contexts") == 0) {
+ corpus.reset(new Corpus);
+ corpus->read(vm["documents"].as<string>());
- // run the sampler
- PYPTopics model(vm["topics"].as<int>());
+ // read the backoff dictionary
+ if (vm.count("backoff-paths"))
+ model.set_backoff(vm["backoff-paths"].as<string>());
- // read the backoff dictionary
- if (vm.count("backoff-paths"))
- model.set_backoff(vm["backoff-paths"].as<string>());
+ }
+ else {
+ boost::shared_ptr<ContextsCorpus> contexts_corpus(new ContextsCorpus);
+ contexts_corpus->read_contexts(vm["contexts"].as<string>());
+ corpus = contexts_corpus;
+ model.set_backoff(contexts_corpus->backoff_index());
+ }
// train the sampler
- model.sample(corpus, vm["samples"].as<int>());
+ model.sample(*corpus, vm["samples"].as<int>());
if (vm.count("document-topics-out")) {
ogzstream documents_out(vm["document-topics-out"].as<string>().c_str());
//model.print_document_topics(documents_out);
int document_id=0;
- for (Corpus::const_iterator corpusIt=corpus.begin();
- corpusIt != corpus.end(); ++corpusIt, ++document_id) {
+ for (Corpus::const_iterator corpusIt=corpus->begin();
+ corpusIt != corpus->end(); ++corpusIt, ++document_id) {
std::vector<int> unique_terms;
for (Document::const_iterator docIt=corpusIt->begin();
docIt != corpusIt->end(); ++docIt) {