#include "lm/builder/adjust_counts.hh" #include "lm/builder/multi_stream.hh" #include "util/stream/timer.hh" #include <algorithm> namespace lm { namespace builder { BadDiscountException::BadDiscountException() throw() {} BadDiscountException::~BadDiscountException() throw() {} namespace { // Return last word in full that is different. const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) { const WordIndex *cur_word = full.end() - 1; const WordIndex *pre_word = lower_last.end() - 1; // Find last difference. for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {} return cur_word; } class StatCollector { public: StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<Discount> &discounts) : orders_(order), full_(orders_.back()), counts_(counts), discounts_(discounts) { memset(&orders_[0], 0, sizeof(OrderStat) * order); } ~StatCollector() {} void CalculateDiscounts() { counts_.resize(orders_.size()); discounts_.resize(orders_.size()); for (std::size_t i = 0; i < orders_.size(); ++i) { const OrderStat &s = orders_[i]; counts_[i] = s.count; for (unsigned j = 1; j < 4; ++j) { // TODO: Specialize error message for j == 3, meaning 3+ UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for " << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any " << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?"); } // See equation (26) in Chen and Goodman. discounts_[i].amount[0] = 0.0; float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]); for (unsigned j = 1; j < 4; ++j) { discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]); UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]); } } } void Add(std::size_t order_minus_1, uint64_t count) { OrderStat &stat = orders_[order_minus_1]; ++stat.count; if (count < 5) ++stat.n[count]; } void AddFull(uint64_t count) { ++full_.count; if (count < 5) ++full_.n[count]; } private: struct OrderStat { // n_1 in equation 26 of Chen and Goodman etc uint64_t n[5]; uint64_t count; }; std::vector<OrderStat> orders_; OrderStat &full_; std::vector<uint64_t> &counts_; std::vector<Discount> &discounts_; }; // Reads all entries in order like NGramStream does. // But deletes any entries that have <s> in the 1st (not 0th) position on the // way out by putting other entries in their place. This disrupts the sort // order but we don't care because the data is going to be sorted again. class CollapseStream { public: CollapseStream(const util::stream::ChainPosition &position) : current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), block_(position) { StartBlock(); } const NGram &operator*() const { return current_; } const NGram *operator->() const { return ¤t_; } operator bool() const { return block_; } CollapseStream &operator++() { assert(block_); if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) { memcpy(current_.Base(), copy_from_, current_.TotalSize()); UpdateCopyFrom(); } current_.NextInMemory(); uint8_t *block_base = static_cast<uint8_t*>(block_->Get()); if (current_.Base() == block_base + block_->ValidSize()) { block_->SetValidSize(copy_from_ + current_.TotalSize() - block_base); ++block_; StartBlock(); } return *this; } private: void StartBlock() { for (; ; ++block_) { if (!block_) return; if (block_->ValidSize()) break; } current_.ReBase(block_->Get()); copy_from_ = static_cast<uint8_t*>(block_->Get()) + block_->ValidSize(); UpdateCopyFrom(); } // Find last without bos. void UpdateCopyFrom() { for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) { if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break; } } NGram current_; // Goes backwards in the block uint8_t *copy_from_; util::stream::Link block_; }; } // namespace void AdjustCounts::Run(const ChainPositions &positions) { UTIL_TIMER("(%w s) Adjusted counts\n"); const std::size_t order = positions.size(); StatCollector stats(order, counts_, discounts_); if (order == 1) { // Only unigrams. Just collect stats. for (NGramStream full(positions[0]); full; ++full) stats.AddFull(full->Count()); stats.CalculateDiscounts(); return; } NGramStreams streams; streams.Init(positions, positions.size() - 1); CollapseStream full(positions[positions.size() - 1]); // Initialization: <unk> has count 0 and so does <s>. NGramStream *lower_valid = streams.begin(); streams[0]->Count() = 0; *streams[0]->begin() = kUNK; stats.Add(0, 0); (++streams[0])->Count() = 0; *streams[0]->begin() = kBOS; // not in stats because it will get put in later. // iterate over full (the stream of the highest order ngrams) for (; full; ++full) { const WordIndex *different = FindDifference(*full, **lower_valid); std::size_t same = full->end() - 1 - different; // Increment the adjusted count. if (same) ++streams[same - 1]->Count(); // Output all the valid ones that changed. for (; lower_valid >= &streams[same]; --lower_valid) { stats.Add(lower_valid - streams.begin(), (*lower_valid)->Count()); ++*lower_valid; } // This is here because bos is also const WordIndex *, so copy gets // consistent argument types. const WordIndex *full_end = full->end(); // Initialize and mark as valid up to bos. const WordIndex *bos; for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) { ++lower_valid; std::copy(bos, full_end, (*lower_valid)->begin()); (*lower_valid)->Count() = 1; } // Now bos indicates where <s> is or is the 0th word of full. if (bos != full->begin()) { // There is an <s> beyond the 0th word. NGramStream &to = *++lower_valid; std::copy(bos, full_end, to->begin()); to->Count() = full->Count(); } else { stats.AddFull(full->Count()); } assert(lower_valid >= &streams[0]); } // Output everything valid. for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) { stats.Add(s - streams.begin(), (*s)->Count()); ++*s; } // Poison everyone! Except the N-grams which were already poisoned by the input. for (NGramStream *s = streams.begin(); s != streams.end(); ++s) s->Poison(); stats.CalculateDiscounts(); // NOTE: See special early-return case for unigrams near the top of this function } }} // namespaces