summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.hh
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-11-10 22:45:13 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-11-10 22:45:13 +0000
commit3204a77dfd5bff0b5c6d12a272ec939a882c7697 (patch)
tree75e901641477cb65e55fe9ac0a2745802b81ab23 /klm/lm/search_trie.hh
parentfd02041c3f1bd1157ecf7c0dbd1c444fb02aa313 (diff)
forgotten files
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@710 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'klm/lm/search_trie.hh')
-rw-r--r--klm/lm/search_trie.hh83
1 files changed, 83 insertions, 0 deletions
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh
new file mode 100644
index 00000000..902f6ce6
--- /dev/null
+++ b/klm/lm/search_trie.hh
@@ -0,0 +1,83 @@
+#ifndef LM_SEARCH_TRIE__
+#define LM_SEARCH_TRIE__
+
+#include "lm/binary_format.hh"
+#include "lm/trie.hh"
+#include "lm/weights.hh"
+
+#include <assert.h>
+
+namespace lm {
+namespace ngram {
+class SortedVocabulary;
+namespace trie {
+
+struct TrieSearch {
+ typedef NodeRange Node;
+
+ typedef ::lm::ngram::trie::Unigram Unigram;
+ Unigram unigram;
+
+ typedef trie::BitPackedMiddle Middle;
+ std::vector<Middle> middle;
+
+ typedef trie::BitPackedLongest Longest;
+ Longest longest;
+
+ static const ModelType kModelType = TRIE_SORTED;
+
+ static std::size_t Size(const std::vector<uint64_t> &counts, const Config &/*config*/) {
+ std::size_t ret = Unigram::Size(counts[0]);
+ for (unsigned char i = 1; i < counts.size() - 1; ++i) {
+ ret += Middle::Size(counts[i], counts[0], counts[i+1]);
+ }
+ return ret + Longest::Size(counts.back(), counts[0]);
+ }
+
+ uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &/*config*/) {
+ unigram.Init(start);
+ start += Unigram::Size(counts[0]);
+ middle.resize(counts.size() - 2);
+ for (unsigned char i = 1; i < counts.size() - 1; ++i) {
+ middle[i-1].Init(start, counts[0], counts[i+1]);
+ start += Middle::Size(counts[i], counts[0], counts[i+1]);
+ }
+ longest.Init(start, counts[0]);
+ return start + Longest::Size(counts.back(), counts[0]);
+ }
+
+ void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab);
+
+ bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
+ return unigram.Find(word, prob, backoff, node);
+ }
+
+ bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const {
+ return mid.Find(word, prob, backoff, node);
+ }
+
+ bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const {
+ return mid.FindNoProb(word, backoff, node);
+ }
+
+ bool LookupLongest(WordIndex word, float &prob, const Node &node) const {
+ return longest.Find(word, prob, node);
+ }
+
+ bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
+ // TODO: don't decode prob.
+ assert(begin != end);
+ float ignored_prob, ignored_backoff;
+ LookupUnigram(*begin, ignored_prob, ignored_backoff, node);
+ for (const WordIndex *i = begin + 1; i < end; ++i) {
+ if (!LookupMiddleNoProb(middle[i - begin - 1], *i, ignored_backoff, node)) return false;
+ }
+ return true;
+ }
+};
+
+} // namespace trie
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_SEARCH_TRIE__