diff options
| author | Kenneth Heafield <github@kheafield.com> | 2011-08-18 12:14:01 +0100 | 
|---|---|---|
| committer | Kenneth Heafield <github@kheafield.com> | 2011-08-18 12:14:01 +0100 | 
| commit | 7607b0a7873f52d6e3ea387bf88c773cbb55f8ee (patch) | |
| tree | 908fd94fea8d09725bc86ec9b3752b89c78338e5 /klm/lm/search_trie.hh | |
| parent | d92124ccc866192e4cdc689f2b41f0324d35dd3b (diff) | |
KenLM update: Bhiksha's trick, simple test for lms without unk, auto-detect binary files instead of requiring them to be specified at runtime.
Diffstat (limited to 'klm/lm/search_trie.hh')
| -rw-r--r-- | klm/lm/search_trie.hh | 20 | 
1 files changed, 11 insertions, 9 deletions
| diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 0a52acb5..2f39c09f 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -13,31 +13,33 @@ struct Backing;  class SortedVocabulary;  namespace trie { -template <class Quant> class TrieSearch; -template <class Quant> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing); +template <class Quant, class Bhiksha> class TrieSearch; +template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); -template <class Quant> class TrieSearch { +template <class Quant, class Bhiksha> class TrieSearch {    public:      typedef NodeRange Node;      typedef ::lm::ngram::trie::Unigram Unigram;      Unigram unigram; -    typedef trie::BitPackedMiddle<typename Quant::Middle> Middle; +    typedef trie::BitPackedMiddle<typename Quant::Middle, Bhiksha> Middle;      typedef trie::BitPackedLongest<typename Quant::Longest> Longest;      Longest longest; -    static const ModelType kModelType = Quant::kModelType; +    static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);      static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {        Quant::UpdateConfigFromBinary(fd, counts, config); +      AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); +      Bhiksha::UpdateConfigFromBinary(fd, config);      }      static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {        std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);        for (unsigned char i = 1; i < counts.size() - 1; ++i) { -        ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]); +        ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);        }        return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);      } @@ -55,8 +57,8 @@ template <class Quant> class TrieSearch {      void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); -    bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { -      return unigram.Find(word, prob, backoff, node); +    void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { +      unigram.Find(word, prob, backoff, node);      }      bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { @@ -83,7 +85,7 @@ template <class Quant> class TrieSearch {      }    private: -    friend void BuildTrie<Quant>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing); +    friend void BuildTrie<Quant, Bhiksha>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);      // Middles are managed manually so we can delay construction and they don't have to be copyable.        void FreeMiddles() { | 
