summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.hh
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2011-09-09 15:33:35 +0200
committerPatrick Simianer <p@simianer.de>2011-09-23 19:13:58 +0200
commitedb0cc0cbae1e75e4aeedb6360eab325effe6573 (patch)
treea2fed4614b88f177f91e88fef3b269fa75e80188 /klm/lm/search_trie.hh
parent2e6ef7cbec77b22ce3d64416a5ada3a6c081f9e2 (diff)
partial merge, ruleid feature
Diffstat (limited to 'klm/lm/search_trie.hh')
-rw-r--r--klm/lm/search_trie.hh20
1 files changed, 11 insertions, 9 deletions
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh
index 0a52acb5..2f39c09f 100644
--- a/klm/lm/search_trie.hh
+++ b/klm/lm/search_trie.hh
@@ -13,31 +13,33 @@ struct Backing;
class SortedVocabulary;
namespace trie {
-template <class Quant> class TrieSearch;
-template <class Quant> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing);
+template <class Quant, class Bhiksha> class TrieSearch;
+template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
-template <class Quant> class TrieSearch {
+template <class Quant, class Bhiksha> class TrieSearch {
public:
typedef NodeRange Node;
typedef ::lm::ngram::trie::Unigram Unigram;
Unigram unigram;
- typedef trie::BitPackedMiddle<typename Quant::Middle> Middle;
+ typedef trie::BitPackedMiddle<typename Quant::Middle, Bhiksha> Middle;
typedef trie::BitPackedLongest<typename Quant::Longest> Longest;
Longest longest;
- static const ModelType kModelType = Quant::kModelType;
+ static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);
static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
Quant::UpdateConfigFromBinary(fd, counts, config);
+ AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0]));
+ Bhiksha::UpdateConfigFromBinary(fd, config);
}
static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
for (unsigned char i = 1; i < counts.size() - 1; ++i) {
- ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]);
+ ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);
}
return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
@@ -55,8 +57,8 @@ template <class Quant> class TrieSearch {
void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
- bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
- return unigram.Find(word, prob, backoff, node);
+ void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
+ unigram.Find(word, prob, backoff, node);
}
bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const {
@@ -83,7 +85,7 @@ template <class Quant> class TrieSearch {
}
private:
- friend void BuildTrie<Quant>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing);
+ friend void BuildTrie<Quant, Bhiksha>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
// Middles are managed manually so we can delay construction and they don't have to be copyable.
void FreeMiddles() {