summaryrefslogtreecommitdiff
path: root/klm/util/multi_intersection.hh
blob: 8334d39dfdddc655e682b1f6cad774a034c3e3b6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#ifndef UTIL_MULTI_INTERSECTION__
#define UTIL_MULTI_INTERSECTION__

#include <boost/optional.hpp>
#include <boost/range/iterator_range.hpp>

#include <algorithm>
#include <functional>
#include <vector>

namespace util {

namespace detail {
template <class Range> struct RangeLessBySize : public std::binary_function<const Range &, const Range &, bool> {
  bool operator()(const Range &left, const Range &right) const {
    return left.size() < right.size();
  }
};

/* Takes sets specified by their iterators and a boost::optional containing
 * the lowest intersection if any.  Each set must be sorted in increasing
 * order.  sets is changed to truncate the beginning of each sequence to the
 * location of the match or an empty set.  Precondition: sets is not empty
 * since the intersection over null is the universe and this function does not
 * know the universe.   
 */
template <class Iterator, class Less> boost::optional<typename std::iterator_traits<Iterator>::value_type> FirstIntersectionSorted(std::vector<boost::iterator_range<Iterator> > &sets, const Less &less = std::less<typename std::iterator_traits<Iterator>::value_type>()) {
  typedef std::vector<boost::iterator_range<Iterator> > Sets;
  typedef typename std::iterator_traits<Iterator>::value_type Value;

  assert(!sets.empty());

  if (sets.front().empty()) return boost::optional<Value>();
  // Possibly suboptimal to copy for general Value; makes unsigned int go slightly faster.  
  Value highest(sets.front().front());
  for (typename Sets::iterator i(sets.begin()); i != sets.end(); ) {
    i->advance_begin(std::lower_bound(i->begin(), i->end(), highest, less) - i->begin());
    if (i->empty()) return boost::optional<Value>();
    if (less(highest, i->front())) {
      highest = i->front();
      // start over
      i = sets.begin();
    } else {
      ++i;
    }
  }
  return boost::optional<Value>(highest);
}

} // namespace detail

template <class Iterator, class Less> boost::optional<typename std::iterator_traits<Iterator>::value_type> FirstIntersection(std::vector<boost::iterator_range<Iterator> > &sets, const Less less) {
  assert(!sets.empty());

  std::sort(sets.begin(), sets.end(), detail::RangeLessBySize<boost::iterator_range<Iterator> >());
  return detail::FirstIntersectionSorted(sets, less);
}

template <class Iterator> boost::optional<typename std::iterator_traits<Iterator>::value_type> FirstIntersection(std::vector<boost::iterator_range<Iterator> > &sets) {
  return FirstIntersection(sets, std::less<typename std::iterator_traits<Iterator>::value_type>());
}

template <class Iterator, class Output, class Less> void AllIntersection(std::vector<boost::iterator_range<Iterator> > &sets, Output &out, const Less less) {
  typedef typename std::iterator_traits<Iterator>::value_type Value;
  assert(!sets.empty());

  std::sort(sets.begin(), sets.end(), detail::RangeLessBySize<boost::iterator_range<Iterator> >());
  boost::optional<Value> ret;
  for (boost::optional<Value> ret; ret = detail::FirstIntersectionSorted(sets, less); sets.front().advance_begin(1)) {
    out(*ret);
  }
}

template <class Iterator, class Output> void AllIntersection(std::vector<boost::iterator_range<Iterator> > &sets, Output &out) {
  AllIntersection(sets, out, std::less<typename std::iterator_traits<Iterator>::value_type>());
}

} // namespace util

#endif // UTIL_MULTI_INTERSECTION__