summaryrefslogtreecommitdiff
path: root/klm/util/joint_sort.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/util/joint_sort.hh')
-rw-r--r--klm/util/joint_sort.hh145
1 files changed, 145 insertions, 0 deletions
diff --git a/klm/util/joint_sort.hh b/klm/util/joint_sort.hh
new file mode 100644
index 00000000..a2f1c01d
--- /dev/null
+++ b/klm/util/joint_sort.hh
@@ -0,0 +1,145 @@
+#ifndef UTIL_JOINT_SORT__
+#define UTIL_JOINT_SORT__
+
+/* A terrifying amount of C++ to coax std::sort into soring one range while
+ * also permuting another range the same way.
+ */
+
+#include "util/proxy_iterator.hh"
+
+#include <algorithm>
+#include <functional>
+#include <iostream>
+
+namespace util {
+
+namespace detail {
+
+template <class KeyIter, class ValueIter> class JointProxy;
+
+template <class KeyIter, class ValueIter> class JointIter {
+ public:
+ JointIter() {}
+
+ JointIter(const KeyIter &key_iter, const ValueIter &value_iter) : key_(key_iter), value_(value_iter) {}
+
+ bool operator==(const JointIter<KeyIter, ValueIter> &other) const { return key_ == other.key_; }
+
+ bool operator<(const JointIter<KeyIter, ValueIter> &other) const { return (key_ < other.key_); }
+
+ std::ptrdiff_t operator-(const JointIter<KeyIter, ValueIter> &other) const { return key_ - other.key_; }
+
+ JointIter<KeyIter, ValueIter> &operator+=(std::ptrdiff_t amount) {
+ key_ += amount;
+ value_ += amount;
+ return *this;
+ }
+
+ void swap(const JointIter &other) {
+ std::swap(key_, other.key_);
+ std::swap(value_, other.value_);
+ }
+
+ private:
+ friend class JointProxy<KeyIter, ValueIter>;
+ KeyIter key_;
+ ValueIter value_;
+};
+
+template <class KeyIter, class ValueIter> class JointProxy {
+ private:
+ typedef JointIter<KeyIter, ValueIter> InnerIterator;
+
+ public:
+ typedef struct {
+ typename std::iterator_traits<KeyIter>::value_type key;
+ typename std::iterator_traits<ValueIter>::value_type value;
+ const typename std::iterator_traits<KeyIter>::value_type &GetKey() const { return key; }
+ } value_type;
+
+ JointProxy(const KeyIter &key_iter, const ValueIter &value_iter) : inner_(key_iter, value_iter) {}
+ JointProxy(const JointProxy<KeyIter, ValueIter> &other) : inner_(other.inner_) {}
+
+ operator const value_type() const {
+ value_type ret;
+ ret.key = *inner_.key_;
+ ret.value = *inner_.value_;
+ return ret;
+ }
+
+ JointProxy &operator=(const JointProxy &other) {
+ *inner_.key_ = *other.inner_.key_;
+ *inner_.value_ = *other.inner_.value_;
+ return *this;
+ }
+
+ JointProxy &operator=(const value_type &other) {
+ *inner_.key_ = other.key;
+ *inner_.value_ = other.value;
+ return *this;
+ }
+
+ typename std::iterator_traits<KeyIter>::reference GetKey() const {
+ return *(inner_.key_);
+ }
+
+ void swap(JointProxy<KeyIter, ValueIter> &other) {
+ std::swap(*inner_.key_, *other.inner_.key_);
+ std::swap(*inner_.value_, *other.inner_.value_);
+ }
+
+ private:
+ friend class ProxyIterator<JointProxy<KeyIter, ValueIter> >;
+
+ InnerIterator &Inner() { return inner_; }
+ const InnerIterator &Inner() const { return inner_; }
+ InnerIterator inner_;
+};
+
+template <class Proxy, class Less> class LessWrapper : public std::binary_function<const typename Proxy::value_type &, const typename Proxy::value_type &, bool> {
+ public:
+ explicit LessWrapper(const Less &less) : less_(less) {}
+
+ bool operator()(const Proxy &left, const Proxy &right) const {
+ return less_(left.GetKey(), right.GetKey());
+ }
+ bool operator()(const Proxy &left, const typename Proxy::value_type &right) const {
+ return less_(left.GetKey(), right.GetKey());
+ }
+ bool operator()(const typename Proxy::value_type &left, const Proxy &right) const {
+ return less_(left.GetKey(), right.GetKey());
+ }
+ bool operator()(const typename Proxy::value_type &left, const typename Proxy::value_type &right) const {
+ return less_(left.GetKey(), right.GetKey());
+ }
+
+ private:
+ const Less less_;
+};
+
+} // namespace detail
+
+template <class KeyIter, class ValueIter, class Less> void JointSort(const KeyIter &key_begin, const KeyIter &key_end, const ValueIter &value_begin, const Less &less) {
+ ProxyIterator<detail::JointProxy<KeyIter, ValueIter> > full_begin(detail::JointProxy<KeyIter, ValueIter>(key_begin, value_begin));
+ detail::LessWrapper<detail::JointProxy<KeyIter, ValueIter>, Less> less_wrap(less);
+ std::sort(full_begin, full_begin + (key_end - key_begin), less_wrap);
+}
+
+
+template <class KeyIter, class ValueIter> void JointSort(const KeyIter &key_begin, const KeyIter &key_end, const ValueIter &value_begin) {
+ JointSort(key_begin, key_end, value_begin, std::less<typename std::iterator_traits<KeyIter>::value_type>());
+}
+
+} // namespace util
+
+namespace std {
+template <class KeyIter, class ValueIter> void swap(util::detail::JointIter<KeyIter, ValueIter> &left, util::detail::JointIter<KeyIter, ValueIter> &right) {
+ left.swap(right);
+}
+
+template <class KeyIter, class ValueIter> void swap(util::detail::JointProxy<KeyIter, ValueIter> &left, util::detail::JointProxy<KeyIter, ValueIter> &right) {
+ left.swap(right);
+}
+} // namespace std
+
+#endif // UTIL_JOINT_SORT__