From 1070f63e3978b9b26df46ad80fe1f40f2ce83a23 Mon Sep 17 00:00:00 2001 From: "philblunsom@gmail.com" Date: Mon, 28 Jun 2010 15:01:17 +0000 Subject: Added contexts_corpus for reading text data files. git-svn-id: https://ws10smt.googlecode.com/svn/trunk@36 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/pyp-topics/src/train.cc | 43 +++++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 14 deletions(-) (limited to 'gi/pyp-topics/src/train.cc') diff --git a/gi/pyp-topics/src/train.cc b/gi/pyp-topics/src/train.cc index 0d107f11..01ada182 100644 --- a/gi/pyp-topics/src/train.cc +++ b/gi/pyp-topics/src/train.cc @@ -10,6 +10,7 @@ // Local #include "pyp-topics.hh" #include "corpus.hh" +#include "contexts_corpus.hh" #include "gzstream.hh" #include "mt19937ar.h" @@ -38,6 +39,7 @@ int main(int argc, char **argv) options_description generic("Allowed options"); generic.add_options() ("documents,d", value(), "file containing the documents") + ("contexts", value(), "file containing the documents in phrase contexts format") ("topics,t", value()->default_value(50), "number of topics") ("document-topics-out,o", value(), "file to write the document topics to") ("topic-words-out,w", value(), "file to write the topic word distribution to") @@ -56,42 +58,55 @@ int main(int argc, char **argv) } notify(vm); //////////////////////////////////////////////////////////////////////////////////////////// - - if (vm.count("help")) { - cout << cmdline_options << "\n"; + if (vm.count("contexts") > 0 && vm.count("documents") > 0) { + cerr << "Only one of --documents or --contexts must be specified." << std::endl; return 1; } - if (vm.count("documents") == 0) { + if (vm.count("documents") == 0 && vm.count("contexts") == 0) { cerr << "Please specify a file containing the documents." << endl; cout << cmdline_options << "\n"; return 1; } + if (vm.count("help")) { + cout << cmdline_options << "\n"; + return 1; + } + // seed the random number generator //mt_init_genrand(time(0)); + PYPTopics model(vm["topics"].as()); + // read the data - Corpus corpus; - corpus.read(vm["documents"].as()); + boost::shared_ptr corpus; + if (vm.count("documents") == 0 && vm.count("contexts") == 0) { + corpus.reset(new Corpus); + corpus->read(vm["documents"].as()); - // run the sampler - PYPTopics model(vm["topics"].as()); + // read the backoff dictionary + if (vm.count("backoff-paths")) + model.set_backoff(vm["backoff-paths"].as()); - // read the backoff dictionary - if (vm.count("backoff-paths")) - model.set_backoff(vm["backoff-paths"].as()); + } + else { + boost::shared_ptr contexts_corpus(new ContextsCorpus); + contexts_corpus->read_contexts(vm["contexts"].as()); + corpus = contexts_corpus; + model.set_backoff(contexts_corpus->backoff_index()); + } // train the sampler - model.sample(corpus, vm["samples"].as()); + model.sample(*corpus, vm["samples"].as()); if (vm.count("document-topics-out")) { ogzstream documents_out(vm["document-topics-out"].as().c_str()); //model.print_document_topics(documents_out); int document_id=0; - for (Corpus::const_iterator corpusIt=corpus.begin(); - corpusIt != corpus.end(); ++corpusIt, ++document_id) { + for (Corpus::const_iterator corpusIt=corpus->begin(); + corpusIt != corpus->end(); ++corpusIt, ++document_id) { std::vector unique_terms; for (Document::const_iterator docIt=corpusIt->begin(); docIt != corpusIt->end(); ++docIt) { -- cgit v1.2.3