summaryrefslogtreecommitdiff
path: root/klm/lm/builder/initial_probabilities.cc
diff options
context:
space:
mode:
authorWu, Ke <wuke@cs.umd.edu>2014-12-17 16:11:38 -0500
committerWu, Ke <wuke@cs.umd.edu>2014-12-17 16:11:38 -0500
commit1613f1fc44ca67820afd7e7b21eb54b316c8ce55 (patch)
treee02b77084f28a18df6b854f87a986124db44d717 /klm/lm/builder/initial_probabilities.cc
parentbd9308e22b5434aa220cc57d82ee867464a011f1 (diff)
parent796768086a687d3f1856fef6489c34fe4d373642 (diff)
Merge with upstream
Diffstat (limited to 'klm/lm/builder/initial_probabilities.cc')
-rw-r--r--klm/lm/builder/initial_probabilities.cc191
1 files changed, 167 insertions, 24 deletions
diff --git a/klm/lm/builder/initial_probabilities.cc b/klm/lm/builder/initial_probabilities.cc
index 58b42a20..5d19a897 100644
--- a/klm/lm/builder/initial_probabilities.cc
+++ b/klm/lm/builder/initial_probabilities.cc
@@ -3,6 +3,8 @@
#include "lm/builder/discount.hh"
#include "lm/builder/ngram_stream.hh"
#include "lm/builder/sort.hh"
+#include "lm/builder/hash_gamma.hh"
+#include "util/murmur_hash.hh"
#include "util/file.hh"
#include "util/stream/chain.hh"
#include "util/stream/io.hh"
@@ -14,55 +16,182 @@ namespace lm { namespace builder {
namespace {
struct BufferEntry {
- // Gamma from page 20 of Chen and Goodman.
+ // Gamma from page 20 of Chen and Goodman.
float gamma;
- // \sum_w a(c w) for all w.
+ // \sum_w a(c w) for all w.
float denominator;
};
-// Extract an array of gamma from an array of BufferEntry.
+struct HashBufferEntry : public BufferEntry {
+ // Hash value of ngram. Used to join contexts with backoffs.
+ uint64_t hash_value;
+};
+
+// Reads all entries in order like NGramStream does.
+// But deletes any entries that have CutoffCount below or equal to pruning
+// threshold.
+class PruneNGramStream {
+ public:
+ PruneNGramStream(const util::stream::ChainPosition &position) :
+ current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
+ dest_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
+ currentCount_(0),
+ block_(position)
+ {
+ StartBlock();
+ }
+
+ NGram &operator*() { return current_; }
+ NGram *operator->() { return &current_; }
+
+ operator bool() const {
+ return block_;
+ }
+
+ PruneNGramStream &operator++() {
+ assert(block_);
+
+ if (current_.Order() > 1) {
+ if(currentCount_ > 0) {
+ if(dest_.Base() < current_.Base()) {
+ memcpy(dest_.Base(), current_.Base(), current_.TotalSize());
+ }
+ dest_.NextInMemory();
+ }
+ } else {
+ dest_.NextInMemory();
+ }
+
+ current_.NextInMemory();
+
+ uint8_t *block_base = static_cast<uint8_t*>(block_->Get());
+ if (current_.Base() == block_base + block_->ValidSize()) {
+ block_->SetValidSize(dest_.Base() - block_base);
+ ++block_;
+ StartBlock();
+ if (block_) {
+ currentCount_ = current_.CutoffCount();
+ }
+ } else {
+ currentCount_ = current_.CutoffCount();
+ }
+
+ return *this;
+ }
+
+ private:
+ void StartBlock() {
+ for (; ; ++block_) {
+ if (!block_) return;
+ if (block_->ValidSize()) break;
+ }
+ current_.ReBase(block_->Get());
+ currentCount_ = current_.CutoffCount();
+
+ dest_.ReBase(block_->Get());
+ }
+
+ NGram current_; // input iterator
+ NGram dest_; // output iterator
+
+ uint64_t currentCount_;
+
+ util::stream::Link block_;
+};
+
+// Extract an array of HashedGamma from an array of BufferEntry.
class OnlyGamma {
public:
+ OnlyGamma(bool pruning) : pruning_(pruning) {}
+
void Run(const util::stream::ChainPosition &position) {
for (util::stream::Link block_it(position); block_it; ++block_it) {
- float *out = static_cast<float*>(block_it->Get());
- const float *in = out;
- const float *end = static_cast<const float*>(block_it->ValidEnd());
- for (out += 1, in += 2; in < end; out += 1, in += 2) {
- *out = *in;
+ if(pruning_) {
+ const HashBufferEntry *in = static_cast<const HashBufferEntry*>(block_it->Get());
+ const HashBufferEntry *end = static_cast<const HashBufferEntry*>(block_it->ValidEnd());
+
+ // Just make it point to the beginning of the stream so it can be overwritten
+ // With HashGamma values. Do not attempt to interpret the values until set below.
+ HashGamma *out = static_cast<HashGamma*>(block_it->Get());
+ for (; in < end; out += 1, in += 1) {
+ // buffering, otherwise might overwrite values too early
+ float gamma_buf = in->gamma;
+ uint64_t hash_buf = in->hash_value;
+
+ out->gamma = gamma_buf;
+ out->hash_value = hash_buf;
+ }
+ block_it->SetValidSize((block_it->ValidSize() * sizeof(HashGamma)) / sizeof(HashBufferEntry));
+ }
+ else {
+ float *out = static_cast<float*>(block_it->Get());
+ const float *in = out;
+ const float *end = static_cast<const float*>(block_it->ValidEnd());
+ for (out += 1, in += 2; in < end; out += 1, in += 2) {
+ *out = *in;
+ }
+ block_it->SetValidSize(block_it->ValidSize() / 2);
}
- block_it->SetValidSize(block_it->ValidSize() / 2);
}
}
+
+ private:
+ bool pruning_;
};
class AddRight {
public:
- AddRight(const Discount &discount, const util::stream::ChainPosition &input)
- : discount_(discount), input_(input) {}
+ AddRight(const Discount &discount, const util::stream::ChainPosition &input, bool pruning)
+ : discount_(discount), input_(input), pruning_(pruning) {}
void Run(const util::stream::ChainPosition &output) {
NGramStream in(input_);
util::stream::Stream out(output);
std::vector<WordIndex> previous(in->Order() - 1);
+ // Silly windows requires this workaround to just get an invalid pointer when empty.
+ void *const previous_raw = previous.empty() ? NULL : static_cast<void*>(&previous[0]);
const std::size_t size = sizeof(WordIndex) * previous.size();
+
for(; in; ++out) {
- memcpy(&previous[0], in->begin(), size);
+ memcpy(previous_raw, in->begin(), size);
uint64_t denominator = 0;
+ uint64_t normalizer = 0;
+
uint64_t counts[4];
memset(counts, 0, sizeof(counts));
do {
- denominator += in->Count();
- ++counts[std::min(in->Count(), static_cast<uint64_t>(3))];
- } while (++in && !memcmp(&previous[0], in->begin(), size));
+ denominator += in->UnmarkedCount();
+
+ // Collect unused probability mass from pruning.
+ // Becomes 0 for unpruned ngrams.
+ normalizer += in->UnmarkedCount() - in->CutoffCount();
+
+ // Chen&Goodman do not mention counting based on cutoffs, but
+ // backoff becomes larger than 1 otherwise, so probably needs
+ // to count cutoffs. Counts normally without pruning.
+ if(in->CutoffCount() > 0)
+ ++counts[std::min(in->CutoffCount(), static_cast<uint64_t>(3))];
+
+ } while (++in && !memcmp(previous_raw, in->begin(), size));
+
BufferEntry &entry = *reinterpret_cast<BufferEntry*>(out.Get());
entry.denominator = static_cast<float>(denominator);
entry.gamma = 0.0;
for (unsigned i = 1; i <= 3; ++i) {
entry.gamma += discount_.Get(i) * static_cast<float>(counts[i]);
}
+
+ // Makes model sum to 1 with pruning (I hope).
+ entry.gamma += normalizer;
+
entry.gamma /= entry.denominator;
+
+ if(pruning_) {
+ // If pruning is enabled the stream actually contains HashBufferEntry, see InitialProbabilities(...),
+ // so add a hash value that identifies the current ngram.
+ static_cast<HashBufferEntry*>(&entry)->hash_value = util::MurmurHashNative(previous_raw, size);
+ }
}
out.Poison();
}
@@ -70,6 +199,7 @@ class AddRight {
private:
const Discount &discount_;
const util::stream::ChainPosition input_;
+ bool pruning_;
};
class MergeRight {
@@ -82,7 +212,7 @@ class MergeRight {
void Run(const util::stream::ChainPosition &primary) {
util::stream::Stream summed(from_adder_);
- NGramStream grams(primary);
+ PruneNGramStream grams(primary);
// Without interpolation, the interpolation weight goes to <unk>.
if (grams->Order() == 1 && !interpolate_unigrams_) {
@@ -97,15 +227,16 @@ class MergeRight {
++summed;
return;
}
-
+
std::vector<WordIndex> previous(grams->Order() - 1);
const std::size_t size = sizeof(WordIndex) * previous.size();
for (; grams; ++summed) {
memcpy(&previous[0], grams->begin(), size);
const BufferEntry &sums = *static_cast<const BufferEntry*>(summed.Get());
+
do {
Payload &pay = grams->Value();
- pay.uninterp.prob = discount_.Apply(pay.count) / sums.denominator;
+ pay.uninterp.prob = discount_.Apply(grams->UnmarkedCount()) / sums.denominator;
pay.uninterp.gamma = sums.gamma;
} while (++grams && !memcmp(&previous[0], grams->begin(), size));
}
@@ -119,17 +250,29 @@ class MergeRight {
} // namespace
-void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector<Discount> &discounts, Chains &primary, Chains &second_in, Chains &gamma_out) {
- util::stream::ChainConfig gamma_config = config.adder_out;
- gamma_config.entry_size = sizeof(BufferEntry);
+void InitialProbabilities(
+ const InitialProbabilitiesConfig &config,
+ const std::vector<Discount> &discounts,
+ util::stream::Chains &primary,
+ util::stream::Chains &second_in,
+ util::stream::Chains &gamma_out,
+ const std::vector<uint64_t> &prune_thresholds) {
for (size_t i = 0; i < primary.size(); ++i) {
+ util::stream::ChainConfig gamma_config = config.adder_out;
+ if(prune_thresholds[i] > 0)
+ gamma_config.entry_size = sizeof(HashBufferEntry);
+ else
+ gamma_config.entry_size = sizeof(BufferEntry);
+
util::stream::ChainPosition second(second_in[i].Add());
second_in[i] >> util::stream::kRecycle;
gamma_out.push_back(gamma_config);
- gamma_out[i] >> AddRight(discounts[i], second);
+ gamma_out[i] >> AddRight(discounts[i], second, prune_thresholds[i] > 0);
+
primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]);
- // Don't bother with the OnlyGamma thread for something to discard.
- if (i) gamma_out[i] >> OnlyGamma();
+
+ // Don't bother with the OnlyGamma thread for something to discard.
+ if (i) gamma_out[i] >> OnlyGamma(prune_thresholds[i] > 0);
}
}