summaryrefslogtreecommitdiff
path: root/klm/lm/builder/adjust_counts.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/builder/adjust_counts.cc')
-rw-r--r--klm/lm/builder/adjust_counts.cc164
1 files changed, 117 insertions, 47 deletions
diff --git a/klm/lm/builder/adjust_counts.cc b/klm/lm/builder/adjust_counts.cc
index a6f48011..803c557d 100644
--- a/klm/lm/builder/adjust_counts.cc
+++ b/klm/lm/builder/adjust_counts.cc
@@ -1,8 +1,9 @@
#include "lm/builder/adjust_counts.hh"
-#include "lm/builder/multi_stream.hh"
+#include "lm/builder/ngram_stream.hh"
#include "util/stream/timer.hh"
#include <algorithm>
+#include <iostream>
namespace lm { namespace builder {
@@ -10,56 +11,78 @@ BadDiscountException::BadDiscountException() throw() {}
BadDiscountException::~BadDiscountException() throw() {}
namespace {
-// Return last word in full that is different.
+// Return last word in full that is different.
const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) {
const WordIndex *cur_word = full.end() - 1;
const WordIndex *pre_word = lower_last.end() - 1;
- // Find last difference.
+ // Find last difference.
for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {}
return cur_word;
}
class StatCollector {
public:
- StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<Discount> &discounts)
- : orders_(order), full_(orders_.back()), counts_(counts), discounts_(discounts) {
+ StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<uint64_t> &counts_pruned, std::vector<Discount> &discounts)
+ : orders_(order), full_(orders_.back()), counts_(counts), counts_pruned_(counts_pruned), discounts_(discounts) {
memset(&orders_[0], 0, sizeof(OrderStat) * order);
}
~StatCollector() {}
- void CalculateDiscounts() {
+ void CalculateDiscounts(const DiscountConfig &config) {
counts_.resize(orders_.size());
- discounts_.resize(orders_.size());
+ counts_pruned_.resize(orders_.size());
for (std::size_t i = 0; i < orders_.size(); ++i) {
const OrderStat &s = orders_[i];
counts_[i] = s.count;
+ counts_pruned_[i] = s.count_pruned;
+ }
- for (unsigned j = 1; j < 4; ++j) {
- // TODO: Specialize error message for j == 3, meaning 3+
- UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for "
- << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any "
- << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?");
- }
-
- // See equation (26) in Chen and Goodman.
- discounts_[i].amount[0] = 0.0;
- float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]);
- for (unsigned j = 1; j < 4; ++j) {
- discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]);
- UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]);
+ discounts_ = config.overwrite;
+ discounts_.resize(orders_.size());
+ for (std::size_t i = config.overwrite.size(); i < orders_.size(); ++i) {
+ const OrderStat &s = orders_[i];
+ try {
+ for (unsigned j = 1; j < 4; ++j) {
+ // TODO: Specialize error message for j == 3, meaning 3+
+ UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for "
+ << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any "
+ << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?");
+ }
+
+ // See equation (26) in Chen and Goodman.
+ discounts_[i].amount[0] = 0.0;
+ float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]);
+ for (unsigned j = 1; j < 4; ++j) {
+ discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]);
+ UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]);
+ }
+ } catch (const BadDiscountException &e) {
+ switch (config.bad_action) {
+ case THROW_UP:
+ throw;
+ case COMPLAIN:
+ std::cerr << e.what() << " Substituting fallback discounts D1=" << config.fallback.amount[1] << " D2=" << config.fallback.amount[2] << " D3+=" << config.fallback.amount[3] << std::endl;
+ case SILENT:
+ break;
+ }
+ discounts_[i] = config.fallback;
}
}
}
- void Add(std::size_t order_minus_1, uint64_t count) {
+ void Add(std::size_t order_minus_1, uint64_t count, bool pruned = false) {
OrderStat &stat = orders_[order_minus_1];
++stat.count;
+ if (!pruned)
+ ++stat.count_pruned;
if (count < 5) ++stat.n[count];
}
- void AddFull(uint64_t count) {
+ void AddFull(uint64_t count, bool pruned = false) {
++full_.count;
+ if (!pruned)
+ ++full_.count_pruned;
if (count < 5) ++full_.n[count];
}
@@ -68,24 +91,27 @@ class StatCollector {
// n_1 in equation 26 of Chen and Goodman etc
uint64_t n[5];
uint64_t count;
+ uint64_t count_pruned;
};
std::vector<OrderStat> orders_;
OrderStat &full_;
std::vector<uint64_t> &counts_;
+ std::vector<uint64_t> &counts_pruned_;
std::vector<Discount> &discounts_;
};
-// Reads all entries in order like NGramStream does.
+// Reads all entries in order like NGramStream does.
// But deletes any entries that have <s> in the 1st (not 0th) position on the
// way out by putting other entries in their place. This disrupts the sort
-// order but we don't care because the data is going to be sorted again.
+// order but we don't care because the data is going to be sorted again.
class CollapseStream {
public:
- CollapseStream(const util::stream::ChainPosition &position) :
+ CollapseStream(const util::stream::ChainPosition &position, uint64_t prune_threshold) :
current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
- block_(position) {
+ prune_threshold_(prune_threshold),
+ block_(position) {
StartBlock();
}
@@ -96,10 +122,18 @@ class CollapseStream {
CollapseStream &operator++() {
assert(block_);
+
if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) {
memcpy(current_.Base(), copy_from_, current_.TotalSize());
UpdateCopyFrom();
+
+ // Mark highest order n-grams for later pruning
+ if(current_.Count() <= prune_threshold_) {
+ current_.Mark();
+ }
+
}
+
current_.NextInMemory();
uint8_t *block_base = static_cast<uint8_t*>(block_->Get());
if (current_.Base() == block_base + block_->ValidSize()) {
@@ -107,6 +141,12 @@ class CollapseStream {
++block_;
StartBlock();
}
+
+ // Mark highest order n-grams for later pruning
+ if(current_.Count() <= prune_threshold_) {
+ current_.Mark();
+ }
+
return *this;
}
@@ -119,9 +159,15 @@ class CollapseStream {
current_.ReBase(block_->Get());
copy_from_ = static_cast<uint8_t*>(block_->Get()) + block_->ValidSize();
UpdateCopyFrom();
+
+ // Mark highest order n-grams for later pruning
+ if(current_.Count() <= prune_threshold_) {
+ current_.Mark();
+ }
+
}
- // Find last without bos.
+ // Find last without bos.
void UpdateCopyFrom() {
for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) {
if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break;
@@ -132,83 +178,107 @@ class CollapseStream {
// Goes backwards in the block
uint8_t *copy_from_;
-
+ uint64_t prune_threshold_;
util::stream::Link block_;
};
} // namespace
-void AdjustCounts::Run(const ChainPositions &positions) {
+void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
UTIL_TIMER("(%w s) Adjusted counts\n");
const std::size_t order = positions.size();
- StatCollector stats(order, counts_, discounts_);
+ StatCollector stats(order, counts_, counts_pruned_, discounts_);
if (order == 1) {
+
// Only unigrams. Just collect stats.
for (NGramStream full(positions[0]); full; ++full)
stats.AddFull(full->Count());
- stats.CalculateDiscounts();
+
+ stats.CalculateDiscounts(discount_config_);
return;
}
NGramStreams streams;
streams.Init(positions, positions.size() - 1);
- CollapseStream full(positions[positions.size() - 1]);
+
+ CollapseStream full(positions[positions.size() - 1], prune_thresholds_.back());
- // Initialization: <unk> has count 0 and so does <s>.
+ // Initialization: <unk> has count 0 and so does <s>.
NGramStream *lower_valid = streams.begin();
streams[0]->Count() = 0;
*streams[0]->begin() = kUNK;
stats.Add(0, 0);
(++streams[0])->Count() = 0;
*streams[0]->begin() = kBOS;
- // not in stats because it will get put in later.
+ // not in stats because it will get put in later.
+ std::vector<uint64_t> lower_counts(positions.size(), 0);
+
// iterate over full (the stream of the highest order ngrams)
- for (; full; ++full) {
+ for (; full; ++full) {
const WordIndex *different = FindDifference(*full, **lower_valid);
std::size_t same = full->end() - 1 - different;
- // Increment the adjusted count.
+ // Increment the adjusted count.
if (same) ++streams[same - 1]->Count();
- // Output all the valid ones that changed.
+ // Output all the valid ones that changed.
for (; lower_valid >= &streams[same]; --lower_valid) {
- stats.Add(lower_valid - streams.begin(), (*lower_valid)->Count());
+
+ // mjd: review this!
+ uint64_t order = (*lower_valid)->Order();
+ uint64_t realCount = lower_counts[order - 1];
+ if(order > 1 && prune_thresholds_[order - 1] && realCount <= prune_thresholds_[order - 1])
+ (*lower_valid)->Mark();
+
+ stats.Add(lower_valid - streams.begin(), (*lower_valid)->UnmarkedCount(), (*lower_valid)->IsMarked());
++*lower_valid;
}
+
+ // Count the true occurrences of lower-order n-grams
+ for (std::size_t i = 0; i < lower_counts.size(); ++i) {
+ if (i >= same) {
+ lower_counts[i] = 0;
+ }
+ lower_counts[i] += full->UnmarkedCount();
+ }
// This is here because bos is also const WordIndex *, so copy gets
- // consistent argument types.
+ // consistent argument types.
const WordIndex *full_end = full->end();
- // Initialize and mark as valid up to bos.
+ // Initialize and mark as valid up to bos.
const WordIndex *bos;
for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) {
++lower_valid;
std::copy(bos, full_end, (*lower_valid)->begin());
(*lower_valid)->Count() = 1;
}
- // Now bos indicates where <s> is or is the 0th word of full.
+ // Now bos indicates where <s> is or is the 0th word of full.
if (bos != full->begin()) {
- // There is an <s> beyond the 0th word.
+ // There is an <s> beyond the 0th word.
NGramStream &to = *++lower_valid;
std::copy(bos, full_end, to->begin());
- to->Count() = full->Count();
+
+ // mjd: what is this doing?
+ to->Count() = full->UnmarkedCount();
} else {
- stats.AddFull(full->Count());
+ stats.AddFull(full->UnmarkedCount(), full->IsMarked());
}
assert(lower_valid >= &streams[0]);
}
// Output everything valid.
for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) {
- stats.Add(s - streams.begin(), (*s)->Count());
+ if((*s)->Count() <= prune_thresholds_[(*s)->Order() - 1])
+ (*s)->Mark();
+ stats.Add(s - streams.begin(), (*s)->UnmarkedCount(), (*s)->IsMarked());
++*s;
}
- // Poison everyone! Except the N-grams which were already poisoned by the input.
+ // Poison everyone! Except the N-grams which were already poisoned by the input.
for (NGramStream *s = streams.begin(); s != streams.end(); ++s)
s->Poison();
- stats.CalculateDiscounts();
+ stats.CalculateDiscounts(discount_config_);
// NOTE: See special early-return case for unigrams near the top of this function
}