summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_trie.hh')
-rw-r--r--klm/lm/search_trie.hh37
1 files changed, 30 insertions, 7 deletions
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh
index 2f39c09f..33ae8cff 100644
--- a/klm/lm/search_trie.hh
+++ b/klm/lm/search_trie.hh
@@ -1,10 +1,16 @@
#ifndef LM_SEARCH_TRIE__
#define LM_SEARCH_TRIE__
-#include "lm/binary_format.hh"
+#include "lm/config.hh"
+#include "lm/model_type.hh"
+#include "lm/return.hh"
#include "lm/trie.hh"
#include "lm/weights.hh"
+#include "util/file_piece.hh"
+
+#include <vector>
+
#include <assert.h>
namespace lm {
@@ -30,6 +36,8 @@ template <class Quant, class Bhiksha> class TrieSearch {
static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);
+ static const unsigned int kVersion = 1;
+
static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
Quant::UpdateConfigFromBinary(fd, counts, config);
AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0]));
@@ -57,12 +65,16 @@ template <class Quant, class Bhiksha> class TrieSearch {
void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
- void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
- unigram.Find(word, prob, backoff, node);
+ void LookupUnigram(WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const {
+ unigram.Find(word, ret.prob, backoff, node);
+ ret.independent_left = (node.begin == node.end);
+ ret.extend_left = static_cast<uint64_t>(word);
}
- bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const {
- return mid.Find(word, prob, backoff, node);
+ bool LookupMiddle(const Middle &mid, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const {
+ if (!mid.Find(word, ret.prob, backoff, node, ret.extend_left)) return false;
+ ret.independent_left = (node.begin == node.end);
+ return true;
}
bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const {
@@ -76,14 +88,25 @@ template <class Quant, class Bhiksha> class TrieSearch {
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
// TODO: don't decode backoff.
assert(begin != end);
- float ignored_prob, ignored_backoff;
- LookupUnigram(*begin, ignored_prob, ignored_backoff, node);
+ FullScoreReturn ignored;
+ float ignored_backoff;
+ LookupUnigram(*begin, ignored_backoff, node, ignored);
for (const WordIndex *i = begin + 1; i < end; ++i) {
if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false;
}
return true;
}
+ Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const {
+ if (extend_length == 1) {
+ float ignored;
+ Node ret;
+ unigram.Find(static_cast<WordIndex>(extend_pointer), prob, ignored, ret);
+ return ret;
+ }
+ return middle_begin_[extend_length - 2].ReadEntry(extend_pointer, prob);
+ }
+
private:
friend void BuildTrie<Quant, Bhiksha>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);