summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.hh
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2011-08-18 12:14:01 +0100
committerKenneth Heafield <github@kheafield.com>2011-08-18 12:14:01 +0100
commit7607b0a7873f52d6e3ea387bf88c773cbb55f8ee (patch)
tree908fd94fea8d09725bc86ec9b3752b89c78338e5 /klm/lm/search_trie.hh
parentd92124ccc866192e4cdc689f2b41f0324d35dd3b (diff)
KenLM update: Bhiksha's trick, simple test for lms without unk, auto-detect binary files instead of requiring them to be specified at runtime.
Diffstat (limited to 'klm/lm/search_trie.hh')
-rw-r--r--klm/lm/search_trie.hh20
1 files changed, 11 insertions, 9 deletions
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh
index 0a52acb5..2f39c09f 100644
--- a/klm/lm/search_trie.hh
+++ b/klm/lm/search_trie.hh
@@ -13,31 +13,33 @@ struct Backing;
class SortedVocabulary;
namespace trie {
-template <class Quant> class TrieSearch;
-template <class Quant> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing);
+template <class Quant, class Bhiksha> class TrieSearch;
+template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
-template <class Quant> class TrieSearch {
+template <class Quant, class Bhiksha> class TrieSearch {
public:
typedef NodeRange Node;
typedef ::lm::ngram::trie::Unigram Unigram;
Unigram unigram;
- typedef trie::BitPackedMiddle<typename Quant::Middle> Middle;
+ typedef trie::BitPackedMiddle<typename Quant::Middle, Bhiksha> Middle;
typedef trie::BitPackedLongest<typename Quant::Longest> Longest;
Longest longest;
- static const ModelType kModelType = Quant::kModelType;
+ static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);
static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
Quant::UpdateConfigFromBinary(fd, counts, config);
+ AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0]));
+ Bhiksha::UpdateConfigFromBinary(fd, config);
}
static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
for (unsigned char i = 1; i < counts.size() - 1; ++i) {
- ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]);
+ ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);
}
return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
@@ -55,8 +57,8 @@ template <class Quant> class TrieSearch {
void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
- bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
- return unigram.Find(word, prob, backoff, node);
+ void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
+ unigram.Find(word, prob, backoff, node);
}
bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const {
@@ -83,7 +85,7 @@ template <class Quant> class TrieSearch {
}
private:
- friend void BuildTrie<Quant>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing);
+ friend void BuildTrie<Quant, Bhiksha>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
// Middles are managed manually so we can delay construction and they don't have to be copyable.
void FreeMiddles() {