diff options
Diffstat (limited to 'klm/util/joint_sort.hh')
-rw-r--r-- | klm/util/joint_sort.hh | 145 |
1 files changed, 145 insertions, 0 deletions
diff --git a/klm/util/joint_sort.hh b/klm/util/joint_sort.hh new file mode 100644 index 00000000..a2f1c01d --- /dev/null +++ b/klm/util/joint_sort.hh @@ -0,0 +1,145 @@ +#ifndef UTIL_JOINT_SORT__ +#define UTIL_JOINT_SORT__ + +/* A terrifying amount of C++ to coax std::sort into soring one range while + * also permuting another range the same way. + */ + +#include "util/proxy_iterator.hh" + +#include <algorithm> +#include <functional> +#include <iostream> + +namespace util { + +namespace detail { + +template <class KeyIter, class ValueIter> class JointProxy; + +template <class KeyIter, class ValueIter> class JointIter { + public: + JointIter() {} + + JointIter(const KeyIter &key_iter, const ValueIter &value_iter) : key_(key_iter), value_(value_iter) {} + + bool operator==(const JointIter<KeyIter, ValueIter> &other) const { return key_ == other.key_; } + + bool operator<(const JointIter<KeyIter, ValueIter> &other) const { return (key_ < other.key_); } + + std::ptrdiff_t operator-(const JointIter<KeyIter, ValueIter> &other) const { return key_ - other.key_; } + + JointIter<KeyIter, ValueIter> &operator+=(std::ptrdiff_t amount) { + key_ += amount; + value_ += amount; + return *this; + } + + void swap(const JointIter &other) { + std::swap(key_, other.key_); + std::swap(value_, other.value_); + } + + private: + friend class JointProxy<KeyIter, ValueIter>; + KeyIter key_; + ValueIter value_; +}; + +template <class KeyIter, class ValueIter> class JointProxy { + private: + typedef JointIter<KeyIter, ValueIter> InnerIterator; + + public: + typedef struct { + typename std::iterator_traits<KeyIter>::value_type key; + typename std::iterator_traits<ValueIter>::value_type value; + const typename std::iterator_traits<KeyIter>::value_type &GetKey() const { return key; } + } value_type; + + JointProxy(const KeyIter &key_iter, const ValueIter &value_iter) : inner_(key_iter, value_iter) {} + JointProxy(const JointProxy<KeyIter, ValueIter> &other) : inner_(other.inner_) {} + + operator const value_type() const { + value_type ret; + ret.key = *inner_.key_; + ret.value = *inner_.value_; + return ret; + } + + JointProxy &operator=(const JointProxy &other) { + *inner_.key_ = *other.inner_.key_; + *inner_.value_ = *other.inner_.value_; + return *this; + } + + JointProxy &operator=(const value_type &other) { + *inner_.key_ = other.key; + *inner_.value_ = other.value; + return *this; + } + + typename std::iterator_traits<KeyIter>::reference GetKey() const { + return *(inner_.key_); + } + + void swap(JointProxy<KeyIter, ValueIter> &other) { + std::swap(*inner_.key_, *other.inner_.key_); + std::swap(*inner_.value_, *other.inner_.value_); + } + + private: + friend class ProxyIterator<JointProxy<KeyIter, ValueIter> >; + + InnerIterator &Inner() { return inner_; } + const InnerIterator &Inner() const { return inner_; } + InnerIterator inner_; +}; + +template <class Proxy, class Less> class LessWrapper : public std::binary_function<const typename Proxy::value_type &, const typename Proxy::value_type &, bool> { + public: + explicit LessWrapper(const Less &less) : less_(less) {} + + bool operator()(const Proxy &left, const Proxy &right) const { + return less_(left.GetKey(), right.GetKey()); + } + bool operator()(const Proxy &left, const typename Proxy::value_type &right) const { + return less_(left.GetKey(), right.GetKey()); + } + bool operator()(const typename Proxy::value_type &left, const Proxy &right) const { + return less_(left.GetKey(), right.GetKey()); + } + bool operator()(const typename Proxy::value_type &left, const typename Proxy::value_type &right) const { + return less_(left.GetKey(), right.GetKey()); + } + + private: + const Less less_; +}; + +} // namespace detail + +template <class KeyIter, class ValueIter, class Less> void JointSort(const KeyIter &key_begin, const KeyIter &key_end, const ValueIter &value_begin, const Less &less) { + ProxyIterator<detail::JointProxy<KeyIter, ValueIter> > full_begin(detail::JointProxy<KeyIter, ValueIter>(key_begin, value_begin)); + detail::LessWrapper<detail::JointProxy<KeyIter, ValueIter>, Less> less_wrap(less); + std::sort(full_begin, full_begin + (key_end - key_begin), less_wrap); +} + + +template <class KeyIter, class ValueIter> void JointSort(const KeyIter &key_begin, const KeyIter &key_end, const ValueIter &value_begin) { + JointSort(key_begin, key_end, value_begin, std::less<typename std::iterator_traits<KeyIter>::value_type>()); +} + +} // namespace util + +namespace std { +template <class KeyIter, class ValueIter> void swap(util::detail::JointIter<KeyIter, ValueIter> &left, util::detail::JointIter<KeyIter, ValueIter> &right) { + left.swap(right); +} + +template <class KeyIter, class ValueIter> void swap(util::detail::JointProxy<KeyIter, ValueIter> &left, util::detail::JointProxy<KeyIter, ValueIter> &right) { + left.swap(right); +} +} // namespace std + +#endif // UTIL_JOINT_SORT__ |