From 3204a77dfd5bff0b5c6d12a272ec939a882c7697 Mon Sep 17 00:00:00 2001 From: redpony Date: Wed, 10 Nov 2010 22:45:13 +0000 Subject: forgotten files git-svn-id: https://ws10smt.googlecode.com/svn/trunk@710 ec762483-ff6d-05da-a07a-a48fb63a330f --- klm/lm/search_trie.hh | 83 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 klm/lm/search_trie.hh (limited to 'klm/lm/search_trie.hh') diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh new file mode 100644 index 00000000..902f6ce6 --- /dev/null +++ b/klm/lm/search_trie.hh @@ -0,0 +1,83 @@ +#ifndef LM_SEARCH_TRIE__ +#define LM_SEARCH_TRIE__ + +#include "lm/binary_format.hh" +#include "lm/trie.hh" +#include "lm/weights.hh" + +#include + +namespace lm { +namespace ngram { +class SortedVocabulary; +namespace trie { + +struct TrieSearch { + typedef NodeRange Node; + + typedef ::lm::ngram::trie::Unigram Unigram; + Unigram unigram; + + typedef trie::BitPackedMiddle Middle; + std::vector middle; + + typedef trie::BitPackedLongest Longest; + Longest longest; + + static const ModelType kModelType = TRIE_SORTED; + + static std::size_t Size(const std::vector &counts, const Config &/*config*/) { + std::size_t ret = Unigram::Size(counts[0]); + for (unsigned char i = 1; i < counts.size() - 1; ++i) { + ret += Middle::Size(counts[i], counts[0], counts[i+1]); + } + return ret + Longest::Size(counts.back(), counts[0]); + } + + uint8_t *SetupMemory(uint8_t *start, const std::vector &counts, const Config &/*config*/) { + unigram.Init(start); + start += Unigram::Size(counts[0]); + middle.resize(counts.size() - 2); + for (unsigned char i = 1; i < counts.size() - 1; ++i) { + middle[i-1].Init(start, counts[0], counts[i+1]); + start += Middle::Size(counts[i], counts[0], counts[i+1]); + } + longest.Init(start, counts[0]); + return start + Longest::Size(counts.back(), counts[0]); + } + + void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector &counts, const Config &config, SortedVocabulary &vocab); + + bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { + return unigram.Find(word, prob, backoff, node); + } + + bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { + return mid.Find(word, prob, backoff, node); + } + + bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { + return mid.FindNoProb(word, backoff, node); + } + + bool LookupLongest(WordIndex word, float &prob, const Node &node) const { + return longest.Find(word, prob, node); + } + + bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { + // TODO: don't decode prob. + assert(begin != end); + float ignored_prob, ignored_backoff; + LookupUnigram(*begin, ignored_prob, ignored_backoff, node); + for (const WordIndex *i = begin + 1; i < end; ++i) { + if (!LookupMiddleNoProb(middle[i - begin - 1], *i, ignored_backoff, node)) return false; + } + return true; + } +}; + +} // namespace trie +} // namespace ngram +} // namespace lm + +#endif // LM_SEARCH_TRIE__ -- cgit v1.2.3