summaryrefslogtreecommitdiff
path: root/klm/lm/builder/sort.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/builder/sort.hh')
-rw-r--r--klm/lm/builder/sort.hh103
1 files changed, 103 insertions, 0 deletions
diff --git a/klm/lm/builder/sort.hh b/klm/lm/builder/sort.hh
new file mode 100644
index 00000000..9989389b
--- /dev/null
+++ b/klm/lm/builder/sort.hh
@@ -0,0 +1,103 @@
+#ifndef LM_BUILDER_SORT__
+#define LM_BUILDER_SORT__
+
+#include "lm/builder/multi_stream.hh"
+#include "lm/builder/ngram.hh"
+#include "lm/word_index.hh"
+#include "util/stream/sort.hh"
+
+#include "util/stream/timer.hh"
+
+#include <functional>
+#include <string>
+
+namespace lm {
+namespace builder {
+
+template <class Child> class Comparator : public std::binary_function<const void *, const void *, bool> {
+ public:
+ explicit Comparator(std::size_t order) : order_(order) {}
+
+ inline bool operator()(const void *lhs, const void *rhs) const {
+ return static_cast<const Child*>(this)->Compare(static_cast<const WordIndex*>(lhs), static_cast<const WordIndex*>(rhs));
+ }
+
+ std::size_t Order() const { return order_; }
+
+ protected:
+ std::size_t order_;
+};
+
+class SuffixOrder : public Comparator<SuffixOrder> {
+ public:
+ explicit SuffixOrder(std::size_t order) : Comparator<SuffixOrder>(order) {}
+
+ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
+ for (std::size_t i = order_ - 1; i != 0; --i) {
+ if (lhs[i] != rhs[i])
+ return lhs[i] < rhs[i];
+ }
+ return lhs[0] < rhs[0];
+ }
+
+ static const unsigned kMatchOffset = 1;
+};
+
+class ContextOrder : public Comparator<ContextOrder> {
+ public:
+ explicit ContextOrder(std::size_t order) : Comparator<ContextOrder>(order) {}
+
+ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
+ for (int i = order_ - 2; i >= 0; --i) {
+ if (lhs[i] != rhs[i])
+ return lhs[i] < rhs[i];
+ }
+ return lhs[order_ - 1] < rhs[order_ - 1];
+ }
+};
+
+class PrefixOrder : public Comparator<PrefixOrder> {
+ public:
+ explicit PrefixOrder(std::size_t order) : Comparator<PrefixOrder>(order) {}
+
+ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
+ for (std::size_t i = 0; i < order_; ++i) {
+ if (lhs[i] != rhs[i])
+ return lhs[i] < rhs[i];
+ }
+ return false;
+ }
+
+ static const unsigned kMatchOffset = 0;
+};
+
+// Sum counts for the same n-gram.
+struct AddCombiner {
+ bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const {
+ NGram first(first_void, compare.Order());
+ // There isn't a const version of NGram.
+ NGram second(const_cast<void*>(second_void), compare.Order());
+ if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false;
+ first.Count() += second.Count();
+ return true;
+ }
+};
+
+// The combiner is only used on a single chain, so I didn't bother to allow
+// that template.
+template <class Compare> class Sorts : public FixedArray<util::stream::Sort<Compare> > {
+ private:
+ typedef util::stream::Sort<Compare> S;
+ typedef FixedArray<S> P;
+
+ public:
+ void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) {
+ new (P::end()) S(chain, config, compare);
+ P::Constructed();
+ }
+};
+
+} // namespace builder
+} // namespace lm
+
+#endif // LM_BUILDER_SORT__