diff options
Diffstat (limited to 'gi/pyp-topics')
-rw-r--r-- | gi/pyp-topics/src/Makefile.am | 5 | ||||
-rw-r--r-- | gi/pyp-topics/src/contexts_corpus.cc | 1 | ||||
-rw-r--r-- | gi/pyp-topics/src/contexts_corpus.hh | 12 | ||||
-rw-r--r-- | gi/pyp-topics/src/train-contexts.cc | 127 | ||||
-rw-r--r-- | gi/pyp-topics/src/train.cc | 51 |
5 files changed, 150 insertions, 46 deletions
diff --git a/gi/pyp-topics/src/Makefile.am b/gi/pyp-topics/src/Makefile.am index 3d62a334..7ca269a5 100644 --- a/gi/pyp-topics/src/Makefile.am +++ b/gi/pyp-topics/src/Makefile.am @@ -1,4 +1,4 @@ -bin_PROGRAMS = pyp-topics-train +bin_PROGRAMS = pyp-topics-train pyp-contexts-train contexts_lexer.cc: contexts_lexer.l $(LEX) -s -CF -8 -o$@ $< @@ -6,5 +6,8 @@ contexts_lexer.cc: contexts_lexer.l pyp_topics_train_SOURCES = corpus.cc gammadist.c gzstream.cc mt19937ar.c pyp-topics.cc train.cc contexts_lexer.cc contexts_corpus.cc pyp_topics_train_LDADD = -lz +pyp_contexts_train_SOURCES = corpus.cc gammadist.c gzstream.cc mt19937ar.c pyp-topics.cc contexts_lexer.cc contexts_corpus.cc train-contexts.cc +pyp_contexts_train_LDADD = -lz + AM_CPPFLAGS = -W -Wall -Wno-sign-compare -funroll-loops diff --git a/gi/pyp-topics/src/contexts_corpus.cc b/gi/pyp-topics/src/contexts_corpus.cc index afa1e19a..f3d3c92e 100644 --- a/gi/pyp-topics/src/contexts_corpus.cc +++ b/gi/pyp-topics/src/contexts_corpus.cc @@ -73,6 +73,7 @@ void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void* //std::cout << std::endl; corpus_ptr->m_documents.push_back(doc); + corpus_ptr->m_keys.push_back(new_contexts.phrase); } unsigned ContextsCorpus::read_contexts(const std::string &filename, diff --git a/gi/pyp-topics/src/contexts_corpus.hh b/gi/pyp-topics/src/contexts_corpus.hh index bd0cd34c..9614e7e3 100644 --- a/gi/pyp-topics/src/contexts_corpus.hh +++ b/gi/pyp-topics/src/contexts_corpus.hh @@ -49,9 +49,6 @@ class ContextsCorpus : public Corpus { friend void read_callback(const ContextsLexer::PhraseContextsType&, void*); public: - typedef boost::ptr_vector<Document>::const_iterator const_iterator; - -public: ContextsCorpus() : m_backoff(new TermBackoff) {} virtual ~ContextsCorpus() {} @@ -62,9 +59,18 @@ public: return m_backoff; } + std::vector<std::string> context2string(const WordID& id) const { + return m_dict.AsVector(id); + } + + const std::string& key(const int& i) const { + return m_keys.at(i); + } + private: TermBackoffPtr m_backoff; Dict m_dict; + std::vector<std::string> m_keys; }; #endif // _CONTEXTS_CORPUS_HH diff --git a/gi/pyp-topics/src/train-contexts.cc b/gi/pyp-topics/src/train-contexts.cc new file mode 100644 index 00000000..3ad8828f --- /dev/null +++ b/gi/pyp-topics/src/train-contexts.cc @@ -0,0 +1,127 @@ +// STL +#include <iostream> +#include <fstream> +#include <algorithm> +#include <iterator> + +// Boost +#include <boost/program_options/parsers.hpp> +#include <boost/program_options/variables_map.hpp> +#include <boost/scoped_ptr.hpp> + +// Local +#include "pyp-topics.hh" +#include "corpus.hh" +#include "contexts_corpus.hh" +#include "gzstream.hh" +#include "mt19937ar.h" + +static const char *REVISION = "$Revision: 0.1 $"; + +// Namespaces +using namespace boost; +using namespace boost::program_options; +using namespace std; + +int main(int argc, char **argv) +{ + std::cout << "Pitman Yor topic models: Copyright 2010 Phil Blunsom\n"; + std::cout << REVISION << '\n' << std::endl; + + //////////////////////////////////////////////////////////////////////////////////////////// + // Command line processing + variables_map vm; + + // Command line processing + { + options_description cmdline_options("Allowed options"); + cmdline_options.add_options() + ("help,h", "print help message") + ("data,d", value<string>(), "file containing the documents and context terms") + ("topics,t", value<int>()->default_value(50), "number of topics") + ("document-topics-out,o", value<string>(), "file to write the document topics to") + ("topic-words-out,w", value<string>(), "file to write the topic word distribution to") + ("samples,s", value<int>()->default_value(10), "number of sampling passes through the data") + ("backoff-type", value<string>(), "backoff type: none|simple") + ; + store(parse_command_line(argc, argv, cmdline_options), vm); + notify(vm); + + if (vm.count("help")) { + cout << cmdline_options << "\n"; + return 1; + } + } + //////////////////////////////////////////////////////////////////////////////////////////// + + if (!vm.count("data")) { + cerr << "Please specify a file containing the data." << endl; + return 1; + } + + // seed the random number generator + //mt_init_genrand(time(0)); + + PYPTopics model(vm["topics"].as<int>()); + + // read the data + BackoffGenerator* backoff_gen=0; + if (vm.count("backoff-type")) { + if (vm["backoff-type"].as<std::string>() == "none") { + backoff_gen = 0; + } + else if (vm["backoff-type"].as<std::string>() == "simple") { + backoff_gen = new SimpleBackoffGenerator(); + } + else { + std::cerr << "Backoff type (--backoff-type) must be one of none|simple." << std::endl; + return(1); + } + } + + ContextsCorpus contexts_corpus; + contexts_corpus.read_contexts(vm["data"].as<string>(), backoff_gen); + model.set_backoff(contexts_corpus.backoff_index()); + + if (backoff_gen) + delete backoff_gen; + + // train the sampler + model.sample(contexts_corpus, vm["samples"].as<int>()); + + if (vm.count("document-topics-out")) { + ogzstream documents_out(vm["document-topics-out"].as<string>().c_str()); + + int document_id=0; + for (Corpus::const_iterator corpusIt=contexts_corpus.begin(); + corpusIt != contexts_corpus.end(); ++corpusIt, ++document_id) { + std::vector<int> unique_terms; + for (Document::const_iterator docIt=corpusIt->begin(); + docIt != corpusIt->end(); ++docIt) { + if (unique_terms.empty() || *docIt != unique_terms.back()) + unique_terms.push_back(*docIt); + } + documents_out << contexts_corpus.key(document_id) << '\t'; + for (std::vector<int>::const_iterator termIt=unique_terms.begin(); + termIt != unique_terms.end(); ++termIt) { + if (termIt != unique_terms.begin()) + documents_out << " ||| "; + std::vector<std::string> strings = contexts_corpus.context2string(*termIt); + std::copy(strings.begin(), strings.end(), std::ostream_iterator<std::string>(documents_out, " ")); + documents_out << "||| C=" << model.max(document_id, *termIt); + } + documents_out << std::endl; + } + documents_out.close(); + } + + if (vm.count("topic-words-out")) { + ogzstream topics_out(vm["topic-words-out"].as<string>().c_str()); + model.print_topic_terms(topics_out); + topics_out.close(); + } + + std::cout << std::endl; + + return 0; +} diff --git a/gi/pyp-topics/src/train.cc b/gi/pyp-topics/src/train.cc index a8fd994c..759bea1d 100644 --- a/gi/pyp-topics/src/train.cc +++ b/gi/pyp-topics/src/train.cc @@ -39,14 +39,12 @@ int main(int argc, char **argv) options_description generic("Allowed options"); generic.add_options() ("documents,d", value<string>(), "file containing the documents") - ("contexts", value<string>(), "file containing the documents in phrase contexts format") ("topics,t", value<int>()->default_value(50), "number of topics") ("document-topics-out,o", value<string>(), "file to write the document topics to") ("topic-words-out,w", value<string>(), "file to write the topic word distribution to") ("samples,s", value<int>()->default_value(10), "number of sampling passes through the data") ("test-corpus", value<string>(), "file containing the test data") ("backoff-paths", value<string>(), "file containing the term backoff paths") - ("backoff-type", value<string>(), "backoff type: none|simple") ; options_description config_options, cmdline_options; config_options.add(generic); @@ -59,12 +57,8 @@ int main(int argc, char **argv) } notify(vm); //////////////////////////////////////////////////////////////////////////////////////////// - if (vm.count("contexts") > 0 && vm.count("documents") > 0) { - cerr << "Only one of --documents or --contexts must be specified." << std::endl; - return 1; - } - if (vm.count("documents") == 0 && vm.count("contexts") == 0) { + if (vm.count("documents") == 0) { cerr << "Please specify a file containing the documents." << endl; cout << cmdline_options << "\n"; return 1; @@ -81,50 +75,23 @@ int main(int argc, char **argv) PYPTopics model(vm["topics"].as<int>()); // read the data - boost::shared_ptr<Corpus> corpus; - if (vm.count("documents") == 0 && vm.count("contexts") == 0) { - corpus.reset(new Corpus); - corpus->read(vm["documents"].as<string>()); - - // read the backoff dictionary - if (vm.count("backoff-paths")) - model.set_backoff(vm["backoff-paths"].as<string>()); + Corpus corpus; + corpus.read(vm["documents"].as<string>()); - } - else { - BackoffGenerator* backoff_gen=0; - if (vm.count("backoff-type")) { - if (vm["backoff-type"].as<std::string>() == "none") { - backoff_gen = 0; - } - else if (vm["backoff-type"].as<std::string>() == "simple") { - backoff_gen = new SimpleBackoffGenerator(); - } - else { - std::cerr << "Backoff type (--backoff-type) must be one of none|simple." << std::endl; - return(1); - } - } - - boost::shared_ptr<ContextsCorpus> contexts_corpus(new ContextsCorpus); - contexts_corpus->read_contexts(vm["contexts"].as<string>(), backoff_gen); - corpus = contexts_corpus; - model.set_backoff(contexts_corpus->backoff_index()); - - if (backoff_gen) - delete backoff_gen; - } + // read the backoff dictionary + if (vm.count("backoff-paths")) + model.set_backoff(vm["backoff-paths"].as<string>()); // train the sampler - model.sample(*corpus, vm["samples"].as<int>()); + model.sample(corpus, vm["samples"].as<int>()); if (vm.count("document-topics-out")) { ogzstream documents_out(vm["document-topics-out"].as<string>().c_str()); //model.print_document_topics(documents_out); int document_id=0; - for (Corpus::const_iterator corpusIt=corpus->begin(); - corpusIt != corpus->end(); ++corpusIt, ++document_id) { + for (Corpus::const_iterator corpusIt=corpus.begin(); + corpusIt != corpus.end(); ++corpusIt, ++document_id) { std::vector<int> unique_terms; for (Document::const_iterator docIt=corpusIt->begin(); docIt != corpusIt->end(); ++docIt) { |