summaryrefslogtreecommitdiff
path: root/klm/lm/trie_sort.hh
blob: 3036319df0328808dfcb8b6e40d465d2f19e06bd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
// Step of trie builder: create sorted files.  

#ifndef LM_TRIE_SORT__
#define LM_TRIE_SORT__

#include "lm/max_order.hh"
#include "lm/word_index.hh"

#include "util/file.hh"
#include "util/scoped.hh"

#include <cstddef>
#include <functional>
#include <string>
#include <vector>

#include <stdint.h>

namespace util {
class FilePiece;
class TempMaker;
} // namespace util

namespace lm {
class PositiveProbWarn;
namespace ngram {
class SortedVocabulary;
class Config;

namespace trie {

void WriteOrThrow(FILE *to, const void *data, size_t size);

class EntryCompare : public std::binary_function<const void*, const void*, bool> {
  public:
    explicit EntryCompare(unsigned char order) : order_(order) {}

    bool operator()(const void *first_void, const void *second_void) const {
      const WordIndex *first = static_cast<const WordIndex*>(first_void);
      const WordIndex *second = static_cast<const WordIndex*>(second_void);
      const WordIndex *end = first + order_;
      for (; first != end; ++first, ++second) {
        if (*first < *second) return true;
        if (*first > *second) return false;
      }
      return false;
    }
  private:
    unsigned char order_;
};

class RecordReader {
  public:
    RecordReader() : remains_(true) {}

    void Init(FILE *file, std::size_t entry_size);

    void *Data() { return data_.get(); }
    const void *Data() const { return data_.get(); }

    RecordReader &operator++() {
      std::size_t ret = fread(data_.get(), entry_size_, 1, file_);
      if (!ret) {
        UTIL_THROW_IF(!feof(file_), util::ErrnoException, "Error reading temporary file");
        remains_ = false;
      }
      return *this;
    }

    operator bool() const { return remains_; }

    void Rewind();

    std::size_t EntrySize() const { return entry_size_; }

    void Overwrite(const void *start, std::size_t amount);

  private:
    FILE *file_;

    util::scoped_malloc data_;

    bool remains_;

    std::size_t entry_size_;
};

class SortedFiles {
  public:
    // Build from ARPA
    SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab);

    int StealUnigram() {
      return unigram_.release();
    }

    FILE *Full(unsigned char order) {
      return full_[order - 2].get();
    }

    FILE *Context(unsigned char of_order) {
      return context_[of_order - 2].get();
    }

  private:
    void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size);
    
    util::scoped_fd unigram_;

    util::scoped_FILE full_[kMaxOrder - 1], context_[kMaxOrder - 1];
};

} // namespace trie
} // namespace ngram
} // namespace lm

#endif // LM_TRIE_SORT__