summaryrefslogtreecommitdiff
path: root/gi/pyp-topics/src
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pyp-topics/src')
-rw-r--r--gi/pyp-topics/src/Makefile.am5
-rw-r--r--gi/pyp-topics/src/contexts_corpus.cc1
-rw-r--r--gi/pyp-topics/src/contexts_corpus.hh12
-rw-r--r--gi/pyp-topics/src/train-contexts.cc127
-rw-r--r--gi/pyp-topics/src/train.cc51
5 files changed, 150 insertions, 46 deletions
diff --git a/gi/pyp-topics/src/Makefile.am b/gi/pyp-topics/src/Makefile.am
index 3d62a334..7ca269a5 100644
--- a/gi/pyp-topics/src/Makefile.am
+++ b/gi/pyp-topics/src/Makefile.am
@@ -1,4 +1,4 @@
-bin_PROGRAMS = pyp-topics-train
+bin_PROGRAMS = pyp-topics-train pyp-contexts-train
contexts_lexer.cc: contexts_lexer.l
$(LEX) -s -CF -8 -o$@ $<
@@ -6,5 +6,8 @@ contexts_lexer.cc: contexts_lexer.l
pyp_topics_train_SOURCES = corpus.cc gammadist.c gzstream.cc mt19937ar.c pyp-topics.cc train.cc contexts_lexer.cc contexts_corpus.cc
pyp_topics_train_LDADD = -lz
+pyp_contexts_train_SOURCES = corpus.cc gammadist.c gzstream.cc mt19937ar.c pyp-topics.cc contexts_lexer.cc contexts_corpus.cc train-contexts.cc
+pyp_contexts_train_LDADD = -lz
+
AM_CPPFLAGS = -W -Wall -Wno-sign-compare -funroll-loops
diff --git a/gi/pyp-topics/src/contexts_corpus.cc b/gi/pyp-topics/src/contexts_corpus.cc
index afa1e19a..f3d3c92e 100644
--- a/gi/pyp-topics/src/contexts_corpus.cc
+++ b/gi/pyp-topics/src/contexts_corpus.cc
@@ -73,6 +73,7 @@ void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void*
//std::cout << std::endl;
corpus_ptr->m_documents.push_back(doc);
+ corpus_ptr->m_keys.push_back(new_contexts.phrase);
}
unsigned ContextsCorpus::read_contexts(const std::string &filename,
diff --git a/gi/pyp-topics/src/contexts_corpus.hh b/gi/pyp-topics/src/contexts_corpus.hh
index bd0cd34c..9614e7e3 100644
--- a/gi/pyp-topics/src/contexts_corpus.hh
+++ b/gi/pyp-topics/src/contexts_corpus.hh
@@ -49,9 +49,6 @@ class ContextsCorpus : public Corpus {
friend void read_callback(const ContextsLexer::PhraseContextsType&, void*);
public:
- typedef boost::ptr_vector<Document>::const_iterator const_iterator;
-
-public:
ContextsCorpus() : m_backoff(new TermBackoff) {}
virtual ~ContextsCorpus() {}
@@ -62,9 +59,18 @@ public:
return m_backoff;
}
+ std::vector<std::string> context2string(const WordID& id) const {
+ return m_dict.AsVector(id);
+ }
+
+ const std::string& key(const int& i) const {
+ return m_keys.at(i);
+ }
+
private:
TermBackoffPtr m_backoff;
Dict m_dict;
+ std::vector<std::string> m_keys;
};
#endif // _CONTEXTS_CORPUS_HH
diff --git a/gi/pyp-topics/src/train-contexts.cc b/gi/pyp-topics/src/train-contexts.cc
new file mode 100644
index 00000000..3ad8828f
--- /dev/null
+++ b/gi/pyp-topics/src/train-contexts.cc
@@ -0,0 +1,127 @@
+// STL
+#include <iostream>
+#include <fstream>
+#include <algorithm>
+#include <iterator>
+
+// Boost
+#include <boost/program_options/parsers.hpp>
+#include <boost/program_options/variables_map.hpp>
+#include <boost/scoped_ptr.hpp>
+
+// Local
+#include "pyp-topics.hh"
+#include "corpus.hh"
+#include "contexts_corpus.hh"
+#include "gzstream.hh"
+#include "mt19937ar.h"
+
+static const char *REVISION = "$Revision: 0.1 $";
+
+// Namespaces
+using namespace boost;
+using namespace boost::program_options;
+using namespace std;
+
+int main(int argc, char **argv)
+{
+ std::cout << "Pitman Yor topic models: Copyright 2010 Phil Blunsom\n";
+ std::cout << REVISION << '\n' << std::endl;
+
+ ////////////////////////////////////////////////////////////////////////////////////////////
+ // Command line processing
+ variables_map vm;
+
+ // Command line processing
+ {
+ options_description cmdline_options("Allowed options");
+ cmdline_options.add_options()
+ ("help,h", "print help message")
+ ("data,d", value<string>(), "file containing the documents and context terms")
+ ("topics,t", value<int>()->default_value(50), "number of topics")
+ ("document-topics-out,o", value<string>(), "file to write the document topics to")
+ ("topic-words-out,w", value<string>(), "file to write the topic word distribution to")
+ ("samples,s", value<int>()->default_value(10), "number of sampling passes through the data")
+ ("backoff-type", value<string>(), "backoff type: none|simple")
+ ;
+ store(parse_command_line(argc, argv, cmdline_options), vm);
+ notify(vm);
+
+ if (vm.count("help")) {
+ cout << cmdline_options << "\n";
+ return 1;
+ }
+ }
+ ////////////////////////////////////////////////////////////////////////////////////////////
+
+ if (!vm.count("data")) {
+ cerr << "Please specify a file containing the data." << endl;
+ return 1;
+ }
+
+ // seed the random number generator
+ //mt_init_genrand(time(0));
+
+ PYPTopics model(vm["topics"].as<int>());
+
+ // read the data
+ BackoffGenerator* backoff_gen=0;
+ if (vm.count("backoff-type")) {
+ if (vm["backoff-type"].as<std::string>() == "none") {
+ backoff_gen = 0;
+ }
+ else if (vm["backoff-type"].as<std::string>() == "simple") {
+ backoff_gen = new SimpleBackoffGenerator();
+ }
+ else {
+ std::cerr << "Backoff type (--backoff-type) must be one of none|simple." << std::endl;
+ return(1);
+ }
+ }
+
+ ContextsCorpus contexts_corpus;
+ contexts_corpus.read_contexts(vm["data"].as<string>(), backoff_gen);
+ model.set_backoff(contexts_corpus.backoff_index());
+
+ if (backoff_gen)
+ delete backoff_gen;
+
+ // train the sampler
+ model.sample(contexts_corpus, vm["samples"].as<int>());
+
+ if (vm.count("document-topics-out")) {
+ ogzstream documents_out(vm["document-topics-out"].as<string>().c_str());
+
+ int document_id=0;
+ for (Corpus::const_iterator corpusIt=contexts_corpus.begin();
+ corpusIt != contexts_corpus.end(); ++corpusIt, ++document_id) {
+ std::vector<int> unique_terms;
+ for (Document::const_iterator docIt=corpusIt->begin();
+ docIt != corpusIt->end(); ++docIt) {
+ if (unique_terms.empty() || *docIt != unique_terms.back())
+ unique_terms.push_back(*docIt);
+ }
+ documents_out << contexts_corpus.key(document_id) << '\t';
+ for (std::vector<int>::const_iterator termIt=unique_terms.begin();
+ termIt != unique_terms.end(); ++termIt) {
+ if (termIt != unique_terms.begin())
+ documents_out << " ||| ";
+ std::vector<std::string> strings = contexts_corpus.context2string(*termIt);
+ std::copy(strings.begin(), strings.end(), std::ostream_iterator<std::string>(documents_out, " "));
+ documents_out << "||| C=" << model.max(document_id, *termIt);
+ }
+ documents_out << std::endl;
+ }
+ documents_out.close();
+ }
+
+ if (vm.count("topic-words-out")) {
+ ogzstream topics_out(vm["topic-words-out"].as<string>().c_str());
+ model.print_topic_terms(topics_out);
+ topics_out.close();
+ }
+
+ std::cout << std::endl;
+
+ return 0;
+}
diff --git a/gi/pyp-topics/src/train.cc b/gi/pyp-topics/src/train.cc
index a8fd994c..759bea1d 100644
--- a/gi/pyp-topics/src/train.cc
+++ b/gi/pyp-topics/src/train.cc
@@ -39,14 +39,12 @@ int main(int argc, char **argv)
options_description generic("Allowed options");
generic.add_options()
("documents,d", value<string>(), "file containing the documents")
- ("contexts", value<string>(), "file containing the documents in phrase contexts format")
("topics,t", value<int>()->default_value(50), "number of topics")
("document-topics-out,o", value<string>(), "file to write the document topics to")
("topic-words-out,w", value<string>(), "file to write the topic word distribution to")
("samples,s", value<int>()->default_value(10), "number of sampling passes through the data")
("test-corpus", value<string>(), "file containing the test data")
("backoff-paths", value<string>(), "file containing the term backoff paths")
- ("backoff-type", value<string>(), "backoff type: none|simple")
;
options_description config_options, cmdline_options;
config_options.add(generic);
@@ -59,12 +57,8 @@ int main(int argc, char **argv)
}
notify(vm);
////////////////////////////////////////////////////////////////////////////////////////////
- if (vm.count("contexts") > 0 && vm.count("documents") > 0) {
- cerr << "Only one of --documents or --contexts must be specified." << std::endl;
- return 1;
- }
- if (vm.count("documents") == 0 && vm.count("contexts") == 0) {
+ if (vm.count("documents") == 0) {
cerr << "Please specify a file containing the documents." << endl;
cout << cmdline_options << "\n";
return 1;
@@ -81,50 +75,23 @@ int main(int argc, char **argv)
PYPTopics model(vm["topics"].as<int>());
// read the data
- boost::shared_ptr<Corpus> corpus;
- if (vm.count("documents") == 0 && vm.count("contexts") == 0) {
- corpus.reset(new Corpus);
- corpus->read(vm["documents"].as<string>());
-
- // read the backoff dictionary
- if (vm.count("backoff-paths"))
- model.set_backoff(vm["backoff-paths"].as<string>());
+ Corpus corpus;
+ corpus.read(vm["documents"].as<string>());
- }
- else {
- BackoffGenerator* backoff_gen=0;
- if (vm.count("backoff-type")) {
- if (vm["backoff-type"].as<std::string>() == "none") {
- backoff_gen = 0;
- }
- else if (vm["backoff-type"].as<std::string>() == "simple") {
- backoff_gen = new SimpleBackoffGenerator();
- }
- else {
- std::cerr << "Backoff type (--backoff-type) must be one of none|simple." << std::endl;
- return(1);
- }
- }
-
- boost::shared_ptr<ContextsCorpus> contexts_corpus(new ContextsCorpus);
- contexts_corpus->read_contexts(vm["contexts"].as<string>(), backoff_gen);
- corpus = contexts_corpus;
- model.set_backoff(contexts_corpus->backoff_index());
-
- if (backoff_gen)
- delete backoff_gen;
- }
+ // read the backoff dictionary
+ if (vm.count("backoff-paths"))
+ model.set_backoff(vm["backoff-paths"].as<string>());
// train the sampler
- model.sample(*corpus, vm["samples"].as<int>());
+ model.sample(corpus, vm["samples"].as<int>());
if (vm.count("document-topics-out")) {
ogzstream documents_out(vm["document-topics-out"].as<string>().c_str());
//model.print_document_topics(documents_out);
int document_id=0;
- for (Corpus::const_iterator corpusIt=corpus->begin();
- corpusIt != corpus->end(); ++corpusIt, ++document_id) {
+ for (Corpus::const_iterator corpusIt=corpus.begin();
+ corpusIt != corpus.end(); ++corpusIt, ++document_id) {
std::vector<int> unique_terms;
for (Document::const_iterator docIt=corpusIt->begin();
docIt != corpusIt->end(); ++docIt) {