summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.hh
blob: 33ae8cff5be6c4d427140c95c5122906df85563c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#ifndef LM_SEARCH_TRIE__
#define LM_SEARCH_TRIE__

#include "lm/config.hh"
#include "lm/model_type.hh"
#include "lm/return.hh"
#include "lm/trie.hh"
#include "lm/weights.hh"

#include "util/file_piece.hh"

#include <vector>

#include <assert.h>

namespace lm {
namespace ngram {
struct Backing;
class SortedVocabulary;
namespace trie {

template <class Quant, class Bhiksha> class TrieSearch;
template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);

template <class Quant, class Bhiksha> class TrieSearch {
  public:
    typedef NodeRange Node;

    typedef ::lm::ngram::trie::Unigram Unigram;
    Unigram unigram;

    typedef trie::BitPackedMiddle<typename Quant::Middle, Bhiksha> Middle;

    typedef trie::BitPackedLongest<typename Quant::Longest> Longest;
    Longest longest;

    static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);

    static const unsigned int kVersion = 1;

    static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
      Quant::UpdateConfigFromBinary(fd, counts, config);
      AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0]));
      Bhiksha::UpdateConfigFromBinary(fd, config);
    }

    static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
      std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
      for (unsigned char i = 1; i < counts.size() - 1; ++i) {
        ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);
      }
      return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
    }

    TrieSearch() : middle_begin_(NULL), middle_end_(NULL) {}

    ~TrieSearch() { FreeMiddles(); }

    uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);

    void LoadedBinary();

    const Middle *MiddleBegin() const { return middle_begin_; }
    const Middle *MiddleEnd() const { return middle_end_; }

    void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);

    void LookupUnigram(WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const {
      unigram.Find(word, ret.prob, backoff, node);
      ret.independent_left = (node.begin == node.end);
      ret.extend_left = static_cast<uint64_t>(word);
    }

    bool LookupMiddle(const Middle &mid, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const {
      if (!mid.Find(word, ret.prob, backoff, node, ret.extend_left)) return false;
      ret.independent_left = (node.begin == node.end);
      return true;
    }

    bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const {
      return mid.FindNoProb(word, backoff, node);
    }

    bool LookupLongest(WordIndex word, float &prob, const Node &node) const {
      return longest.Find(word, prob, node);
    }

    bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
      // TODO: don't decode backoff.
      assert(begin != end);
      FullScoreReturn ignored;
      float ignored_backoff;
      LookupUnigram(*begin, ignored_backoff, node, ignored);
      for (const WordIndex *i = begin + 1; i < end; ++i) {
        if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false;
      }
      return true;
    }

    Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const {
      if (extend_length == 1) {
        float ignored;
        Node ret;
        unigram.Find(static_cast<WordIndex>(extend_pointer), prob, ignored, ret);
        return ret;
      }
      return middle_begin_[extend_length - 2].ReadEntry(extend_pointer, prob);
    }

  private:
    friend void BuildTrie<Quant, Bhiksha>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);

    // Middles are managed manually so we can delay construction and they don't have to be copyable.  
    void FreeMiddles() {
      for (const Middle *i = middle_begin_; i != middle_end_; ++i) {
        i->~Middle();
      }
      free(middle_begin_);
    }

    Middle *middle_begin_, *middle_end_;
    Quant quant_;
};

} // namespace trie
} // namespace ngram
} // namespace lm

#endif // LM_SEARCH_TRIE__