summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_trie.hh')
-rw-r--r--klm/lm/search_trie.hh11
1 files changed, 8 insertions, 3 deletions
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh
index 902f6ce6..0f720217 100644
--- a/klm/lm/search_trie.hh
+++ b/klm/lm/search_trie.hh
@@ -9,6 +9,7 @@
namespace lm {
namespace ngram {
+struct Backing;
class SortedVocabulary;
namespace trie {
@@ -39,14 +40,18 @@ struct TrieSearch {
start += Unigram::Size(counts[0]);
middle.resize(counts.size() - 2);
for (unsigned char i = 1; i < counts.size() - 1; ++i) {
- middle[i-1].Init(start, counts[0], counts[i+1]);
+ middle[i-1].Init(
+ start,
+ counts[0],
+ counts[i+1],
+ (i == counts.size() - 2) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle[i]));
start += Middle::Size(counts[i], counts[0], counts[i+1]);
}
longest.Init(start, counts[0]);
return start + Longest::Size(counts.back(), counts[0]);
}
- void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab);
+ void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
return unigram.Find(word, prob, backoff, node);
@@ -65,7 +70,7 @@ struct TrieSearch {
}
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
- // TODO: don't decode prob.
+ // TODO: don't decode backoff.
assert(begin != end);
float ignored_prob, ignored_backoff;
LookupUnigram(*begin, ignored_prob, ignored_backoff, node);