diff options
author | philblunsom <philblunsom@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-19 18:33:29 +0000 |
---|---|---|
committer | philblunsom <philblunsom@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-19 18:33:29 +0000 |
commit | 73dbb0343a895345a80d49da9d48edac8858e87a (patch) | |
tree | d33164f980f79218bb57153daaa563ec6d6cf1cb /gi/pyp-topics/src/mpi-train-contexts.cc | |
parent | cf868a29d10942c62b4041e5931e68f868a4b96d (diff) |
Vaguely working distributed implementation. Hierarchical topics doesn't yet work correctly.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@317 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/pyp-topics/src/mpi-train-contexts.cc')
-rw-r--r-- | gi/pyp-topics/src/mpi-train-contexts.cc | 113 |
1 files changed, 64 insertions, 49 deletions
diff --git a/gi/pyp-topics/src/mpi-train-contexts.cc b/gi/pyp-topics/src/mpi-train-contexts.cc index 0651ecac..6e1e78a5 100644 --- a/gi/pyp-topics/src/mpi-train-contexts.cc +++ b/gi/pyp-topics/src/mpi-train-contexts.cc @@ -10,6 +10,7 @@ #include <boost/scoped_ptr.hpp> #include <boost/mpi/environment.hpp> #include <boost/mpi/communicator.hpp> +#include <boost/lexical_cast.hpp> // Local #include "mpi-pyp-topics.hh" @@ -28,9 +29,10 @@ int main(int argc, char **argv) { mpi::environment env(argc, argv); mpi::communicator world; - bool am_root = (world.rank() == 0); - if (am_root) std::cout << "I am process " << world.rank() << " of " << world.size() << "." << std::endl; + int rank = world.rank(); + bool am_root = rank; if (am_root) cout << "Pitman Yor topic models: Copyright 2010 Phil Blunsom\n"; + if (am_root) std::cout << "I am process " << world.rank() << " of " << world.size() << "." << std::endl; if (am_root) cout << REVISION << '\n' <<endl; //////////////////////////////////////////////////////////////////////////////////////////// @@ -117,60 +119,73 @@ int main(int argc, char **argv) vm["freq-cutoff-interval"].as<int>(), vm["max-contexts-per-document"].as<int>()); - if (world.rank() == 0) { - if (vm.count("document-topics-out")) { - ogzstream documents_out(vm["document-topics-out"].as<string>().c_str()); - - int document_id=0; - map<int,int> all_terms; - for (Corpus::const_iterator corpusIt=contexts_corpus.begin(); - corpusIt != contexts_corpus.end(); ++corpusIt, ++document_id) { - vector<int> unique_terms; - for (Document::const_iterator docIt=corpusIt->begin(); - docIt != corpusIt->end(); ++docIt) { - if (unique_terms.empty() || *docIt != unique_terms.back()) - unique_terms.push_back(*docIt); - // increment this terms frequency - pair<map<int,int>::iterator,bool> insert_result = all_terms.insert(make_pair(*docIt,1)); - if (!insert_result.second) - all_terms[*docIt] = all_terms[*docIt] + 1; - //insert_result.first++; - } - documents_out << contexts_corpus.key(document_id) << '\t'; - documents_out << model.max(document_id) << " " << corpusIt->size() << " ||| "; - for (std::vector<int>::const_iterator termIt=unique_terms.begin(); - termIt != unique_terms.end(); ++termIt) { - if (termIt != unique_terms.begin()) - documents_out << " ||| "; - vector<std::string> strings = contexts_corpus.context2string(*termIt); - copy(strings.begin(), strings.end(),ostream_iterator<std::string>(documents_out, " ")); - documents_out << "||| C=" << model.max(document_id, *termIt); - - } - documents_out <<endl; + if (vm.count("document-topics-out")) { + std::ofstream documents_out((vm["document-topics-out"].as<string>() + ".pyp-process-" + boost::lexical_cast<std::string>(rank)).c_str()); + int documents = contexts_corpus.num_documents(); + int mpi_start = 0, mpi_end = documents; + if (world.size() != 1) { + mpi_start = (documents / world.size()) * rank; + if (rank == world.size()-1) mpi_end = documents; + else mpi_end = (documents / world.size())*(rank+1); + } + + map<int,int> all_terms; + for (int document_id=mpi_start; document_id<mpi_end; ++document_id) { + assert (document_id < contexts_corpus.num_documents()); + const Document& doc = contexts_corpus.at(document_id); + vector<int> unique_terms; + for (Document::const_iterator docIt=doc.begin(); docIt != doc.end(); ++docIt) { + if (unique_terms.empty() || *docIt != unique_terms.back()) + unique_terms.push_back(*docIt); + // increment this terms frequency + pair<map<int,int>::iterator,bool> insert_result = all_terms.insert(make_pair(*docIt,1)); + if (!insert_result.second) + all_terms[*docIt] = all_terms[*docIt] + 1; } - documents_out.close(); - - if (vm.count("default-topics-out")) { - ofstream default_topics(vm["default-topics-out"].as<string>().c_str()); - default_topics << model.max_topic() <<endl; - for (std::map<int,int>::const_iterator termIt=all_terms.begin(); termIt != all_terms.end(); ++termIt) { - vector<std::string> strings = contexts_corpus.context2string(termIt->first); - default_topics << model.max(-1, termIt->first) << " ||| " << termIt->second << " ||| "; - copy(strings.begin(), strings.end(),ostream_iterator<std::string>(default_topics, " ")); - default_topics <<endl; - } + documents_out << contexts_corpus.key(document_id) << '\t'; + documents_out << model.max(document_id) << " " << doc.size() << " ||| "; + for (std::vector<int>::const_iterator termIt=unique_terms.begin(); termIt != unique_terms.end(); ++termIt) { + if (termIt != unique_terms.begin()) + documents_out << " ||| "; + vector<std::string> strings = contexts_corpus.context2string(*termIt); + copy(strings.begin(), strings.end(),ostream_iterator<std::string>(documents_out, " ")); + documents_out << "||| C=" << model.max(document_id, *termIt); } + documents_out <<endl; + } + documents_out.close(); + + if (am_root) { + ogzstream root_documents_out(vm["document-topics-out"].as<string>().c_str()); + for (int p=0; p < world.size(); ++p) { + std::string rank_p_prefix((vm["document-topics-out"].as<string>() + ".pyp-process-" + boost::lexical_cast<std::string>(p)).c_str()); + std::ifstream rank_p_trees_istream(rank_p_prefix.c_str(), std::ios_base::binary); + root_documents_out << rank_p_trees_istream.rdbuf(); + rank_p_trees_istream.close(); + remove((rank_p_prefix).c_str()); + } + root_documents_out.close(); } - if (vm.count("topic-words-out")) { - ogzstream topics_out(vm["topic-words-out"].as<string>().c_str()); - model.print_topic_terms(topics_out); - topics_out.close(); + if (am_root && vm.count("default-topics-out")) { + ofstream default_topics(vm["default-topics-out"].as<string>().c_str()); + default_topics << model.max_topic() <<endl; + for (std::map<int,int>::const_iterator termIt=all_terms.begin(); termIt != all_terms.end(); ++termIt) { + vector<std::string> strings = contexts_corpus.context2string(termIt->first); + default_topics << model.max(-1, termIt->first) << " ||| " << termIt->second << " ||| "; + copy(strings.begin(), strings.end(),ostream_iterator<std::string>(default_topics, " ")); + default_topics <<endl; + } } + } - cout <<endl; + if (am_root && vm.count("topic-words-out")) { + ogzstream topics_out(vm["topic-words-out"].as<string>().c_str()); + model.print_topic_terms(topics_out); + topics_out.close(); } + cout <<endl; + return 0; } |