summaryrefslogtreecommitdiff
path: root/klm/lm/builder/multi_stream.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/builder/multi_stream.hh')
-rw-r--r--klm/lm/builder/multi_stream.hh180
1 files changed, 180 insertions, 0 deletions
diff --git a/klm/lm/builder/multi_stream.hh b/klm/lm/builder/multi_stream.hh
new file mode 100644
index 00000000..707a98c7
--- /dev/null
+++ b/klm/lm/builder/multi_stream.hh
@@ -0,0 +1,180 @@
+#ifndef LM_BUILDER_MULTI_STREAM__
+#define LM_BUILDER_MULTI_STREAM__
+
+#include "lm/builder/ngram_stream.hh"
+#include "util/scoped.hh"
+#include "util/stream/chain.hh"
+
+#include <cstddef>
+#include <new>
+
+#include <assert.h>
+#include <stdlib.h>
+
+namespace lm { namespace builder {
+
+template <class T> class FixedArray {
+ public:
+ explicit FixedArray(std::size_t count) {
+ Init(count);
+ }
+
+ FixedArray() : newed_end_(NULL) {}
+
+ void Init(std::size_t count) {
+ assert(!block_.get());
+ block_.reset(malloc(sizeof(T) * count));
+ if (!block_.get()) throw std::bad_alloc();
+ newed_end_ = begin();
+ }
+
+ FixedArray(const FixedArray &from) {
+ std::size_t size = from.newed_end_ - static_cast<const T*>(from.block_.get());
+ Init(size);
+ for (std::size_t i = 0; i < size; ++i) {
+ new(end()) T(from[i]);
+ Constructed();
+ }
+ }
+
+ ~FixedArray() { clear(); }
+
+ T *begin() { return static_cast<T*>(block_.get()); }
+ const T *begin() const { return static_cast<const T*>(block_.get()); }
+ // Always call Constructed after successful completion of new.
+ T *end() { return newed_end_; }
+ const T *end() const { return newed_end_; }
+
+ T &back() { return *(end() - 1); }
+ const T &back() const { return *(end() - 1); }
+
+ std::size_t size() const { return end() - begin(); }
+ bool empty() const { return begin() == end(); }
+
+ T &operator[](std::size_t i) { return begin()[i]; }
+ const T &operator[](std::size_t i) const { return begin()[i]; }
+
+ template <class C> void push_back(const C &c) {
+ new (end()) T(c);
+ Constructed();
+ }
+
+ void clear() {
+ for (T *i = begin(); i != end(); ++i)
+ i->~T();
+ newed_end_ = begin();
+ }
+
+ protected:
+ void Constructed() {
+ ++newed_end_;
+ }
+
+ private:
+ util::scoped_malloc block_;
+
+ T *newed_end_;
+};
+
+class Chains;
+
+class ChainPositions : public FixedArray<util::stream::ChainPosition> {
+ public:
+ ChainPositions() {}
+
+ void Init(Chains &chains);
+
+ explicit ChainPositions(Chains &chains) {
+ Init(chains);
+ }
+};
+
+class Chains : public FixedArray<util::stream::Chain> {
+ private:
+ template <class T, void (T::*ptr)(const ChainPositions &) = &T::Run> struct CheckForRun {
+ typedef Chains type;
+ };
+
+ public:
+ explicit Chains(std::size_t limit) : FixedArray<util::stream::Chain>(limit) {}
+
+ template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) {
+ threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker));
+ return *this;
+ }
+
+ template <class Worker> typename CheckForRun<Worker>::type &operator>>(const boost::reference_wrapper<Worker> &worker) {
+ threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker));
+ return *this;
+ }
+
+ Chains &operator>>(const util::stream::Recycler &recycler) {
+ for (util::stream::Chain *i = begin(); i != end(); ++i)
+ *i >> recycler;
+ return *this;
+ }
+
+ void Wait(bool release_memory = true) {
+ threads_.clear();
+ for (util::stream::Chain *i = begin(); i != end(); ++i) {
+ i->Wait(release_memory);
+ }
+ }
+
+ private:
+ boost::ptr_vector<util::stream::Thread> threads_;
+
+ Chains(const Chains &);
+ void operator=(const Chains &);
+};
+
+inline void ChainPositions::Init(Chains &chains) {
+ FixedArray<util::stream::ChainPosition>::Init(chains.size());
+ for (util::stream::Chain *i = chains.begin(); i != chains.end(); ++i) {
+ new (end()) util::stream::ChainPosition(i->Add()); Constructed();
+ }
+}
+
+inline Chains &operator>>(Chains &chains, ChainPositions &positions) {
+ positions.Init(chains);
+ return chains;
+}
+
+class NGramStreams : public FixedArray<NGramStream> {
+ public:
+ NGramStreams() {}
+
+ // This puts a dummy NGramStream at the beginning (useful to algorithms that need to reference something at the beginning).
+ void InitWithDummy(const ChainPositions &positions) {
+ FixedArray<NGramStream>::Init(positions.size() + 1);
+ new (end()) NGramStream(); Constructed();
+ for (const util::stream::ChainPosition *i = positions.begin(); i != positions.end(); ++i) {
+ push_back(*i);
+ }
+ }
+
+ // Limit restricts to positions[0,limit)
+ void Init(const ChainPositions &positions, std::size_t limit) {
+ FixedArray<NGramStream>::Init(limit);
+ for (const util::stream::ChainPosition *i = positions.begin(); i != positions.begin() + limit; ++i) {
+ push_back(*i);
+ }
+ }
+ void Init(const ChainPositions &positions) {
+ Init(positions, positions.size());
+ }
+
+ NGramStreams(const ChainPositions &positions) {
+ Init(positions);
+ }
+};
+
+inline Chains &operator>>(Chains &chains, NGramStreams &streams) {
+ ChainPositions positions;
+ chains >> positions;
+ streams.Init(positions);
+ return chains;
+}
+
+}} // namespaces
+#endif // LM_BUILDER_MULTI_STREAM__