diff options
| author | armatthews <armatthe@cmu.edu> | 2014-10-13 14:59:23 -0400 | 
|---|---|---|
| committer | armatthews <armatthe@cmu.edu> | 2014-10-13 14:59:23 -0400 | 
| commit | 9a06ff1465eb3477ac3d1e92ab52e7eae40316a8 (patch) | |
| tree | 808c266a3f510d00f37cd19c3f1da91d8fc683f7 /klm/lm/builder/initial_probabilities.cc | |
| parent | e51da099233df0a384b04fe5908b30e44040d13e (diff) | |
| parent | d3e2ec203a5cf550320caa8023ac3dd103b0be7d (diff) | |
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'klm/lm/builder/initial_probabilities.cc')
| -rw-r--r-- | klm/lm/builder/initial_probabilities.cc | 191 | 
1 files changed, 167 insertions, 24 deletions
| diff --git a/klm/lm/builder/initial_probabilities.cc b/klm/lm/builder/initial_probabilities.cc index 58b42a20..5d19a897 100644 --- a/klm/lm/builder/initial_probabilities.cc +++ b/klm/lm/builder/initial_probabilities.cc @@ -3,6 +3,8 @@  #include "lm/builder/discount.hh"  #include "lm/builder/ngram_stream.hh"  #include "lm/builder/sort.hh" +#include "lm/builder/hash_gamma.hh" +#include "util/murmur_hash.hh"  #include "util/file.hh"  #include "util/stream/chain.hh"  #include "util/stream/io.hh" @@ -14,55 +16,182 @@ namespace lm { namespace builder {  namespace {  struct BufferEntry { -  // Gamma from page 20 of Chen and Goodman.   +  // Gamma from page 20 of Chen and Goodman.    float gamma; -  // \sum_w a(c w) for all w.   +  // \sum_w a(c w) for all w.    float denominator;  }; -// Extract an array of gamma from an array of BufferEntry.   +struct HashBufferEntry : public BufferEntry { +  // Hash value of ngram. Used to join contexts with backoffs. +  uint64_t hash_value; +}; + +// Reads all entries in order like NGramStream does.   +// But deletes any entries that have CutoffCount below or equal to pruning +// threshold.    +class PruneNGramStream { +  public: +    PruneNGramStream(const util::stream::ChainPosition &position) : +      current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), +      dest_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), +      currentCount_(0), +      block_(position) +    {  +      StartBlock(); +    } + +    NGram &operator*() { return current_; } +    NGram *operator->() { return ¤t_; } + +    operator bool() const { +      return block_; +    } + +    PruneNGramStream &operator++() { +      assert(block_); +       +      if (current_.Order() > 1) { +        if(currentCount_ > 0) { +          if(dest_.Base() < current_.Base()) { +            memcpy(dest_.Base(), current_.Base(), current_.TotalSize()); +          } +          dest_.NextInMemory(); +        } +      } else { +        dest_.NextInMemory();           +      } +       +      current_.NextInMemory(); +       +      uint8_t *block_base = static_cast<uint8_t*>(block_->Get()); +      if (current_.Base() == block_base + block_->ValidSize()) { +        block_->SetValidSize(dest_.Base() - block_base); +        ++block_; +        StartBlock(); +        if (block_) { +          currentCount_ = current_.CutoffCount(); +        } +      } else {                 +        currentCount_ = current_.CutoffCount(); +      } +       +      return *this; +    } + +  private: +    void StartBlock() { +      for (; ; ++block_) { +        if (!block_) return; +        if (block_->ValidSize()) break; +      } +      current_.ReBase(block_->Get()); +      currentCount_ = current_.CutoffCount(); +       +      dest_.ReBase(block_->Get()); +    } + +    NGram current_; // input iterator +    NGram dest_;    // output iterator +     +    uint64_t currentCount_; + +    util::stream::Link block_; +}; + +// Extract an array of HashedGamma from an array of BufferEntry.  class OnlyGamma {    public: +    OnlyGamma(bool pruning) : pruning_(pruning) {} +      void Run(const util::stream::ChainPosition &position) {        for (util::stream::Link block_it(position); block_it; ++block_it) { -        float *out = static_cast<float*>(block_it->Get()); -        const float *in = out; -        const float *end = static_cast<const float*>(block_it->ValidEnd()); -        for (out += 1, in += 2; in < end; out += 1, in += 2) { -          *out = *in; +        if(pruning_) { +          const HashBufferEntry *in = static_cast<const HashBufferEntry*>(block_it->Get()); +          const HashBufferEntry *end = static_cast<const HashBufferEntry*>(block_it->ValidEnd()); + +          // Just make it point to the beginning of the stream so it can be overwritten +          // With HashGamma values. Do not attempt to interpret the values until set below. +          HashGamma *out = static_cast<HashGamma*>(block_it->Get()); +          for (; in < end; out += 1, in += 1) { +            // buffering, otherwise might overwrite values too early +            float gamma_buf = in->gamma; +            uint64_t hash_buf = in->hash_value; + +            out->gamma = gamma_buf; +            out->hash_value = hash_buf; +          } +          block_it->SetValidSize((block_it->ValidSize() * sizeof(HashGamma)) / sizeof(HashBufferEntry)); +        } +        else { +          float *out = static_cast<float*>(block_it->Get()); +          const float *in = out; +          const float *end = static_cast<const float*>(block_it->ValidEnd()); +          for (out += 1, in += 2; in < end; out += 1, in += 2) { +            *out = *in; +          } +          block_it->SetValidSize(block_it->ValidSize() / 2);          } -        block_it->SetValidSize(block_it->ValidSize() / 2);        }      } + +    private: +      bool pruning_;  };  class AddRight {    public: -    AddRight(const Discount &discount, const util::stream::ChainPosition &input)  -      : discount_(discount), input_(input) {} +    AddRight(const Discount &discount, const util::stream::ChainPosition &input, bool pruning) +      : discount_(discount), input_(input), pruning_(pruning) {}      void Run(const util::stream::ChainPosition &output) {        NGramStream in(input_);        util::stream::Stream out(output);        std::vector<WordIndex> previous(in->Order() - 1); +      // Silly windows requires this workaround to just get an invalid pointer when empty. +      void *const previous_raw = previous.empty() ? NULL : static_cast<void*>(&previous[0]);        const std::size_t size = sizeof(WordIndex) * previous.size(); +        for(; in; ++out) { -        memcpy(&previous[0], in->begin(), size); +        memcpy(previous_raw, in->begin(), size);          uint64_t denominator = 0; +        uint64_t normalizer = 0; +                  uint64_t counts[4];          memset(counts, 0, sizeof(counts));          do { -          denominator += in->Count(); -          ++counts[std::min(in->Count(), static_cast<uint64_t>(3))]; -        } while (++in && !memcmp(&previous[0], in->begin(), size)); +          denominator += in->UnmarkedCount(); +           +          // Collect unused probability mass from pruning. +          // Becomes 0 for unpruned ngrams. +          normalizer += in->UnmarkedCount() - in->CutoffCount(); +           +          // Chen&Goodman do not mention counting based on cutoffs, but +          // backoff becomes larger than 1 otherwise, so probably needs +          // to count cutoffs. Counts normally without pruning. +          if(in->CutoffCount() > 0) +            ++counts[std::min(in->CutoffCount(), static_cast<uint64_t>(3))]; +         +        } while (++in && !memcmp(previous_raw, in->begin(), size)); +                  BufferEntry &entry = *reinterpret_cast<BufferEntry*>(out.Get());          entry.denominator = static_cast<float>(denominator);          entry.gamma = 0.0;          for (unsigned i = 1; i <= 3; ++i) {            entry.gamma += discount_.Get(i) * static_cast<float>(counts[i]);          } + +        // Makes model sum to 1 with pruning (I hope). +        entry.gamma += normalizer; +                  entry.gamma /= entry.denominator; +         +        if(pruning_) { +          // If pruning is enabled the stream actually contains HashBufferEntry, see InitialProbabilities(...), +          // so add a hash value that identifies the current ngram. +          static_cast<HashBufferEntry*>(&entry)->hash_value = util::MurmurHashNative(previous_raw, size); +        }        }        out.Poison();      } @@ -70,6 +199,7 @@ class AddRight {    private:      const Discount &discount_;      const util::stream::ChainPosition input_; +    bool pruning_;  };  class MergeRight { @@ -82,7 +212,7 @@ class MergeRight {      void Run(const util::stream::ChainPosition &primary) {        util::stream::Stream summed(from_adder_); -      NGramStream grams(primary); +      PruneNGramStream grams(primary);        // Without interpolation, the interpolation weight goes to <unk>.        if (grams->Order() == 1 && !interpolate_unigrams_) { @@ -97,15 +227,16 @@ class MergeRight {          ++summed;          return;        } - +              std::vector<WordIndex> previous(grams->Order() - 1);        const std::size_t size = sizeof(WordIndex) * previous.size();        for (; grams; ++summed) {          memcpy(&previous[0], grams->begin(), size);          const BufferEntry &sums = *static_cast<const BufferEntry*>(summed.Get()); +                  do {            Payload &pay = grams->Value(); -          pay.uninterp.prob = discount_.Apply(pay.count) / sums.denominator; +          pay.uninterp.prob = discount_.Apply(grams->UnmarkedCount()) / sums.denominator;            pay.uninterp.gamma = sums.gamma;          } while (++grams && !memcmp(&previous[0], grams->begin(), size));        } @@ -119,17 +250,29 @@ class MergeRight {  } // namespace -void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out) { -  util::stream::ChainConfig gamma_config = config.adder_out; -  gamma_config.entry_size = sizeof(BufferEntry); +void InitialProbabilities( +    const InitialProbabilitiesConfig &config, +    const std::vector<Discount> &discounts, +    util::stream::Chains &primary, +    util::stream::Chains &second_in, +    util::stream::Chains &gamma_out, +    const std::vector<uint64_t> &prune_thresholds) {    for (size_t i = 0; i < primary.size(); ++i) { +    util::stream::ChainConfig gamma_config = config.adder_out; +    if(prune_thresholds[i] > 0) +      gamma_config.entry_size = sizeof(HashBufferEntry); +    else +      gamma_config.entry_size = sizeof(BufferEntry); +      util::stream::ChainPosition second(second_in[i].Add());      second_in[i] >> util::stream::kRecycle;      gamma_out.push_back(gamma_config); -    gamma_out[i] >> AddRight(discounts[i], second); +    gamma_out[i] >> AddRight(discounts[i], second, prune_thresholds[i] > 0); +      primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]); -    // Don't bother with the OnlyGamma thread for something to discard.   -    if (i) gamma_out[i] >> OnlyGamma(); + +    // Don't bother with the OnlyGamma thread for something to discard. +    if (i) gamma_out[i] >> OnlyGamma(prune_thresholds[i] > 0);    }  } | 
