From efe0d24fa7dbca47825638a52f51977456153bd0 Mon Sep 17 00:00:00 2001 From: "philblunsom@gmail.com" Date: Tue, 22 Jun 2010 20:34:00 +0000 Subject: Initial ci of gi dir git-svn-id: https://ws10smt.googlecode.com/svn/trunk@5 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/pyp-topics/src/train.cc | 136 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 gi/pyp-topics/src/train.cc (limited to 'gi/pyp-topics/src/train.cc') diff --git a/gi/pyp-topics/src/train.cc b/gi/pyp-topics/src/train.cc new file mode 100644 index 00000000..0d107f11 --- /dev/null +++ b/gi/pyp-topics/src/train.cc @@ -0,0 +1,136 @@ +// STL +#include +#include + +// Boost +#include +#include +#include + +// Local +#include "pyp-topics.hh" +#include "corpus.hh" +#include "gzstream.hh" +#include "mt19937ar.h" + +static const char *REVISION = "$Revision: 0.1 $"; + +// Namespaces +using namespace boost; +using namespace boost::program_options; +using namespace std; + +int main(int argc, char **argv) +{ + std::cout << "Pitman Yor topic models: Copyright 2010 Phil Blunsom\n"; + std::cout << REVISION << '\n' << std::endl; + + //////////////////////////////////////////////////////////////////////////////////////////// + // Command line processing + variables_map vm; + + // Command line processing + options_description cmdline_specific("Command line specific options"); + cmdline_specific.add_options() + ("help,h", "print help message") + ("config,c", value(), "config file specifying additional command line options") + ; + options_description generic("Allowed options"); + generic.add_options() + ("documents,d", value(), "file containing the documents") + ("topics,t", value()->default_value(50), "number of topics") + ("document-topics-out,o", value(), "file to write the document topics to") + ("topic-words-out,w", value(), "file to write the topic word distribution to") + ("samples,s", value()->default_value(10), "number of sampling passes through the data") + ("test-corpus", value(), "file containing the test data") + ("backoff-paths", value(), "file containing the term backoff paths") + ; + options_description config_options, cmdline_options; + config_options.add(generic); + cmdline_options.add(generic).add(cmdline_specific); + + store(parse_command_line(argc, argv, cmdline_options), vm); + if (vm.count("config") > 0) { + ifstream config(vm["config"].as().c_str()); + store(parse_config_file(config, cmdline_options), vm); + } + notify(vm); + //////////////////////////////////////////////////////////////////////////////////////////// + + if (vm.count("help")) { + cout << cmdline_options << "\n"; + return 1; + } + + if (vm.count("documents") == 0) { + cerr << "Please specify a file containing the documents." << endl; + cout << cmdline_options << "\n"; + return 1; + } + + // seed the random number generator + //mt_init_genrand(time(0)); + + // read the data + Corpus corpus; + corpus.read(vm["documents"].as()); + + // run the sampler + PYPTopics model(vm["topics"].as()); + + // read the backoff dictionary + if (vm.count("backoff-paths")) + model.set_backoff(vm["backoff-paths"].as()); + + // train the sampler + model.sample(corpus, vm["samples"].as()); + + if (vm.count("document-topics-out")) { + ogzstream documents_out(vm["document-topics-out"].as().c_str()); + //model.print_document_topics(documents_out); + + int document_id=0; + for (Corpus::const_iterator corpusIt=corpus.begin(); + corpusIt != corpus.end(); ++corpusIt, ++document_id) { + std::vector unique_terms; + for (Document::const_iterator docIt=corpusIt->begin(); + docIt != corpusIt->end(); ++docIt) { + if (unique_terms.empty() || *docIt != unique_terms.back()) + unique_terms.push_back(*docIt); + } + documents_out << unique_terms.size(); + for (std::vector::const_iterator termIt=unique_terms.begin(); + termIt != unique_terms.end(); ++termIt) + documents_out << " " << *termIt << ":" << model.max(document_id, *termIt); + documents_out << std::endl; + } + documents_out.close(); + } + + if (vm.count("topic-words-out")) { + ogzstream topics_out(vm["topic-words-out"].as().c_str()); + model.print_topic_terms(topics_out); + topics_out.close(); + } + + if (vm.count("test-corpus")) { + TestCorpus test_corpus; + test_corpus.read(vm["test-corpus"].as()); + ogzstream topics_out((vm["test-corpus"].as() + ".topics.gz").c_str()); + + for (TestCorpus::const_iterator corpusIt=test_corpus.begin(); + corpusIt != test_corpus.end(); ++corpusIt) { + int index=0; + for (DocumentTerms::const_iterator instanceIt=corpusIt->begin(); + instanceIt != corpusIt->end(); ++instanceIt, ++index) { + int topic = model.max(instanceIt->doc, instanceIt->term); + if (index != 0) topics_out << " "; + topics_out << topic; + } + topics_out << std::endl; + } + topics_out.close(); + } + + return 0; +} -- cgit v1.2.3