summaryrefslogtreecommitdiff
path: root/klm/alone/threading.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/alone/threading.cc')
-rw-r--r--klm/alone/threading.cc80
1 files changed, 80 insertions, 0 deletions
diff --git a/klm/alone/threading.cc b/klm/alone/threading.cc
new file mode 100644
index 00000000..475386b6
--- /dev/null
+++ b/klm/alone/threading.cc
@@ -0,0 +1,80 @@
+#include "alone/threading.hh"
+
+#include "alone/assemble.hh"
+#include "alone/graph.hh"
+#include "alone/read.hh"
+#include "alone/vocab.hh"
+#include "lm/model.hh"
+#include "search/context.hh"
+#include "search/vertex_generator.hh"
+
+#include <boost/ref.hpp>
+#include <boost/scoped_ptr.hpp>
+#include <boost/utility/in_place_factory.hpp>
+
+#include <sstream>
+
+namespace alone {
+template <class Model> void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out) {
+ search::Context<Model> context(config, model);
+ Graph graph;
+ Vocab vocab(model.GetVocabulary());
+ {
+ boost::scoped_ptr<util::FilePiece> in(in_ptr);
+ ReadCDec(context, *in, graph, vocab);
+ }
+
+ for (std::size_t i = 0; i < graph.VertexSize(); ++i) {
+ search::VertexGenerator(context, graph.MutableVertex(i));
+ }
+ search::PartialVertex top = graph.Root().RootPartial();
+ if (top.Empty()) {
+ out << "NO PATH FOUND";
+ } else {
+ search::PartialVertex continuation;
+ while (!top.Complete()) {
+ top.Split(continuation);
+ top = continuation;
+ }
+ out << top.End() << " ||| " << top.End().Bound() << std::endl;
+ }
+}
+
+template void Decode(const search::Config &config, const lm::ngram::ProbingModel &model, util::FilePiece *in_ptr, std::ostream &out);
+template void Decode(const search::Config &config, const lm::ngram::RestProbingModel &model, util::FilePiece *in_ptr, std::ostream &out);
+
+#ifdef WITH_THREADS
+template <class Model> void DecodeHandler<Model>::operator()(Input message) {
+ std::stringstream assemble;
+ Decode(config_, model_, message.file, assemble);
+ Produce(message.sentence_id, assemble.str());
+}
+
+template <class Model> void DecodeHandler<Model>::Produce(unsigned int sentence_id, const std::string &str) {
+ Output out;
+ out.sentence_id = sentence_id;
+ out.str = new std::string(str);
+ out_.Produce(out);
+}
+
+void PrintHandler::operator()(Output message) {
+ unsigned int relative = message.sentence_id - done_;
+ if (waiting_.size() <= relative) waiting_.resize(relative + 1);
+ waiting_[relative] = message.str;
+ for (std::string *lead; !waiting_.empty() && (lead = waiting_[0]); waiting_.pop_front(), ++done_) {
+ out_ << *lead;
+ delete lead;
+ }
+}
+
+template <class Model> Controller<Model>::Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to) :
+ sentence_id_(0),
+ printer_(decode_workers, 1, boost::ref(to), Output::Poison()),
+ decoder_(3, decode_workers, boost::in_place(boost::ref(config), boost::ref(model), boost::ref(printer_.In())), Input::Poison()) {}
+
+template class Controller<lm::ngram::RestProbingModel>;
+template class Controller<lm::ngram::ProbingModel>;
+
+#endif
+
+} // namespace alone