summaryrefslogtreecommitdiff
path: root/klm/lm/builder/sort.hh
blob: 712bb8e3537d37ea1272c1ede238337fc59f32e4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
#ifndef LM_BUILDER_SORT_H
#define LM_BUILDER_SORT_H

#include "lm/builder/ngram_stream.hh"
#include "lm/builder/ngram.hh"
#include "lm/word_index.hh"
#include "util/stream/sort.hh"

#include "util/stream/timer.hh"

#include <functional>
#include <string>

namespace lm {
namespace builder {

/**
 * Abstract parent class for defining custom n-gram comparators.
 */
template <class Child> class Comparator : public std::binary_function<const void *, const void *, bool> {
  public:
  
    /**
     * Constructs a comparator capable of comparing two n-grams.
     *
     * @param order Number of words in each n-gram
     */
    explicit Comparator(std::size_t order) : order_(order) {}

    /**
     * Applies the comparator using the Compare method that must be defined in any class that inherits from this class.
     *
     * @param lhs A pointer to the n-gram on the left-hand side of the comparison
     * @param rhs A pointer to the n-gram on the right-hand side of the comparison
     *
     * @see ContextOrder::Compare
     * @see PrefixOrder::Compare
     * @see SuffixOrder::Compare
     */
    inline bool operator()(const void *lhs, const void *rhs) const {
      return static_cast<const Child*>(this)->Compare(static_cast<const WordIndex*>(lhs), static_cast<const WordIndex*>(rhs));
    }

    /** Gets the n-gram order defined for this comparator. */
    std::size_t Order() const { return order_; }

  protected:
    std::size_t order_;
};

/**
 * N-gram comparator that compares n-grams according to their reverse (suffix) order.
 *
 * This comparator compares n-grams lexicographically, one word at a time, 
 * beginning with the last word of each n-gram and ending with the first word of each n-gram. 
 *
 * Some examples of n-gram comparisons as defined by this comparator:
 * - a b c == a b c
 * - a b c < a b d
 * - a b c > a d b
 * - a b c > a b b
 * - a b c > x a c
 * - a b c < x y z
 */
class SuffixOrder : public Comparator<SuffixOrder> {
  public:
  
    /** 
     * Constructs a comparator capable of comparing two n-grams.
     *
     * @param order Number of words in each n-gram
     */
    explicit SuffixOrder(std::size_t order) : Comparator<SuffixOrder>(order) {}

    /**
     * Compares two n-grams lexicographically, one word at a time, 
     * beginning with the last word of each n-gram and ending with the first word of each n-gram.
     *
     * @param lhs A pointer to the n-gram on the left-hand side of the comparison
     * @param rhs A pointer to the n-gram on the right-hand side of the comparison
     */
    inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
      for (std::size_t i = order_ - 1; i != 0; --i) {
        if (lhs[i] != rhs[i])
          return lhs[i] < rhs[i];
      }
      return lhs[0] < rhs[0];
    }

    static const unsigned kMatchOffset = 1;
};

  
/**
  * N-gram comparator that compares n-grams according to the reverse (suffix) order of the n-gram context.
  *
  * This comparator compares n-grams lexicographically, one word at a time, 
  * beginning with the penultimate word of each n-gram and ending with the first word of each n-gram;
  * finally, this comparator compares the last word of each n-gram.
  *
  * Some examples of n-gram comparisons as defined by this comparator:
  * - a b c == a b c
  * - a b c < a b d
  * - a b c < a d b
  * - a b c > a b b
  * - a b c > x a c
  * - a b c < x y z
  */
class ContextOrder : public Comparator<ContextOrder> {
  public:
  
    /** 
     * Constructs a comparator capable of comparing two n-grams.
     *
     * @param order Number of words in each n-gram
     */
    explicit ContextOrder(std::size_t order) : Comparator<ContextOrder>(order) {}

    /**
     * Compares two n-grams lexicographically, one word at a time, 
     * beginning with the penultimate word of each n-gram and ending with the first word of each n-gram;
     * finally, this comparator compares the last word of each n-gram.
     *
     * @param lhs A pointer to the n-gram on the left-hand side of the comparison
     * @param rhs A pointer to the n-gram on the right-hand side of the comparison
     */
    inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
      for (int i = order_ - 2; i >= 0; --i) {
        if (lhs[i] != rhs[i])
          return lhs[i] < rhs[i];
      }
      return lhs[order_ - 1] < rhs[order_ - 1];
    }
};

/**
 * N-gram comparator that compares n-grams according to their natural (prefix) order.
 *
 * This comparator compares n-grams lexicographically, one word at a time, 
 * beginning with the first word of each n-gram and ending with the last word of each n-gram.
 *
 * Some examples of n-gram comparisons as defined by this comparator:
 * - a b c == a b c
 * - a b c < a b d
 * - a b c < a d b
 * - a b c > a b b
 * - a b c < x a c
 * - a b c < x y z
 */
class PrefixOrder : public Comparator<PrefixOrder> {
  public:
  
    /** 
     * Constructs a comparator capable of comparing two n-grams.
     *
     * @param order Number of words in each n-gram
     */
    explicit PrefixOrder(std::size_t order) : Comparator<PrefixOrder>(order) {}

    /**
     * Compares two n-grams lexicographically, one word at a time, 
     * beginning with the first word of each n-gram and ending with the last word of each n-gram.
     *
     * @param lhs A pointer to the n-gram on the left-hand side of the comparison
     * @param rhs A pointer to the n-gram on the right-hand side of the comparison
     */
    inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const {
      for (std::size_t i = 0; i < order_; ++i) {
        if (lhs[i] != rhs[i])
          return lhs[i] < rhs[i];
      }
      return false;
    }
    
    static const unsigned kMatchOffset = 0;
};

// Sum counts for the same n-gram.
struct AddCombiner {
  bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const {
    NGram first(first_void, compare.Order());
    // There isn't a const version of NGram.  
    NGram second(const_cast<void*>(second_void), compare.Order());
    if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false;
    first.Count() += second.Count();
    return true;
  }
};

// The combiner is only used on a single chain, so I didn't bother to allow
// that template.
/**
 * Represents an @ref util::FixedArray "array" capable of storing @ref util::stream::Sort "Sort" objects.
 *
 * In the anticipated use case, an instance of this class will maintain one @ref util::stream::Sort "Sort" object
 * for each n-gram order (ranging from 1 up to the maximum n-gram order being processed).
 * Use in this manner would enable the n-grams each n-gram order to be sorted, in parallel.
 *
 * @tparam Compare An @ref Comparator "ngram comparator" to use during sorting.
 */
template <class Compare> class Sorts : public util::FixedArray<util::stream::Sort<Compare> > {
  private:
    typedef util::stream::Sort<Compare> S;
    typedef util::FixedArray<S> P;

  public:
  
    /**
     * Constructs, but does not initialize.
     * 
     * @ref util::FixedArray::Init() "Init" must be called before use.
     *
     * @see util::FixedArray::Init()
     */
    Sorts() {}

    /**
     * Constructs an @ref util::FixedArray "array" capable of storing a fixed number of @ref util::stream::Sort "Sort" objects.
     *
     * @param number The maximum number of @ref util::stream::Sort "sorters" that can be held by this @ref util::FixedArray "array"
     * @see util::FixedArray::FixedArray()
     */
    explicit Sorts(std::size_t number) : util::FixedArray<util::stream::Sort<Compare> >(number) {}

    /** 
     * Constructs a new @ref util::stream::Sort "Sort" object which is stored in this @ref util::FixedArray "array".
     *
     * The new @ref util::stream::Sort "Sort" object is constructed using the provided @ref util::stream::SortConfig "SortConfig" and @ref Comparator "ngram comparator";
     * once constructed, a new worker @ref util::stream::Thread "thread" (owned by the @ref util::stream::Chain "chain") will sort the n-gram data stored
     * in the @ref util::stream::Block "blocks" of the provided @ref util::stream::Chain "chain".
     *
     * @see util::stream::Sort::Sort()
     * @see util::stream::Chain::operator>>()
     */
    void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) {
      new (P::end()) S(chain, config, compare); // use "placement new" syntax to initalize S in an already-allocated memory location
      P::Constructed();
    }
};

} // namespace builder
} // namespace lm

#endif // LM_BUILDER_SORT_H