// STL #include <iostream> #include <fstream> #include <algorithm> #include <iterator> // Boost #include <boost/program_options/parsers.hpp> #include <boost/program_options/variables_map.hpp> #include <boost/scoped_ptr.hpp> // Local #include "pyp-topics.hh" #include "corpus.hh" #include "contexts_corpus.hh" #include "gzstream.hh" static const char *REVISION = "$Rev$"; // Namespaces using namespace boost; using namespace boost::program_options; using namespace std; int main(int argc, char **argv) { cout << "Pitman Yor topic models: Copyright 2010 Phil Blunsom\n"; cout << REVISION << '\n' <<endl; //////////////////////////////////////////////////////////////////////////////////////////// // Command line processing variables_map vm; // Command line processing { options_description cmdline_specific("Command line specific options"); cmdline_specific.add_options() ("help,h", "print help message") ("config,c", value<string>(), "config file specifying additional command line options") ; options_description config_options("Allowed options"); config_options.add_options() ("data,d", value<string>(), "file containing the documents and context terms") ("topics,t", value<int>()->default_value(50), "number of topics") ("document-topics-out,o", value<string>(), "file to write the document topics to") ("default-topics-out", value<string>(), "file to write default term topic assignments.") ("topic-words-out,w", value<string>(), "file to write the topic word distribution to") ("samples,s", value<int>()->default_value(10), "number of sampling passes through the data") ("backoff-type", value<string>(), "backoff type: none|simple") // ("filter-singleton-contexts", "filter singleton contexts") ("hierarchical-topics", "Use a backoff hierarchical PYP as the P0 for the document topics distribution.") ("freq-cutoff-start", value<int>()->default_value(0), "initial frequency cutoff.") ("freq-cutoff-end", value<int>()->default_value(0), "final frequency cutoff.") ("freq-cutoff-interval", value<int>()->default_value(0), "number of iterations between frequency decrement.") ("max-threads", value<int>()->default_value(1), "maximum number of simultaneous threads allowed") ("max-contexts-per-document", value<int>()->default_value(0), "Only sample the n most frequent contexts for a document.") ("num-jobs", value<int>()->default_value(1), "allows finer control over parallelization") ("temp-start", value<double>()->default_value(1.0), "starting annealing temperature.") ("temp-end", value<double>()->default_value(1.0), "end annealing temperature.") ; cmdline_specific.add(config_options); store(parse_command_line(argc, argv, cmdline_specific), vm); notify(vm); if (vm.count("config") > 0) { ifstream config(vm["config"].as<string>().c_str()); store(parse_config_file(config, config_options), vm); } if (vm.count("help")) { cout << cmdline_specific << "\n"; return 1; } } //////////////////////////////////////////////////////////////////////////////////////////// if (!vm.count("data")) { cerr << "Please specify a file containing the data." << endl; return 1; } assert(vm["max-threads"].as<int>() > 0); assert(vm["num-jobs"].as<int>() > -1); // seed the random number generator: 0 = automatic, specify value otherwise unsigned long seed = 0; PYPTopics model(vm["topics"].as<int>(), vm.count("hierarchical-topics"), seed, vm["max-threads"].as<int>(), vm["num-jobs"].as<int>()); // read the data BackoffGenerator* backoff_gen=0; if (vm.count("backoff-type")) { if (vm["backoff-type"].as<std::string>() == "none") { backoff_gen = 0; } else if (vm["backoff-type"].as<std::string>() == "simple") { backoff_gen = new SimpleBackoffGenerator(); } else { cerr << "Backoff type (--backoff-type) must be one of none|simple." <<endl; return(1); } } ContextsCorpus contexts_corpus; contexts_corpus.read_contexts(vm["data"].as<string>(), backoff_gen, /*vm.count("filter-singleton-contexts")*/ false); model.set_backoff(contexts_corpus.backoff_index()); if (backoff_gen) delete backoff_gen; // train the sampler model.sample_corpus(contexts_corpus, vm["samples"].as<int>(), vm["freq-cutoff-start"].as<int>(), vm["freq-cutoff-end"].as<int>(), vm["freq-cutoff-interval"].as<int>(), vm["max-contexts-per-document"].as<int>(), vm["temp-start"].as<double>(), vm["temp-end"].as<double>()); if (vm.count("document-topics-out")) { ogzstream documents_out(vm["document-topics-out"].as<string>().c_str()); int document_id=0; map<int,int> all_terms; for (Corpus::const_iterator corpusIt=contexts_corpus.begin(); corpusIt != contexts_corpus.end(); ++corpusIt, ++document_id) { vector<int> unique_terms; for (Document::const_iterator docIt=corpusIt->begin(); docIt != corpusIt->end(); ++docIt) { if (unique_terms.empty() || *docIt != unique_terms.back()) unique_terms.push_back(*docIt); // increment this terms frequency pair<map<int,int>::iterator,bool> insert_result = all_terms.insert(make_pair(*docIt,1)); if (!insert_result.second) all_terms[*docIt] = all_terms[*docIt] + 1; //insert_result.first++; } documents_out << contexts_corpus.key(document_id) << '\t'; documents_out << model.max(document_id).first << " " << corpusIt->size() << " ||| "; for (std::vector<int>::const_iterator termIt=unique_terms.begin(); termIt != unique_terms.end(); ++termIt) { if (termIt != unique_terms.begin()) documents_out << " ||| "; vector<std::string> strings = contexts_corpus.context2string(*termIt); copy(strings.begin(), strings.end(),ostream_iterator<std::string>(documents_out, " ")); std::pair<int,PYPTopics::F> maxinfo = model.max(document_id, *termIt); documents_out << "||| C=" << maxinfo.first << " P=" << maxinfo.second; } documents_out <<endl; } documents_out.close(); if (vm.count("default-topics-out")) { ofstream default_topics(vm["default-topics-out"].as<string>().c_str()); default_topics << model.max_topic() <<endl; for (std::map<int,int>::const_iterator termIt=all_terms.begin(); termIt != all_terms.end(); ++termIt) { vector<std::string> strings = contexts_corpus.context2string(termIt->first); default_topics << model.max(-1, termIt->first).first << " ||| " << termIt->second << " ||| "; copy(strings.begin(), strings.end(),ostream_iterator<std::string>(default_topics, " ")); default_topics <<endl; } } } if (vm.count("topic-words-out")) { ogzstream topics_out(vm["topic-words-out"].as<string>().c_str()); model.print_topic_terms(topics_out); topics_out.close(); } cout <<endl; return 0; }