From 0b9031042500d45a098762f0a930bd6a66a58fac Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 18 Jan 2013 17:12:51 +0000 Subject: KenLM dffafbf with lmplz source (but not built) --- klm/lm/builder/sort.hh | 103 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 klm/lm/builder/sort.hh (limited to 'klm/lm/builder/sort.hh') diff --git a/klm/lm/builder/sort.hh b/klm/lm/builder/sort.hh new file mode 100644 index 00000000..9989389b --- /dev/null +++ b/klm/lm/builder/sort.hh @@ -0,0 +1,103 @@ +#ifndef LM_BUILDER_SORT__ +#define LM_BUILDER_SORT__ + +#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram.hh" +#include "lm/word_index.hh" +#include "util/stream/sort.hh" + +#include "util/stream/timer.hh" + +#include +#include + +namespace lm { +namespace builder { + +template class Comparator : public std::binary_function { + public: + explicit Comparator(std::size_t order) : order_(order) {} + + inline bool operator()(const void *lhs, const void *rhs) const { + return static_cast(this)->Compare(static_cast(lhs), static_cast(rhs)); + } + + std::size_t Order() const { return order_; } + + protected: + std::size_t order_; +}; + +class SuffixOrder : public Comparator { + public: + explicit SuffixOrder(std::size_t order) : Comparator(order) {} + + inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { + for (std::size_t i = order_ - 1; i != 0; --i) { + if (lhs[i] != rhs[i]) + return lhs[i] < rhs[i]; + } + return lhs[0] < rhs[0]; + } + + static const unsigned kMatchOffset = 1; +}; + +class ContextOrder : public Comparator { + public: + explicit ContextOrder(std::size_t order) : Comparator(order) {} + + inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { + for (int i = order_ - 2; i >= 0; --i) { + if (lhs[i] != rhs[i]) + return lhs[i] < rhs[i]; + } + return lhs[order_ - 1] < rhs[order_ - 1]; + } +}; + +class PrefixOrder : public Comparator { + public: + explicit PrefixOrder(std::size_t order) : Comparator(order) {} + + inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { + for (std::size_t i = 0; i < order_; ++i) { + if (lhs[i] != rhs[i]) + return lhs[i] < rhs[i]; + } + return false; + } + + static const unsigned kMatchOffset = 0; +}; + +// Sum counts for the same n-gram. +struct AddCombiner { + bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const { + NGram first(first_void, compare.Order()); + // There isn't a const version of NGram. + NGram second(const_cast(second_void), compare.Order()); + if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false; + first.Count() += second.Count(); + return true; + } +}; + +// The combiner is only used on a single chain, so I didn't bother to allow +// that template. +template class Sorts : public FixedArray > { + private: + typedef util::stream::Sort S; + typedef FixedArray P; + + public: + void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) { + new (P::end()) S(chain, config, compare); + P::Constructed(); + } +}; + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_SORT__ -- cgit v1.2.3