diff options
Diffstat (limited to 'klm/lm')
| -rw-r--r-- | klm/lm/left.hh | 70 | ||||
| -rw-r--r-- | klm/lm/model.cc | 5 | ||||
| -rw-r--r-- | klm/lm/model.hh | 25 | 
3 files changed, 76 insertions, 24 deletions
| diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 837be765..effa0560 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -1,3 +1,40 @@ +/* Efficient left and right language model state for sentence fragments. + * Intended usage: + * Store ChartState with every chart entry.   + * To do a rule application: + * 1. Make a ChartState object for your new entry.   + * 2. Construct RuleScore.   + * 3. Going from left to right, call Terminal or NonTerminal.  + *   For terminals, just pass the vocab id.   + *   For non-terminals, pass that non-terminal's ChartState. + *     If your decoder expects scores inclusive of subtree scores (i.e. you + *     label entries with the highest-scoring path), pass the non-terminal's + *     score as prob.   + *     If your decoder expects relative scores and will walk the chart later, + *     pass prob = 0.0.   + *     In other words, the only effect of prob is that it gets added to the + *     returned log probability.   + * 4. Call Finish.  It returns the log probability.    + * + * There's a couple more details:  + * Do not pass <s> to Terminal as it is formally not a word in the sentence, + * only context.  Instead, call BeginSentence.  If called, it should be the + * first call after RuleScore is constructed (since <s> is always the + * leftmost). + * + * If the leftmost RHS is a non-terminal, it's faster to call BeginNonTerminal. + * + * Hashing and sorting comparison operators are provided.   All state objects + * are POD.  If you intend to use memcmp on raw state objects, you must call + * ZeroRemaining first, as the value of array entries beyond length is + * otherwise undefined.   + * + * Usage is of course not limited to chart decoding.  Anything that generates + * sentence fragments missing left context could benefit.  For example, a + * phrase-based decoder could pre-score phrases, storing ChartState with each + * phrase, even if hypotheses are generated left-to-right.   + */ +  #ifndef LM_LEFT__  #define LM_LEFT__ @@ -5,6 +42,8 @@  #include "lm/model.hh"  #include "lm/return.hh" +#include "util/murmur_hash.hh" +  #include <algorithm>  namespace lm { @@ -18,23 +57,30 @@ struct Left {    }    int Compare(const Left &other) const { -    if (length != other.length) { -      return (int)length - (int)other.length; -    } +    if (length != other.length) return length < other.length ? -1 : 1;      if (pointers[length - 1] > other.pointers[length - 1]) return 1;      if (pointers[length - 1] < other.pointers[length - 1]) return -1;      return 0;    } +  bool operator<(const Left &other) const { +    if (length != other.length) return length < other.length; +    return pointers[length - 1] < other.pointers[length - 1]; +  } +    void ZeroRemaining() {      for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i)        *i = 0;    } -  uint64_t pointers[kMaxOrder - 1];    unsigned char length; +  uint64_t pointers[kMaxOrder - 1];  }; +inline size_t hash_value(const Left &left) { +  return util::MurmurHashNative(&left.length, 1, left.pointers[left.length - 1]); +} +  struct ChartState {    bool operator==(const ChartState &other) {      return (left == other.left) && (right == other.right) && (full == other.full); @@ -48,16 +94,27 @@ struct ChartState {      return (int)full - (int)other.full;    } +  bool operator<(const ChartState &other) const { +    return Compare(other) == -1; +  } +    void ZeroRemaining() {      left.ZeroRemaining();      right.ZeroRemaining();    }    Left left; -  State right;    bool full; +  State right;  }; +inline size_t hash_value(const ChartState &state) { +  size_t hashes[2]; +  hashes[0] = hash_value(state.left); +  hashes[1] = hash_value(state.right); +  return util::MurmurHashNative(hashes, sizeof(size_t), state.full); +} +  template <class M> class RuleScore {    public:      explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) { @@ -73,8 +130,7 @@ template <class M> class RuleScore {      void Terminal(WordIndex word) {        State copy(out_.right); -      FullScoreReturn ret = model_.FullScore(copy, word, out_.right); -      ProcessRet(ret); +      ProcessRet(model_.FullScore(copy, word, out_.right));        if (out_.right.length != copy.length + 1) left_done_ = true;      } diff --git a/klm/lm/model.cc b/klm/lm/model.cc index ca581d8a..25f1ab7c 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -14,11 +14,6 @@  namespace lm {  namespace ngram { - -size_t hash_value(const State &state) { -  return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length); -} -  namespace detail {  template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType; diff --git a/klm/lm/model.hh b/klm/lm/model.hh index fe91af2e..c278acd6 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -12,6 +12,8 @@  #include "lm/vocab.hh"  #include "lm/weights.hh" +#include "util/murmur_hash.hh" +  #include <algorithm>  #include <vector> @@ -28,21 +30,18 @@ class State {    public:      bool operator==(const State &other) const {        if (length != other.length) return false; -      const WordIndex *end = words + length; -      for (const WordIndex *first = words, *second = other.words; -          first != end; ++first, ++second) { -        if (*first != *second) return false; -      } -      // If the histories are equal, so are the backoffs.   -      return true; +      return !memcmp(words, other.words, length * sizeof(WordIndex));      }      // Three way comparison function.        int Compare(const State &other) const { -      if (length == other.length) { -        return memcmp(words, other.words, length * sizeof(WordIndex)); -      } -      return (length < other.length) ? -1 : 1; +      if (length != other.length) return length < other.length ? -1 : 1; +      return memcmp(words, other.words, length * sizeof(WordIndex)); +    } + +    bool operator<(const State &other) const { +      if (length != other.length) return length < other.length; +      return memcmp(words, other.words, length * sizeof(WordIndex)) < 0;      }      // Call this before using raw memcmp.   @@ -62,7 +61,9 @@ class State {      unsigned char length;  }; -size_t hash_value(const State &state); +inline size_t hash_value(const State &state) { +  return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length); +}  namespace detail { | 
