diff options
Diffstat (limited to 'gi/pyp-topics/src/mpi-train-contexts.cc')
-rw-r--r-- | gi/pyp-topics/src/mpi-train-contexts.cc | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/gi/pyp-topics/src/mpi-train-contexts.cc b/gi/pyp-topics/src/mpi-train-contexts.cc index 4f966a65..7bb890d2 100644 --- a/gi/pyp-topics/src/mpi-train-contexts.cc +++ b/gi/pyp-topics/src/mpi-train-contexts.cc @@ -15,7 +15,7 @@ // Local #include "mpi-pyp-topics.hh" #include "corpus.hh" -#include "contexts_corpus.hh" +#include "mpi-corpus.hh" #include "gzstream.hh" static const char *REVISION = "$Rev: 170 $"; @@ -105,8 +105,13 @@ int main(int argc, char **argv) } } - ContextsCorpus contexts_corpus; + //ContextsCorpus contexts_corpus; + MPICorpus contexts_corpus; contexts_corpus.read_contexts(vm["data"].as<string>(), backoff_gen, /*vm.count("filter-singleton-contexts")*/ false); + int mpi_start = 0, mpi_end = 0; + contexts_corpus.bounds(&mpi_start, &mpi_end); + std::cerr << "\tProcess " << rank << " has documents " << mpi_start << " -> " << mpi_end << "." << std::endl; + model.set_backoff(contexts_corpus.backoff_index()); if (backoff_gen) @@ -121,13 +126,15 @@ int main(int argc, char **argv) if (vm.count("document-topics-out")) { std::ofstream documents_out((vm["document-topics-out"].as<string>() + ".pyp-process-" + boost::lexical_cast<std::string>(rank)).c_str()); - int documents = contexts_corpus.num_documents(); + //int documents = contexts_corpus.num_documents(); + /* int mpi_start = 0, mpi_end = documents; if (world.size() != 1) { mpi_start = (documents / world.size()) * rank; if (rank == world.size()-1) mpi_end = documents; else mpi_end = (documents / world.size())*(rank+1); } + */ map<int,int> all_terms; for (int document_id=mpi_start; document_id<mpi_end; ++document_id) { @@ -143,13 +150,14 @@ int main(int argc, char **argv) all_terms[*docIt] = all_terms[*docIt] + 1; } documents_out << contexts_corpus.key(document_id) << '\t'; - documents_out << model.max(document_id) << " " << doc.size() << " ||| "; + documents_out << model.max(document_id).first << " " << doc.size() << " ||| "; for (std::vector<int>::const_iterator termIt=unique_terms.begin(); termIt != unique_terms.end(); ++termIt) { if (termIt != unique_terms.begin()) documents_out << " ||| "; vector<std::string> strings = contexts_corpus.context2string(*termIt); copy(strings.begin(), strings.end(),ostream_iterator<std::string>(documents_out, " ")); - documents_out << "||| C=" << model.max(document_id, *termIt); + std::pair<int,MPIPYPTopics::F> maxinfo = model.max(document_id, *termIt); + documents_out << "||| C=" << maxinfo.first << " P=" << maxinfo.second; } documents_out <<endl; } @@ -173,7 +181,7 @@ int main(int argc, char **argv) default_topics << model.max_topic() <<endl; for (std::map<int,int>::const_iterator termIt=all_terms.begin(); termIt != all_terms.end(); ++termIt) { vector<std::string> strings = contexts_corpus.context2string(termIt->first); - default_topics << model.max(-1, termIt->first) << " ||| " << termIt->second << " ||| "; + default_topics << model.max(-1, termIt->first).first << " ||| " << termIt->second << " ||| "; copy(strings.begin(), strings.end(),ostream_iterator<std::string>(default_topics, " ")); default_topics <<endl; } |