diff options
| author | Patrick Simianer <p@simianer.de> | 2011-10-19 14:02:34 +0200 | 
|---|---|---|
| committer | Patrick Simianer <p@simianer.de> | 2011-10-19 14:02:34 +0200 | 
| commit | eb14e36d0b29f19321d44dd7dfa73cc703838d86 (patch) | |
| tree | 1285e9e56959bc3a4b506e36bbc3b49f4e938fa0 /klm/lm/search_hashed.hh | |
| parent | 68f158b11df9f4072699fe6a4c8022ea54102b28 (diff) | |
| parent | 04e38a57b19ea012895ac2efb39382c2e77833a9 (diff) | |
merge upstream/master
Diffstat (limited to 'klm/lm/search_hashed.hh')
| -rw-r--r-- | klm/lm/search_hashed.hh | 43 | 
1 files changed, 38 insertions, 5 deletions
| diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index c62985e4..e289fd11 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -1,15 +1,18 @@  #ifndef LM_SEARCH_HASHED__  #define LM_SEARCH_HASHED__ -#include "lm/binary_format.hh" +#include "lm/model_type.hh"  #include "lm/config.hh"  #include "lm/read_arpa.hh" +#include "lm/return.hh"  #include "lm/weights.hh" +#include "util/bit_packing.hh"  #include "util/key_value_packing.hh"  #include "util/probing_hash_table.hh"  #include <algorithm> +#include <iostream>  #include <vector>  namespace util { class FilePiece; } @@ -52,9 +55,14 @@ struct HashedSearch {    Unigram unigram; -  void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { +  void LookupUnigram(WordIndex word, float &backoff, Node &next, FullScoreReturn &ret) const {      const ProbBackoff &entry = unigram.Lookup(word); -    prob = entry.prob; +    util::FloatEnc val; +    val.f = entry.prob; +    ret.independent_left = (val.i & util::kSignBit); +    ret.extend_left = static_cast<uint64_t>(word); +    val.i |= util::kSignBit; +    ret.prob = val.f;      backoff = entry.backoff;      next = static_cast<Node>(word);    } @@ -67,6 +75,8 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has      typedef LongestT Longest;      Longest longest; +    static const unsigned int kVersion = 0; +      // TODO: move probing_multiplier here with next binary file format update.        static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} @@ -85,11 +95,33 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has      const Middle *MiddleBegin() const { return &*middle_.begin(); }      const Middle *MiddleEnd() const { return &*middle_.end(); } -    bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { +    Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const { +      util::FloatEnc val; +      if (extend_length == 1) { +        val.f = unigram.Lookup(static_cast<uint64_t>(extend_pointer)).prob; +      } else { +        typename Middle::ConstIterator found; +        if (!middle_[extend_length - 2].Find(extend_pointer, found)) { +          std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl; +          abort(); +        } +        val.f = found->GetValue().prob; +      } +      val.i |= util::kSignBit; +      prob = val.f; +      return extend_pointer; +    } + +    bool LookupMiddle(const Middle &middle, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const {        node = CombineWordHash(node, word);        typename Middle::ConstIterator found;        if (!middle.Find(node, found)) return false; -      prob = found->GetValue().prob; +      util::FloatEnc enc; +      enc.f = found->GetValue().prob; +      ret.independent_left = (enc.i & util::kSignBit); +      ret.extend_left = node; +      enc.i |= util::kSignBit; +      ret.prob = enc.f;        backoff = found->GetValue().backoff;        return true;      } @@ -105,6 +137,7 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has      }      bool LookupLongest(WordIndex word, float &prob, Node &node) const { +      // Sign bit is always on because longest n-grams do not extend left.          node = CombineWordHash(node, word);        typename Longest::ConstIterator found;        if (!longest.Find(node, found)) return false; | 
