From edb0cc0cbae1e75e4aeedb6360eab325effe6573 Mon Sep 17 00:00:00 2001
From: Patrick Simianer
Date: Fri, 9 Sep 2011 15:33:35 +0200
Subject: partial merge, ruleid feature
---
klm/lm/search_trie.hh | 20 +++++++++++---------
1 file changed, 11 insertions(+), 9 deletions(-)
(limited to 'klm/lm/search_trie.hh')
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh
index 0a52acb5..2f39c09f 100644
--- a/klm/lm/search_trie.hh
+++ b/klm/lm/search_trie.hh
@@ -13,31 +13,33 @@ struct Backing;
class SortedVocabulary;
namespace trie {
-template class TrieSearch;
-template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing);
+template class TrieSearch;
+template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
-template class TrieSearch {
+template class TrieSearch {
public:
typedef NodeRange Node;
typedef ::lm::ngram::trie::Unigram Unigram;
Unigram unigram;
- typedef trie::BitPackedMiddle Middle;
+ typedef trie::BitPackedMiddle Middle;
typedef trie::BitPackedLongest Longest;
Longest longest;
- static const ModelType kModelType = Quant::kModelType;
+ static const ModelType kModelType = static_cast(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);
static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) {
Quant::UpdateConfigFromBinary(fd, counts, config);
+ AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0]));
+ Bhiksha::UpdateConfigFromBinary(fd, config);
}
static std::size_t Size(const std::vector &counts, const Config &config) {
std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
for (unsigned char i = 1; i < counts.size() - 1; ++i) {
- ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]);
+ ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);
}
return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
@@ -55,8 +57,8 @@ template class TrieSearch {
void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
- bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
- return unigram.Find(word, prob, backoff, node);
+ void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
+ unigram.Find(word, prob, backoff, node);
}
bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const {
@@ -83,7 +85,7 @@ template class TrieSearch {
}
private:
- friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing);
+ friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
// Middles are managed manually so we can delay construction and they don't have to be copyable.
void FreeMiddles() {
--
cgit v1.2.3