diff options
Diffstat (limited to 'klm')
| -rw-r--r-- | klm/lm/enumerate_vocab.hh | 29 | ||||
| -rw-r--r-- | klm/lm/model.hh | 126 | ||||
| -rw-r--r-- | klm/lm/search_hashed.hh | 156 | ||||
| -rw-r--r-- | klm/lm/search_trie.hh | 83 | ||||
| -rw-r--r-- | klm/lm/trie.hh | 129 | ||||
| -rw-r--r-- | klm/lm/vocab.hh | 138 | 
6 files changed, 661 insertions, 0 deletions
| diff --git a/klm/lm/enumerate_vocab.hh b/klm/lm/enumerate_vocab.hh new file mode 100644 index 00000000..7a2f7d12 --- /dev/null +++ b/klm/lm/enumerate_vocab.hh @@ -0,0 +1,29 @@ +#ifndef LM_ENUMERATE_VOCAB__ +#define LM_ENUMERATE_VOCAB__ + +#include "lm/word_index.hh" +#include "util/string_piece.hh" + +namespace lm { +namespace ngram { + +/* If you need the actual strings in the vocabulary, inherit from this class + * and implement Add.  Then put a pointer in Config.enumerate_vocab.   + * Add is called once per n-gram.  index starts at 0 and increases by 1 each + * time.   + */ +class EnumerateVocab { +  public: +    virtual ~EnumerateVocab() {} + +    virtual void Add(WordIndex index, const StringPiece &str) = 0; + +  protected: +    EnumerateVocab() {} +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_ENUMERATE_VOCAB__ + diff --git a/klm/lm/model.hh b/klm/lm/model.hh new file mode 100644 index 00000000..e0eeee17 --- /dev/null +++ b/klm/lm/model.hh @@ -0,0 +1,126 @@ +#ifndef LM_MODEL__ +#define LM_MODEL__ + +#include "lm/binary_format.hh" +#include "lm/config.hh" +#include "lm/facade.hh" +#include "lm/search_hashed.hh" +#include "lm/search_trie.hh" +#include "lm/vocab.hh" +#include "lm/weights.hh" + +#include <algorithm> +#include <vector> + +namespace util { class FilePiece; } + +namespace lm { +namespace ngram { + +// If you need higher order, change this and recompile.   +// Having this limit means that State can be +// (kMaxOrder - 1) * sizeof(float) bytes instead of +// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead +const std::size_t kMaxOrder = 6; + +// This is a POD.   +class State { +  public: +    bool operator==(const State &other) const { +      if (valid_length_ != other.valid_length_) return false; +      const WordIndex *end = history_ + valid_length_; +      for (const WordIndex *first = history_, *second = other.history_; +          first != end; ++first, ++second) { +        if (*first != *second) return false; +      } +      // If the histories are equal, so are the backoffs.   +      return true; +    } + +    // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.   +    // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.   +    WordIndex history_[kMaxOrder - 1]; +    float backoff_[kMaxOrder - 1]; +    unsigned char valid_length_; +}; + +size_t hash_value(const State &state); + +namespace detail { + +// Should return the same results as SRI.   +// Why VocabularyT instead of just Vocabulary?  ModelFacade defines Vocabulary.   +template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> { +  private: +    typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P; +  public: +    // Get the size of memory that will be mapped given ngram counts.  This +    // does not include small non-mapped control structures, such as this class +    // itself.   +    static size_t Size(const std::vector<uint64_t> &counts, const Config &config = Config()); + +    GenericModel(const char *file, const Config &config = Config()); + +    FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const; + +    /* Slower call without in_state.  Don't use this if you can avoid it.  This +     * is mostly a hack for Hieu to integrate it into Moses which sometimes +     * forgets LM state (i.e. it doesn't store it with the phrase).  Sigh.    +     * The context indices should be in an array.   +     * If context_rbegin != context_rend then *context_rbegin is the word +     * before new_word.   +     */ +    FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; + +    /* Get the state for a context.  Don't use this if you can avoid it.  Use +     * BeginSentenceState or EmptyContextState and extend from those.  If +     * you're only going to use this state to call FullScore once, use +     * FullScoreForgotState. */ +    void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const; + +  private: +    friend void LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to); + +    float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const; + +    FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, unsigned char &backoff_start, State &out_state) const; + +    // Appears after Size in the cc file. +    void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config); + +    void InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd); + +    void InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters ¶ms, const Config &config); + +    Backing &MutableBacking() { return backing_; } + +    static const ModelType kModelType = Search::kModelType; + +    Backing backing_; +     +    VocabularyT vocab_; + +    typedef typename Search::Unigram Unigram; +    typedef typename Search::Middle Middle; +    typedef typename Search::Longest Longest; + +    Search search_; +}; + +} // namespace detail + +// These must also be instantiated in the cc file.   +typedef ::lm::ngram::ProbingVocabulary Vocabulary; +typedef detail::GenericModel<detail::ProbingHashedSearch, Vocabulary> ProbingModel; +// Default implementation.  No real reason for it to be the default.   +typedef ProbingModel Model; + +typedef ::lm::ngram::SortedVocabulary SortedVocabulary; +typedef detail::GenericModel<detail::SortedHashedSearch, SortedVocabulary> SortedModel; + +typedef detail::GenericModel<trie::TrieSearch, SortedVocabulary> TrieModel; + +} // namespace ngram +} // namespace lm + +#endif // LM_MODEL__ diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh new file mode 100644 index 00000000..1ee2b9e9 --- /dev/null +++ b/klm/lm/search_hashed.hh @@ -0,0 +1,156 @@ +#ifndef LM_SEARCH_HASHED__ +#define LM_SEARCH_HASHED__ + +#include "lm/binary_format.hh" +#include "lm/config.hh" +#include "lm/read_arpa.hh" +#include "lm/weights.hh" + +#include "util/key_value_packing.hh" +#include "util/probing_hash_table.hh" +#include "util/sorted_uniform.hh" + +#include <algorithm> +#include <vector> + +namespace util { class FilePiece; } + +namespace lm { +namespace ngram { +namespace detail { + +inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { +  uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL); +  return ret; +} + +struct HashedSearch { +  typedef uint64_t Node; + +  class Unigram { +    public: +      Unigram() {} + +      Unigram(void *start, std::size_t /*allocated*/) : unigram_(static_cast<ProbBackoff*>(start)) {} + +      static std::size_t Size(uint64_t count) { +        return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk> +      } + +      const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index]; } + +      ProbBackoff &Unknown() { return unigram_[0]; } + +      void LoadedBinary() {} + +      // For building. +      ProbBackoff *Raw() { return unigram_; } + +    private: +      ProbBackoff *unigram_; +  }; + +  Unigram unigram; + +  bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { +    const ProbBackoff &entry = unigram.Lookup(word); +    prob = entry.prob; +    backoff = entry.backoff; +    next = static_cast<Node>(word); +    return true; +  } +}; + +template <class MiddleT, class LongestT> struct TemplateHashedSearch : public HashedSearch { +  typedef MiddleT Middle; +  std::vector<Middle> middle; + +  typedef LongestT Longest; +  Longest longest; + +  static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) { +    std::size_t ret = Unigram::Size(counts[0]); +    for (unsigned char n = 1; n < counts.size() - 1; ++n) { +      ret += Middle::Size(counts[n], config.probing_multiplier); +    } +    return ret + Longest::Size(counts.back(), config.probing_multiplier); +  } + +  uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { +    std::size_t allocated = Unigram::Size(counts[0]); +    unigram = Unigram(start, allocated); +    start += allocated; +    for (unsigned int n = 2; n < counts.size(); ++n) { +      allocated = Middle::Size(counts[n - 1], config.probing_multiplier); +      middle.push_back(Middle(start, allocated)); +      start += allocated; +    } +    allocated = Longest::Size(counts.back(), config.probing_multiplier); +    longest = Longest(start, allocated); +    start += allocated; +    return start; +  } + +  template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab); + +  bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { +    node = CombineWordHash(node, word); +    typename Middle::ConstIterator found; +    if (!middle.Find(node, found)) return false; +    prob = found->GetValue().prob; +    backoff = found->GetValue().backoff; +    return true; +  } + +  bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const { +    node = CombineWordHash(node, word); +    typename Middle::ConstIterator found; +    if (!middle.Find(node, found)) return false; +    backoff = found->GetValue().backoff; +    return true; +  } + +  bool LookupLongest(WordIndex word, float &prob, Node &node) const { +    node = CombineWordHash(node, word); +    typename Longest::ConstIterator found; +    if (!longest.Find(node, found)) return false; +    prob = found->GetValue().prob; +    return true; +  } + +  // Geenrate a node without necessarily checking that it actually exists.   +  // Optionally return false if it's know to not exist.   +  bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { +    assert(begin != end); +    node = static_cast<Node>(*begin); +    for (const WordIndex *i = begin + 1; i < end; ++i) { +      node = CombineWordHash(node, *i); +    } +    return true; +  } +}; + +// std::identity is an SGI extension :-( +struct IdentityHash : public std::unary_function<uint64_t, size_t> { +  size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); } +}; + +struct ProbingHashedSearch : public TemplateHashedSearch< +  util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, ProbBackoff>, IdentityHash>, +  util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, Prob>, IdentityHash> > { + +  static const ModelType kModelType = HASH_PROBING; +}; + +struct SortedHashedSearch : public TemplateHashedSearch< +  util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, ProbBackoff> >, +  util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, Prob> > > { +   +  static const ModelType kModelType = HASH_SORTED; +}; + +} // namespace detail +} // namespace ngram +} // namespace lm + +#endif // LM_SEARCH_HASHED__ diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh new file mode 100644 index 00000000..902f6ce6 --- /dev/null +++ b/klm/lm/search_trie.hh @@ -0,0 +1,83 @@ +#ifndef LM_SEARCH_TRIE__ +#define LM_SEARCH_TRIE__ + +#include "lm/binary_format.hh" +#include "lm/trie.hh" +#include "lm/weights.hh" + +#include <assert.h> + +namespace lm { +namespace ngram { +class SortedVocabulary; +namespace trie { + +struct TrieSearch { +  typedef NodeRange Node; + +  typedef ::lm::ngram::trie::Unigram Unigram; +  Unigram unigram; + +  typedef trie::BitPackedMiddle Middle; +  std::vector<Middle> middle; + +  typedef trie::BitPackedLongest Longest; +  Longest longest; + +  static const ModelType kModelType = TRIE_SORTED; + +  static std::size_t Size(const std::vector<uint64_t> &counts, const Config &/*config*/) { +    std::size_t ret = Unigram::Size(counts[0]); +    for (unsigned char i = 1; i < counts.size() - 1; ++i) { +      ret += Middle::Size(counts[i], counts[0], counts[i+1]); +    } +    return ret + Longest::Size(counts.back(), counts[0]); +  } + +  uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &/*config*/) { +    unigram.Init(start); +    start += Unigram::Size(counts[0]); +    middle.resize(counts.size() - 2); +    for (unsigned char i = 1; i < counts.size() - 1; ++i) { +      middle[i-1].Init(start, counts[0], counts[i+1]); +      start += Middle::Size(counts[i], counts[0], counts[i+1]); +    } +    longest.Init(start, counts[0]); +    return start + Longest::Size(counts.back(), counts[0]); +  } + +  void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab); + +  bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { +    return unigram.Find(word, prob, backoff, node); +  } + +  bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { +    return mid.Find(word, prob, backoff, node); +  } + +  bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { +    return mid.FindNoProb(word, backoff, node); +  } + +  bool LookupLongest(WordIndex word, float &prob, const Node &node) const { +    return longest.Find(word, prob, node); +  } + +  bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { +    // TODO: don't decode prob.   +    assert(begin != end); +    float ignored_prob, ignored_backoff; +    LookupUnigram(*begin, ignored_prob, ignored_backoff, node); +    for (const WordIndex *i = begin + 1; i < end; ++i) { +      if (!LookupMiddleNoProb(middle[i - begin - 1], *i, ignored_backoff, node)) return false; +    } +    return true; +  } +}; + +} // namespace trie +} // namespace ngram +} // namespace lm + +#endif // LM_SEARCH_TRIE__ diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh new file mode 100644 index 00000000..35dc2c96 --- /dev/null +++ b/klm/lm/trie.hh @@ -0,0 +1,129 @@ +#ifndef LM_TRIE__ +#define LM_TRIE__ + +#include <inttypes.h> + +#include <cstddef> + +#include "lm/word_index.hh" +#include "lm/weights.hh" + +namespace lm { +namespace ngram { +namespace trie { + +struct NodeRange { +  uint64_t begin, end; +}; + +// TODO: if the number of unigrams is a concern, also bit pack these records.   +struct UnigramValue { +  ProbBackoff weights; +  uint64_t next; +  uint64_t Next() const { return next; } +}; + +class Unigram { +  public: +    Unigram() {} +     +    void Init(void *start) { +      unigram_ = static_cast<UnigramValue*>(start); +    } +     +    static std::size_t Size(uint64_t count) { +      // +1 in case unknown doesn't appear.  +1 for the final next.   +      return (count + 2) * sizeof(UnigramValue); +    } +     +    const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index].weights; } +     +    ProbBackoff &Unknown() { return unigram_[0].weights; } + +    UnigramValue *Raw() { +      return unigram_; +    } +     +    void LoadedBinary() {} + +    bool Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { +      UnigramValue *val = unigram_ + word; +      prob = val->weights.prob; +      backoff = val->weights.backoff; +      next.begin = val->next; +      next.end = (val+1)->next; +      return true; +    } + +  private: +    UnigramValue *unigram_; +};   + +class BitPacked { +  public: +    BitPacked() {} + +    uint64_t InsertIndex() const { +      return insert_index_; +    } + +    void LoadedBinary() {} + +  protected: +    static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); + +    void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits); + +    uint8_t word_bits_, prob_bits_; +    uint8_t total_bits_; +    uint64_t word_mask_; + +    uint8_t *base_; + +    uint64_t insert_index_; +}; + +class BitPackedMiddle : public BitPacked { +  public: +    BitPackedMiddle() {} + +    static std::size_t Size(uint64_t entries, uint64_t max_vocab, uint64_t max_next); + +    void Init(void *base, uint64_t max_vocab, uint64_t max_next); + +    void Insert(WordIndex word, float prob, float backoff, uint64_t next); + +    bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const; + +    bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; + +    void FinishedLoading(uint64_t next_end); + +  private: +    uint8_t backoff_bits_, next_bits_; +    uint64_t next_mask_; +}; + + +class BitPackedLongest : public BitPacked { +  public: +    BitPackedLongest() {} + +    static std::size_t Size(uint64_t entries, uint64_t max_vocab) { +      return BaseSize(entries, max_vocab, 0); +    } + +    void Init(void *base, uint64_t max_vocab) { +      return BaseInit(base, max_vocab, 0); +    } + +    void Insert(WordIndex word, float prob); + +    bool Find(WordIndex word, float &prob, const NodeRange &node) const; +}; + +} // namespace trie +} // namespace ngram +} // namespace lm + +#endif // LM_TRIE__ diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh new file mode 100644 index 00000000..bb5d789b --- /dev/null +++ b/klm/lm/vocab.hh @@ -0,0 +1,138 @@ +#ifndef LM_VOCAB__ +#define LM_VOCAB__ + +#include "lm/enumerate_vocab.hh" +#include "lm/virtual_interface.hh" +#include "util/key_value_packing.hh" +#include "util/probing_hash_table.hh" +#include "util/sorted_uniform.hh" +#include "util/string_piece.hh" + +#include <string> +#include <vector> + +namespace lm { +class ProbBackoff; + +namespace ngram { +class Config; +class EnumerateVocab; + +namespace detail { +uint64_t HashForVocab(const char *str, std::size_t len); +inline uint64_t HashForVocab(const StringPiece &str) { +  return HashForVocab(str.data(), str.length()); +} +} // namespace detail + +class WriteWordsWrapper : public EnumerateVocab { +  public: +    WriteWordsWrapper(EnumerateVocab *inner, int fd); + +    ~WriteWordsWrapper(); +     +    void Add(WordIndex index, const StringPiece &str); + +  private: +    EnumerateVocab *inner_; +    int fd_; +}; + +// Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices.   +class SortedVocabulary : public base::Vocabulary { +  private: +    // Sorted uniform requires a GetKey function.   +    struct Entry { +      uint64_t GetKey() const { return key; } +      uint64_t key; +      bool operator<(const Entry &other) const { +        return key < other.key; +      } +    }; + +  public: +    SortedVocabulary(); + +    WordIndex Index(const StringPiece &str) const { +      const Entry *found; +      if (util::SortedUniformFind<const Entry *, uint64_t>(begin_, end_, detail::HashForVocab(str), found)) { +        return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table. +      } else { +        return 0; +      } +    } + +    // Ignores second argument for consistency with probing hash which has a float here.   +    static size_t Size(std::size_t entries, const Config &config); + +    // Everything else is for populating.  I'm too lazy to hide and friend these, but you'll only get a const reference anyway. +    void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); + +    void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); + +    WordIndex Insert(const StringPiece &str); + +    // Reorders reorder_vocab so that the IDs are sorted.   +    void FinishedLoading(ProbBackoff *reorder_vocab); + +    bool SawUnk() const { return saw_unk_; } + +    void LoadedBinary(int fd, EnumerateVocab *to); + +  private: +    Entry *begin_, *end_; + +    bool saw_unk_; + +    EnumerateVocab *enumerate_; + +    // Actual strings.  Used only when loading from ARPA and enumerate_ != NULL  +    std::vector<std::string> strings_to_enumerate_; +}; + +// Vocabulary storing a map from uint64_t to WordIndex.  +class ProbingVocabulary : public base::Vocabulary { +  public: +    ProbingVocabulary(); + +    WordIndex Index(const StringPiece &str) const { +      Lookup::ConstIterator i; +      return lookup_.Find(detail::HashForVocab(str), i) ? i->GetValue() : 0; +    } + +    static size_t Size(std::size_t entries, const Config &config); + +    // Everything else is for populating.  I'm too lazy to hide and friend these, but you'll only get a const reference anyway. +    void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); + +    void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); + +    WordIndex Insert(const StringPiece &str); + +    void FinishedLoading(ProbBackoff *reorder_vocab); + +    bool SawUnk() const { return saw_unk_; } + +    void LoadedBinary(int fd, EnumerateVocab *to); + +  private: +    // std::identity is an SGI extension :-( +    struct IdentityHash : public std::unary_function<uint64_t, std::size_t> { +      std::size_t operator()(uint64_t arg) const { return static_cast<std::size_t>(arg); } +    }; + +    typedef util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, WordIndex>, IdentityHash> Lookup; + +    Lookup lookup_; + +    WordIndex available_; + +    bool saw_unk_; + +    EnumerateVocab *enumerate_; +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_VOCAB__ | 
