From b1ed81ef3216b212295afa76c5d20a56fb647204 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 13 Oct 2014 00:42:37 -0400 Subject: new kenlm --- klm/lm/bhiksha.hh | 29 +++-- klm/lm/binary_format.cc | 2 +- klm/lm/binary_format.hh | 6 +- klm/lm/blank.hh | 6 +- klm/lm/builder/Makefile.am | 8 +- klm/lm/builder/adjust_counts.cc | 164 +++++++++++++++++++-------- klm/lm/builder/adjust_counts.hh | 41 +++++-- klm/lm/builder/adjust_counts_test.cc | 13 ++- klm/lm/builder/corpus_count.cc | 100 ++++++----------- klm/lm/builder/corpus_count.hh | 11 +- klm/lm/builder/corpus_count_test.cc | 2 +- klm/lm/builder/discount.hh | 6 +- klm/lm/builder/dump_counts_main.cc | 36 ++++++ klm/lm/builder/hash_gamma.hh | 19 ++++ klm/lm/builder/header_info.hh | 4 +- klm/lm/builder/initial_probabilities.cc | 191 ++++++++++++++++++++++++++++---- klm/lm/builder/initial_probabilities.hh | 17 ++- klm/lm/builder/interpolate.cc | 122 +++++++++++++++++--- klm/lm/builder/interpolate.hh | 22 ++-- klm/lm/builder/joint_order.hh | 10 +- klm/lm/builder/lmplz_main.cc | 97 +++++++++++++++- klm/lm/builder/ngram.hh | 39 +++++-- klm/lm/builder/ngram_stream.hh | 9 +- klm/lm/builder/pipeline.cc | 103 ++++++++++------- klm/lm/builder/pipeline.hh | 38 ++++++- klm/lm/builder/print.cc | 12 +- klm/lm/builder/print.hh | 12 +- klm/lm/builder/sort.hh | 157 ++++++++++++++++++++++++-- klm/lm/config.hh | 6 +- klm/lm/enumerate_vocab.hh | 6 +- klm/lm/facade.hh | 6 +- klm/lm/filter/arpa_io.hh | 6 +- klm/lm/filter/count_io.hh | 6 +- klm/lm/filter/format.hh | 6 +- klm/lm/filter/phrase.hh | 6 +- klm/lm/filter/thread.hh | 6 +- klm/lm/filter/vocab.hh | 6 +- klm/lm/filter/wrapper.hh | 6 +- klm/lm/interpolate/arpa_to_stream.cc | 47 ++++++++ klm/lm/interpolate/arpa_to_stream.hh | 38 +++++++ klm/lm/interpolate/example_sort_main.cc | 144 ++++++++++++++++++++++++ klm/lm/left.hh | 6 +- klm/lm/lm_exception.hh | 4 +- klm/lm/max_order.hh | 8 +- klm/lm/model.hh | 6 +- klm/lm/model_test.cc | 2 +- klm/lm/model_type.hh | 6 +- klm/lm/neural/wordvecs.cc | 23 ++++ klm/lm/neural/wordvecs.hh | 38 +++++++ klm/lm/ngram_query.hh | 91 ++++++++++----- klm/lm/partial.hh | 6 +- klm/lm/quantize.hh | 6 +- klm/lm/query_main.cc | 75 +++++++++---- klm/lm/read_arpa.hh | 31 +++--- klm/lm/return.hh | 6 +- klm/lm/search_hashed.cc | 2 +- klm/lm/search_hashed.hh | 6 +- klm/lm/search_trie.cc | 1 + klm/lm/search_trie.hh | 6 +- klm/lm/sizes.hh | 6 +- klm/lm/state.hh | 6 +- klm/lm/test.arpa | 2 +- klm/lm/test_nounk.arpa | 2 +- klm/lm/trie.cc | 7 +- klm/lm/trie.hh | 6 +- klm/lm/trie_sort.cc | 21 +++- klm/lm/trie_sort.hh | 6 +- klm/lm/value.hh | 6 +- klm/lm/value_build.hh | 6 +- klm/lm/virtual_interface.hh | 6 +- klm/lm/vocab.cc | 12 +- klm/lm/vocab.hh | 90 ++++++++++++--- klm/lm/weights.hh | 6 +- klm/lm/word_index.hh | 4 +- klm/lm/wrappers/README | 3 + klm/lm/wrappers/nplm.cc | 90 +++++++++++++++ klm/lm/wrappers/nplm.hh | 83 ++++++++++++++ 77 files changed, 1781 insertions(+), 469 deletions(-) create mode 100644 klm/lm/builder/dump_counts_main.cc create mode 100644 klm/lm/builder/hash_gamma.hh create mode 100644 klm/lm/interpolate/arpa_to_stream.cc create mode 100644 klm/lm/interpolate/arpa_to_stream.hh create mode 100644 klm/lm/interpolate/example_sort_main.cc create mode 100644 klm/lm/neural/wordvecs.cc create mode 100644 klm/lm/neural/wordvecs.hh create mode 100644 klm/lm/wrappers/README create mode 100644 klm/lm/wrappers/nplm.cc create mode 100644 klm/lm/wrappers/nplm.hh (limited to 'klm/lm') diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index 350571a6..134beb2f 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -10,17 +10,19 @@ * Currently only used for next pointers. */ -#ifndef LM_BHIKSHA__ -#define LM_BHIKSHA__ - -#include -#include +#ifndef LM_BHIKSHA_H +#define LM_BHIKSHA_H #include "lm/model_type.hh" #include "lm/trie.hh" #include "util/bit_packing.hh" #include "util/sorted_uniform.hh" +#include + +#include +#include + namespace lm { namespace ngram { struct Config; @@ -73,15 +75,24 @@ class ArrayBhiksha { ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_value, const Config &config); void ReadNext(const void *base, uint64_t bit_offset, uint64_t index, uint8_t total_bits, NodeRange &out) const { - const uint64_t *begin_it = util::BinaryBelow(util::IdentityAccessor(), offset_begin_, offset_end_, index); + // Some assertions are commented out because they are expensive. + // assert(*offset_begin_ == 0); + // std::upper_bound returns the first element that is greater. Want the + // last element that is <= to the index. + const uint64_t *begin_it = std::upper_bound(offset_begin_, offset_end_, index) - 1; + // Since *offset_begin_ == 0, the position should be in range. + // assert(begin_it >= offset_begin_); const uint64_t *end_it; - for (end_it = begin_it; (end_it < offset_end_) && (*end_it <= index + 1); ++end_it) {} + for (end_it = begin_it + 1; (end_it < offset_end_) && (*end_it <= index + 1); ++end_it) {} + // assert(end_it == std::upper_bound(offset_begin_, offset_end_, index + 1)); --end_it; + // assert(end_it >= begin_it); out.begin = ((begin_it - offset_begin_) << next_inline_.bits) | util::ReadInt57(base, bit_offset, next_inline_.bits, next_inline_.mask); out.end = ((end_it - offset_begin_) << next_inline_.bits) | util::ReadInt57(base, bit_offset + total_bits, next_inline_.bits, next_inline_.mask); - //assert(out.end >= out.begin); + // If this fails, consider rebuilding your model using KenLM after 1e333d786b748555e8f368d2bbba29a016c98052 + assert(out.end >= out.begin); } void WriteNext(void *base, uint64_t bit_offset, uint64_t index, uint64_t value) { @@ -109,4 +120,4 @@ class ArrayBhiksha { } // namespace ngram } // namespace lm -#endif // LM_BHIKSHA__ +#endif // LM_BHIKSHA_H diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 9c744b13..48117404 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -149,7 +149,7 @@ void BinaryFormat::InitializeBinary(int fd, ModelType model_type, unsigned int s void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const { assert(header_size_ != kInvalidSize); - util::PReadOrThrow(file_.get(), to, amount, offset_excluding_header + header_size_); + util::ErsatzPRead(file_.get(), to, amount, offset_excluding_header + header_size_); } void *BinaryFormat::LoadBinary(std::size_t size) { diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index f33f88d7..136d6b1a 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -1,5 +1,5 @@ -#ifndef LM_BINARY_FORMAT__ -#define LM_BINARY_FORMAT__ +#ifndef LM_BINARY_FORMAT_H +#define LM_BINARY_FORMAT_H #include "lm/config.hh" #include "lm/model_type.hh" @@ -103,4 +103,4 @@ bool IsBinaryFormat(int fd); } // namespace ngram } // namespace lm -#endif // LM_BINARY_FORMAT__ +#endif // LM_BINARY_FORMAT_H diff --git a/klm/lm/blank.hh b/klm/lm/blank.hh index 4da81209..94a71ad2 100644 --- a/klm/lm/blank.hh +++ b/klm/lm/blank.hh @@ -1,5 +1,5 @@ -#ifndef LM_BLANK__ -#define LM_BLANK__ +#ifndef LM_BLANK_H +#define LM_BLANK_H #include @@ -40,4 +40,4 @@ inline bool HasExtension(const float &backoff) { } // namespace ngram } // namespace lm -#endif // LM_BLANK__ +#endif // LM_BLANK_H diff --git a/klm/lm/builder/Makefile.am b/klm/lm/builder/Makefile.am index 38259c51..bb15ff04 100644 --- a/klm/lm/builder/Makefile.am +++ b/klm/lm/builder/Makefile.am @@ -1,4 +1,8 @@ -bin_PROGRAMS = lmplz +bin_PROGRAMS = lmplz dump_counts + +dump_counts_SOURCES = \ + print.cc \ + dump_counts_main.cc lmplz_SOURCES = \ lmplz_main.cc \ @@ -7,6 +11,7 @@ lmplz_SOURCES = \ corpus_count.cc \ corpus_count.hh \ discount.hh \ + hash_gamma.hh \ header_info.hh \ initial_probabilities.cc \ initial_probabilities.hh \ @@ -22,6 +27,7 @@ lmplz_SOURCES = \ print.hh \ sort.hh +dump_counts_LDADD = ../libklm.a ../../util/double-conversion/libklm_util_double.a ../../util/stream/libklm_util_stream.a ../../util/libklm_util.a $(BOOST_THREAD_LIBS) lmplz_LDADD = ../libklm.a ../../util/double-conversion/libklm_util_double.a ../../util/stream/libklm_util_stream.a ../../util/libklm_util.a $(BOOST_THREAD_LIBS) AM_CPPFLAGS = -W -Wall -I$(top_srcdir)/klm diff --git a/klm/lm/builder/adjust_counts.cc b/klm/lm/builder/adjust_counts.cc index a6f48011..803c557d 100644 --- a/klm/lm/builder/adjust_counts.cc +++ b/klm/lm/builder/adjust_counts.cc @@ -1,8 +1,9 @@ #include "lm/builder/adjust_counts.hh" -#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram_stream.hh" #include "util/stream/timer.hh" #include +#include namespace lm { namespace builder { @@ -10,56 +11,78 @@ BadDiscountException::BadDiscountException() throw() {} BadDiscountException::~BadDiscountException() throw() {} namespace { -// Return last word in full that is different. +// Return last word in full that is different. const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) { const WordIndex *cur_word = full.end() - 1; const WordIndex *pre_word = lower_last.end() - 1; - // Find last difference. + // Find last difference. for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {} return cur_word; } class StatCollector { public: - StatCollector(std::size_t order, std::vector &counts, std::vector &discounts) - : orders_(order), full_(orders_.back()), counts_(counts), discounts_(discounts) { + StatCollector(std::size_t order, std::vector &counts, std::vector &counts_pruned, std::vector &discounts) + : orders_(order), full_(orders_.back()), counts_(counts), counts_pruned_(counts_pruned), discounts_(discounts) { memset(&orders_[0], 0, sizeof(OrderStat) * order); } ~StatCollector() {} - void CalculateDiscounts() { + void CalculateDiscounts(const DiscountConfig &config) { counts_.resize(orders_.size()); - discounts_.resize(orders_.size()); + counts_pruned_.resize(orders_.size()); for (std::size_t i = 0; i < orders_.size(); ++i) { const OrderStat &s = orders_[i]; counts_[i] = s.count; + counts_pruned_[i] = s.count_pruned; + } - for (unsigned j = 1; j < 4; ++j) { - // TODO: Specialize error message for j == 3, meaning 3+ - UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for " - << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any " - << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?"); - } - - // See equation (26) in Chen and Goodman. - discounts_[i].amount[0] = 0.0; - float y = static_cast(s.n[1]) / static_cast(s.n[1] + 2.0 * s.n[2]); - for (unsigned j = 1; j < 4; ++j) { - discounts_[i].amount[j] = static_cast(j) - static_cast(j + 1) * y * static_cast(s.n[j+1]) / static_cast(s.n[j]); - UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]); + discounts_ = config.overwrite; + discounts_.resize(orders_.size()); + for (std::size_t i = config.overwrite.size(); i < orders_.size(); ++i) { + const OrderStat &s = orders_[i]; + try { + for (unsigned j = 1; j < 4; ++j) { + // TODO: Specialize error message for j == 3, meaning 3+ + UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for " + << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any " + << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?"); + } + + // See equation (26) in Chen and Goodman. + discounts_[i].amount[0] = 0.0; + float y = static_cast(s.n[1]) / static_cast(s.n[1] + 2.0 * s.n[2]); + for (unsigned j = 1; j < 4; ++j) { + discounts_[i].amount[j] = static_cast(j) - static_cast(j + 1) * y * static_cast(s.n[j+1]) / static_cast(s.n[j]); + UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]); + } + } catch (const BadDiscountException &e) { + switch (config.bad_action) { + case THROW_UP: + throw; + case COMPLAIN: + std::cerr << e.what() << " Substituting fallback discounts D1=" << config.fallback.amount[1] << " D2=" << config.fallback.amount[2] << " D3+=" << config.fallback.amount[3] << std::endl; + case SILENT: + break; + } + discounts_[i] = config.fallback; } } } - void Add(std::size_t order_minus_1, uint64_t count) { + void Add(std::size_t order_minus_1, uint64_t count, bool pruned = false) { OrderStat &stat = orders_[order_minus_1]; ++stat.count; + if (!pruned) + ++stat.count_pruned; if (count < 5) ++stat.n[count]; } - void AddFull(uint64_t count) { + void AddFull(uint64_t count, bool pruned = false) { ++full_.count; + if (!pruned) + ++full_.count_pruned; if (count < 5) ++full_.n[count]; } @@ -68,24 +91,27 @@ class StatCollector { // n_1 in equation 26 of Chen and Goodman etc uint64_t n[5]; uint64_t count; + uint64_t count_pruned; }; std::vector orders_; OrderStat &full_; std::vector &counts_; + std::vector &counts_pruned_; std::vector &discounts_; }; -// Reads all entries in order like NGramStream does. +// Reads all entries in order like NGramStream does. // But deletes any entries that have in the 1st (not 0th) position on the // way out by putting other entries in their place. This disrupts the sort -// order but we don't care because the data is going to be sorted again. +// order but we don't care because the data is going to be sorted again. class CollapseStream { public: - CollapseStream(const util::stream::ChainPosition &position) : + CollapseStream(const util::stream::ChainPosition &position, uint64_t prune_threshold) : current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), - block_(position) { + prune_threshold_(prune_threshold), + block_(position) { StartBlock(); } @@ -96,10 +122,18 @@ class CollapseStream { CollapseStream &operator++() { assert(block_); + if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) { memcpy(current_.Base(), copy_from_, current_.TotalSize()); UpdateCopyFrom(); + + // Mark highest order n-grams for later pruning + if(current_.Count() <= prune_threshold_) { + current_.Mark(); + } + } + current_.NextInMemory(); uint8_t *block_base = static_cast(block_->Get()); if (current_.Base() == block_base + block_->ValidSize()) { @@ -107,6 +141,12 @@ class CollapseStream { ++block_; StartBlock(); } + + // Mark highest order n-grams for later pruning + if(current_.Count() <= prune_threshold_) { + current_.Mark(); + } + return *this; } @@ -119,9 +159,15 @@ class CollapseStream { current_.ReBase(block_->Get()); copy_from_ = static_cast(block_->Get()) + block_->ValidSize(); UpdateCopyFrom(); + + // Mark highest order n-grams for later pruning + if(current_.Count() <= prune_threshold_) { + current_.Mark(); + } + } - // Find last without bos. + // Find last without bos. void UpdateCopyFrom() { for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) { if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break; @@ -132,83 +178,107 @@ class CollapseStream { // Goes backwards in the block uint8_t *copy_from_; - + uint64_t prune_threshold_; util::stream::Link block_; }; } // namespace -void AdjustCounts::Run(const ChainPositions &positions) { +void AdjustCounts::Run(const util::stream::ChainPositions &positions) { UTIL_TIMER("(%w s) Adjusted counts\n"); const std::size_t order = positions.size(); - StatCollector stats(order, counts_, discounts_); + StatCollector stats(order, counts_, counts_pruned_, discounts_); if (order == 1) { + // Only unigrams. Just collect stats. for (NGramStream full(positions[0]); full; ++full) stats.AddFull(full->Count()); - stats.CalculateDiscounts(); + + stats.CalculateDiscounts(discount_config_); return; } NGramStreams streams; streams.Init(positions, positions.size() - 1); - CollapseStream full(positions[positions.size() - 1]); + + CollapseStream full(positions[positions.size() - 1], prune_thresholds_.back()); - // Initialization: has count 0 and so does . + // Initialization: has count 0 and so does . NGramStream *lower_valid = streams.begin(); streams[0]->Count() = 0; *streams[0]->begin() = kUNK; stats.Add(0, 0); (++streams[0])->Count() = 0; *streams[0]->begin() = kBOS; - // not in stats because it will get put in later. + // not in stats because it will get put in later. + std::vector lower_counts(positions.size(), 0); + // iterate over full (the stream of the highest order ngrams) - for (; full; ++full) { + for (; full; ++full) { const WordIndex *different = FindDifference(*full, **lower_valid); std::size_t same = full->end() - 1 - different; - // Increment the adjusted count. + // Increment the adjusted count. if (same) ++streams[same - 1]->Count(); - // Output all the valid ones that changed. + // Output all the valid ones that changed. for (; lower_valid >= &streams[same]; --lower_valid) { - stats.Add(lower_valid - streams.begin(), (*lower_valid)->Count()); + + // mjd: review this! + uint64_t order = (*lower_valid)->Order(); + uint64_t realCount = lower_counts[order - 1]; + if(order > 1 && prune_thresholds_[order - 1] && realCount <= prune_thresholds_[order - 1]) + (*lower_valid)->Mark(); + + stats.Add(lower_valid - streams.begin(), (*lower_valid)->UnmarkedCount(), (*lower_valid)->IsMarked()); ++*lower_valid; } + + // Count the true occurrences of lower-order n-grams + for (std::size_t i = 0; i < lower_counts.size(); ++i) { + if (i >= same) { + lower_counts[i] = 0; + } + lower_counts[i] += full->UnmarkedCount(); + } // This is here because bos is also const WordIndex *, so copy gets - // consistent argument types. + // consistent argument types. const WordIndex *full_end = full->end(); - // Initialize and mark as valid up to bos. + // Initialize and mark as valid up to bos. const WordIndex *bos; for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) { ++lower_valid; std::copy(bos, full_end, (*lower_valid)->begin()); (*lower_valid)->Count() = 1; } - // Now bos indicates where is or is the 0th word of full. + // Now bos indicates where is or is the 0th word of full. if (bos != full->begin()) { - // There is an beyond the 0th word. + // There is an beyond the 0th word. NGramStream &to = *++lower_valid; std::copy(bos, full_end, to->begin()); - to->Count() = full->Count(); + + // mjd: what is this doing? + to->Count() = full->UnmarkedCount(); } else { - stats.AddFull(full->Count()); + stats.AddFull(full->UnmarkedCount(), full->IsMarked()); } assert(lower_valid >= &streams[0]); } // Output everything valid. for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) { - stats.Add(s - streams.begin(), (*s)->Count()); + if((*s)->Count() <= prune_thresholds_[(*s)->Order() - 1]) + (*s)->Mark(); + stats.Add(s - streams.begin(), (*s)->UnmarkedCount(), (*s)->IsMarked()); ++*s; } - // Poison everyone! Except the N-grams which were already poisoned by the input. + // Poison everyone! Except the N-grams which were already poisoned by the input. for (NGramStream *s = streams.begin(); s != streams.end(); ++s) s->Poison(); - stats.CalculateDiscounts(); + stats.CalculateDiscounts(discount_config_); // NOTE: See special early-return case for unigrams near the top of this function } diff --git a/klm/lm/builder/adjust_counts.hh b/klm/lm/builder/adjust_counts.hh index f38ff79d..a5435c28 100644 --- a/klm/lm/builder/adjust_counts.hh +++ b/klm/lm/builder/adjust_counts.hh @@ -1,24 +1,35 @@ -#ifndef LM_BUILDER_ADJUST_COUNTS__ -#define LM_BUILDER_ADJUST_COUNTS__ +#ifndef LM_BUILDER_ADJUST_COUNTS_H +#define LM_BUILDER_ADJUST_COUNTS_H #include "lm/builder/discount.hh" +#include "lm/lm_exception.hh" #include "util/exception.hh" #include #include +namespace util { namespace stream { class ChainPositions; } } + namespace lm { namespace builder { -class ChainPositions; - class BadDiscountException : public util::Exception { public: BadDiscountException() throw(); ~BadDiscountException() throw(); }; +struct DiscountConfig { + // Overrides discounts for orders [1,discount_override.size()]. + std::vector overwrite; + // If discounting fails for an order, copy them from here. + Discount fallback; + // What to do when discounts are out of range or would trigger divison by + // zero. It it does something other than THROW_UP, use fallback_discount. + WarningAction bad_action; +}; + /* Compute adjusted counts. * Input: unique suffix sorted N-grams (and just the N-grams) with raw counts. * Output: [1,N]-grams with adjusted counts. @@ -27,18 +38,32 @@ class BadDiscountException : public util::Exception { */ class AdjustCounts { public: - AdjustCounts(std::vector &counts, std::vector &discounts) - : counts_(counts), discounts_(discounts) {} + // counts: output + // counts_pruned: output + // discounts: mostly output. If the input already has entries, they will be kept. + // prune_thresholds: input. n-grams with normal (not adjusted) count below this will be pruned. + AdjustCounts( + const std::vector &prune_thresholds, + std::vector &counts, + std::vector &counts_pruned, + const DiscountConfig &discount_config, + std::vector &discounts) + : prune_thresholds_(prune_thresholds), counts_(counts), counts_pruned_(counts_pruned), discount_config_(discount_config), discounts_(discounts) + {} - void Run(const ChainPositions &positions); + void Run(const util::stream::ChainPositions &positions); private: + const std::vector &prune_thresholds_; std::vector &counts_; + std::vector &counts_pruned_; + + DiscountConfig discount_config_; std::vector &discounts_; }; } // namespace builder } // namespace lm -#endif // LM_BUILDER_ADJUST_COUNTS__ +#endif // LM_BUILDER_ADJUST_COUNTS_H diff --git a/klm/lm/builder/adjust_counts_test.cc b/klm/lm/builder/adjust_counts_test.cc index 68b5f33e..073c5dfe 100644 --- a/klm/lm/builder/adjust_counts_test.cc +++ b/klm/lm/builder/adjust_counts_test.cc @@ -1,6 +1,6 @@ #include "lm/builder/adjust_counts.hh" -#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram_stream.hh" #include "util/scoped.hh" #include @@ -61,19 +61,24 @@ BOOST_AUTO_TEST_CASE(Simple) { util::stream::ChainConfig config; config.total_memory = 100; config.block_count = 1; - Chains chains(4); + util::stream::Chains chains(4); for (unsigned i = 0; i < 4; ++i) { config.entry_size = NGram::TotalSize(i + 1); chains.push_back(config); } chains[3] >> WriteInput(); - ChainPositions for_adjust(chains); + util::stream::ChainPositions for_adjust(chains); for (unsigned i = 0; i < 4; ++i) { chains[i] >> boost::ref(outputs[i]); } chains >> util::stream::kRecycle; - BOOST_CHECK_THROW(AdjustCounts(counts, discount).Run(for_adjust), BadDiscountException); + std::vector counts_pruned(4); + std::vector prune_thresholds(4); + DiscountConfig discount_config; + discount_config.fallback = Discount(); + discount_config.bad_action = THROW_UP; + BOOST_CHECK_THROW(AdjustCounts(prune_thresholds, counts, counts_pruned, discount_config, discount).Run(for_adjust), BadDiscountException); } BOOST_REQUIRE_EQUAL(4UL, counts.size()); BOOST_CHECK_EQUAL(4UL, counts[0]); diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc index ccc06efc..590e79fa 100644 --- a/klm/lm/builder/corpus_count.cc +++ b/klm/lm/builder/corpus_count.cc @@ -2,6 +2,7 @@ #include "lm/builder/ngram.hh" #include "lm/lm_exception.hh" +#include "lm/vocab.hh" #include "lm/word_index.hh" #include "util/fake_ofstream.hh" #include "util/file.hh" @@ -37,60 +38,6 @@ struct VocabEntry { }; #pragma pack(pop) -const float kProbingMultiplier = 1.5; - -class VocabHandout { - public: - static std::size_t MemUsage(WordIndex initial_guess) { - if (initial_guess < 2) initial_guess = 2; - return util::CheckOverflow(Table::Size(initial_guess, kProbingMultiplier)); - } - - explicit VocabHandout(int fd, WordIndex initial_guess) : - table_backing_(util::CallocOrThrow(MemUsage(initial_guess))), - table_(table_backing_.get(), MemUsage(initial_guess)), - double_cutoff_(std::max(initial_guess * 1.1, 1)), - word_list_(fd) { - Lookup(""); // Force 0 - Lookup(""); // Force 1 - Lookup(""); // Force 2 - } - - WordIndex Lookup(const StringPiece &word) { - VocabEntry entry; - entry.key = util::MurmurHashNative(word.data(), word.size()); - entry.value = table_.SizeNoSerialization(); - - Table::MutableIterator it; - if (table_.FindOrInsert(entry, it)) - return it->value; - word_list_ << word << '\0'; - UTIL_THROW_IF(Size() >= std::numeric_limits::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh."); - if (Size() >= double_cutoff_) { - table_backing_.call_realloc(table_.DoubleTo()); - table_.Double(table_backing_.get()); - double_cutoff_ *= 2; - } - return entry.value; - } - - WordIndex Size() const { - return table_.SizeNoSerialization(); - } - - private: - // TODO: factor out a resizable probing hash table. - // TODO: use mremap on linux to get all zeros on resizes. - util::scoped_malloc table_backing_; - - typedef util::ProbingHashTable Table; - Table table_; - - std::size_t double_cutoff_; - - util::FakeOFStream word_list_; -}; - class DedupeHash : public std::unary_function { public: explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {} @@ -127,6 +74,10 @@ struct DedupeEntry { } }; + +// TODO: don't have this here, should be with probing hash table defaults? +const float kProbingMultiplier = 1.5; + typedef util::ProbingHashTable Dedupe; class Writer { @@ -220,37 +171,50 @@ float CorpusCount::DedupeMultiplier(std::size_t order) { } std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) { - return VocabHandout::MemUsage(vocab_estimate); + return ngram::GrowableVocab::MemUsage(vocab_estimate); } -CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block) +CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol) : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count), dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), - dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) { + dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)), + disallowed_symbol_action_(disallowed_symbol) { } -void CorpusCount::Run(const util::stream::ChainPosition &position) { - UTIL_TIMER("(%w s) Counted n-grams\n"); +namespace { + void ComplainDisallowed(StringPiece word, WarningAction &action) { + switch (action) { + case SILENT: + return; + case COMPLAIN: + std::cerr << "Warning: " << word << " appears in the input. All instances of , , and will be interpreted as whitespace." << std::endl; + action = SILENT; + return; + case THROW_UP: + UTIL_THROW(FormatLoadException, "Special word " << word << " is not allowed in the corpus. I plan to support models containing in the future. Pass --skip_symbols to convert these symbols to whitespace."); + } + } +} // namespace - VocabHandout vocab(vocab_write_, type_count_); +void CorpusCount::Run(const util::stream::ChainPosition &position) { + ngram::GrowableVocab vocab(type_count_, vocab_write_); token_count_ = 0; type_count_ = 0; - const WordIndex end_sentence = vocab.Lookup(""); + const WordIndex end_sentence = vocab.FindOrInsert(""); Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); uint64_t count = 0; bool delimiters[256]; - memset(delimiters, 0, sizeof(delimiters)); - const char kDelimiterSet[] = "\0\t\n\r "; - for (const char *i = kDelimiterSet; i < kDelimiterSet + sizeof(kDelimiterSet); ++i) { - delimiters[static_cast(*i)] = true; - } + util::BoolCharacter::Build("\0\t\n\r ", delimiters); try { while(true) { StringPiece line(from_.ReadLine()); writer.StartSentence(); for (util::TokenIter w(line, delimiters); w; ++w) { - WordIndex word = vocab.Lookup(*w); - UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing in the future."); + WordIndex word = vocab.FindOrInsert(*w); + if (word <= 2) { + ComplainDisallowed(*w, disallowed_symbol_action_); + continue; + } writer.Append(word); ++count; } diff --git a/klm/lm/builder/corpus_count.hh b/klm/lm/builder/corpus_count.hh index aa0ed8ed..da4ff9fc 100644 --- a/klm/lm/builder/corpus_count.hh +++ b/klm/lm/builder/corpus_count.hh @@ -1,6 +1,7 @@ -#ifndef LM_BUILDER_CORPUS_COUNT__ -#define LM_BUILDER_CORPUS_COUNT__ +#ifndef LM_BUILDER_CORPUS_COUNT_H +#define LM_BUILDER_CORPUS_COUNT_H +#include "lm/lm_exception.hh" #include "lm/word_index.hh" #include "util/scoped.hh" @@ -28,7 +29,7 @@ class CorpusCount { // token_count: out. // type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value. - CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block); + CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol); void Run(const util::stream::ChainPosition &position); @@ -40,8 +41,10 @@ class CorpusCount { std::size_t dedupe_mem_size_; util::scoped_malloc dedupe_mem_; + + WarningAction disallowed_symbol_action_; }; } // namespace builder } // namespace lm -#endif // LM_BUILDER_CORPUS_COUNT__ +#endif // LM_BUILDER_CORPUS_COUNT_H diff --git a/klm/lm/builder/corpus_count_test.cc b/klm/lm/builder/corpus_count_test.cc index 6d325ef5..26cb6346 100644 --- a/klm/lm/builder/corpus_count_test.cc +++ b/klm/lm/builder/corpus_count_test.cc @@ -45,7 +45,7 @@ BOOST_AUTO_TEST_CASE(Short) { NGramStream stream; uint64_t token_count; WordIndex type_count = 10; - CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize()); + CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize(), SILENT); chain >> boost::ref(counter) >> stream >> util::stream::kRecycle; const char *v[] = {"", "", "", "looking", "on", "a", "little", "more", "loin", "foo", "bar"}; diff --git a/klm/lm/builder/discount.hh b/klm/lm/builder/discount.hh index 4d0aa4fd..e2f40846 100644 --- a/klm/lm/builder/discount.hh +++ b/klm/lm/builder/discount.hh @@ -1,5 +1,5 @@ -#ifndef BUILDER_DISCOUNT__ -#define BUILDER_DISCOUNT__ +#ifndef LM_BUILDER_DISCOUNT_H +#define LM_BUILDER_DISCOUNT_H #include @@ -23,4 +23,4 @@ struct Discount { } // namespace builder } // namespace lm -#endif // BUILDER_DISCOUNT__ +#endif // LM_BUILDER_DISCOUNT_H diff --git a/klm/lm/builder/dump_counts_main.cc b/klm/lm/builder/dump_counts_main.cc new file mode 100644 index 00000000..fa001679 --- /dev/null +++ b/klm/lm/builder/dump_counts_main.cc @@ -0,0 +1,36 @@ +#include "lm/builder/print.hh" +#include "lm/word_index.hh" +#include "util/file.hh" +#include "util/read_compressed.hh" + +#include + +#include +#include + +int main(int argc, char *argv[]) { + if (argc != 4) { + std::cerr << "Usage: " << argv[0] << " counts vocabulary order\n" + "The counts file contains records with 4-byte vocabulary ids followed by 8-byte\n" + "counts. Each record has order many vocabulary ids.\n" + "The vocabulary file contains the words delimited by NULL in order of id.\n" + "The vocabulary file may not be compressed because it is mmapped but the counts\n" + "file can be compressed.\n"; + return 1; + } + util::ReadCompressed counts(util::OpenReadOrThrow(argv[1])); + util::scoped_fd vocab_file(util::OpenReadOrThrow(argv[2])); + lm::builder::VocabReconstitute vocab(vocab_file.get()); + unsigned int order = boost::lexical_cast(argv[3]); + std::vector record(sizeof(uint32_t) * order + sizeof(uint64_t)); + while (std::size_t got = counts.ReadOrEOF(&*record.begin(), record.size())) { + UTIL_THROW_IF(got != record.size(), util::Exception, "Read " << got << " bytes at the end of file, which is not a complete record of length " << record.size()); + const lm::WordIndex *words = reinterpret_cast(&*record.begin()); + for (const lm::WordIndex *i = words; i != words + order; ++i) { + UTIL_THROW_IF(*i >= vocab.Size(), util::Exception, "Vocab ID " << *i << " is larger than the vocab file's maximum of " << vocab.Size() << ". Are you sure you have the right order and vocab file for these counts?"); + std::cout << vocab.Lookup(*i) << ' '; + } + // TODO don't use std::cout because it is slow. Add fast uint64_t printing support to FakeOFStream. + std::cout << *reinterpret_cast(words + order) << '\n'; + } +} diff --git a/klm/lm/builder/hash_gamma.hh b/klm/lm/builder/hash_gamma.hh new file mode 100644 index 00000000..4bef47e8 --- /dev/null +++ b/klm/lm/builder/hash_gamma.hh @@ -0,0 +1,19 @@ +#ifndef LM_BUILDER_HASH_GAMMA__ +#define LM_BUILDER_HASH_GAMMA__ + +#include + +namespace lm { namespace builder { + +#pragma pack(push) +#pragma pack(4) + +struct HashGamma { + uint64_t hash_value; + float gamma; +}; + +#pragma pack(pop) + +}} // namespaces +#endif // LM_BUILDER_HASH_GAMMA__ diff --git a/klm/lm/builder/header_info.hh b/klm/lm/builder/header_info.hh index ccca1456..16f3f609 100644 --- a/klm/lm/builder/header_info.hh +++ b/klm/lm/builder/header_info.hh @@ -1,5 +1,5 @@ -#ifndef LM_BUILDER_HEADER_INFO__ -#define LM_BUILDER_HEADER_INFO__ +#ifndef LM_BUILDER_HEADER_INFO_H +#define LM_BUILDER_HEADER_INFO_H #include #include diff --git a/klm/lm/builder/initial_probabilities.cc b/klm/lm/builder/initial_probabilities.cc index 58b42a20..5d19a897 100644 --- a/klm/lm/builder/initial_probabilities.cc +++ b/klm/lm/builder/initial_probabilities.cc @@ -3,6 +3,8 @@ #include "lm/builder/discount.hh" #include "lm/builder/ngram_stream.hh" #include "lm/builder/sort.hh" +#include "lm/builder/hash_gamma.hh" +#include "util/murmur_hash.hh" #include "util/file.hh" #include "util/stream/chain.hh" #include "util/stream/io.hh" @@ -14,55 +16,182 @@ namespace lm { namespace builder { namespace { struct BufferEntry { - // Gamma from page 20 of Chen and Goodman. + // Gamma from page 20 of Chen and Goodman. float gamma; - // \sum_w a(c w) for all w. + // \sum_w a(c w) for all w. float denominator; }; -// Extract an array of gamma from an array of BufferEntry. +struct HashBufferEntry : public BufferEntry { + // Hash value of ngram. Used to join contexts with backoffs. + uint64_t hash_value; +}; + +// Reads all entries in order like NGramStream does. +// But deletes any entries that have CutoffCount below or equal to pruning +// threshold. +class PruneNGramStream { + public: + PruneNGramStream(const util::stream::ChainPosition &position) : + current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), + dest_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), + currentCount_(0), + block_(position) + { + StartBlock(); + } + + NGram &operator*() { return current_; } + NGram *operator->() { return ¤t_; } + + operator bool() const { + return block_; + } + + PruneNGramStream &operator++() { + assert(block_); + + if (current_.Order() > 1) { + if(currentCount_ > 0) { + if(dest_.Base() < current_.Base()) { + memcpy(dest_.Base(), current_.Base(), current_.TotalSize()); + } + dest_.NextInMemory(); + } + } else { + dest_.NextInMemory(); + } + + current_.NextInMemory(); + + uint8_t *block_base = static_cast(block_->Get()); + if (current_.Base() == block_base + block_->ValidSize()) { + block_->SetValidSize(dest_.Base() - block_base); + ++block_; + StartBlock(); + if (block_) { + currentCount_ = current_.CutoffCount(); + } + } else { + currentCount_ = current_.CutoffCount(); + } + + return *this; + } + + private: + void StartBlock() { + for (; ; ++block_) { + if (!block_) return; + if (block_->ValidSize()) break; + } + current_.ReBase(block_->Get()); + currentCount_ = current_.CutoffCount(); + + dest_.ReBase(block_->Get()); + } + + NGram current_; // input iterator + NGram dest_; // output iterator + + uint64_t currentCount_; + + util::stream::Link block_; +}; + +// Extract an array of HashedGamma from an array of BufferEntry. class OnlyGamma { public: + OnlyGamma(bool pruning) : pruning_(pruning) {} + void Run(const util::stream::ChainPosition &position) { for (util::stream::Link block_it(position); block_it; ++block_it) { - float *out = static_cast(block_it->Get()); - const float *in = out; - const float *end = static_cast(block_it->ValidEnd()); - for (out += 1, in += 2; in < end; out += 1, in += 2) { - *out = *in; + if(pruning_) { + const HashBufferEntry *in = static_cast(block_it->Get()); + const HashBufferEntry *end = static_cast(block_it->ValidEnd()); + + // Just make it point to the beginning of the stream so it can be overwritten + // With HashGamma values. Do not attempt to interpret the values until set below. + HashGamma *out = static_cast(block_it->Get()); + for (; in < end; out += 1, in += 1) { + // buffering, otherwise might overwrite values too early + float gamma_buf = in->gamma; + uint64_t hash_buf = in->hash_value; + + out->gamma = gamma_buf; + out->hash_value = hash_buf; + } + block_it->SetValidSize((block_it->ValidSize() * sizeof(HashGamma)) / sizeof(HashBufferEntry)); + } + else { + float *out = static_cast(block_it->Get()); + const float *in = out; + const float *end = static_cast(block_it->ValidEnd()); + for (out += 1, in += 2; in < end; out += 1, in += 2) { + *out = *in; + } + block_it->SetValidSize(block_it->ValidSize() / 2); } - block_it->SetValidSize(block_it->ValidSize() / 2); } } + + private: + bool pruning_; }; class AddRight { public: - AddRight(const Discount &discount, const util::stream::ChainPosition &input) - : discount_(discount), input_(input) {} + AddRight(const Discount &discount, const util::stream::ChainPosition &input, bool pruning) + : discount_(discount), input_(input), pruning_(pruning) {} void Run(const util::stream::ChainPosition &output) { NGramStream in(input_); util::stream::Stream out(output); std::vector previous(in->Order() - 1); + // Silly windows requires this workaround to just get an invalid pointer when empty. + void *const previous_raw = previous.empty() ? NULL : static_cast(&previous[0]); const std::size_t size = sizeof(WordIndex) * previous.size(); + for(; in; ++out) { - memcpy(&previous[0], in->begin(), size); + memcpy(previous_raw, in->begin(), size); uint64_t denominator = 0; + uint64_t normalizer = 0; + uint64_t counts[4]; memset(counts, 0, sizeof(counts)); do { - denominator += in->Count(); - ++counts[std::min(in->Count(), static_cast(3))]; - } while (++in && !memcmp(&previous[0], in->begin(), size)); + denominator += in->UnmarkedCount(); + + // Collect unused probability mass from pruning. + // Becomes 0 for unpruned ngrams. + normalizer += in->UnmarkedCount() - in->CutoffCount(); + + // Chen&Goodman do not mention counting based on cutoffs, but + // backoff becomes larger than 1 otherwise, so probably needs + // to count cutoffs. Counts normally without pruning. + if(in->CutoffCount() > 0) + ++counts[std::min(in->CutoffCount(), static_cast(3))]; + + } while (++in && !memcmp(previous_raw, in->begin(), size)); + BufferEntry &entry = *reinterpret_cast(out.Get()); entry.denominator = static_cast(denominator); entry.gamma = 0.0; for (unsigned i = 1; i <= 3; ++i) { entry.gamma += discount_.Get(i) * static_cast(counts[i]); } + + // Makes model sum to 1 with pruning (I hope). + entry.gamma += normalizer; + entry.gamma /= entry.denominator; + + if(pruning_) { + // If pruning is enabled the stream actually contains HashBufferEntry, see InitialProbabilities(...), + // so add a hash value that identifies the current ngram. + static_cast(&entry)->hash_value = util::MurmurHashNative(previous_raw, size); + } } out.Poison(); } @@ -70,6 +199,7 @@ class AddRight { private: const Discount &discount_; const util::stream::ChainPosition input_; + bool pruning_; }; class MergeRight { @@ -82,7 +212,7 @@ class MergeRight { void Run(const util::stream::ChainPosition &primary) { util::stream::Stream summed(from_adder_); - NGramStream grams(primary); + PruneNGramStream grams(primary); // Without interpolation, the interpolation weight goes to . if (grams->Order() == 1 && !interpolate_unigrams_) { @@ -97,15 +227,16 @@ class MergeRight { ++summed; return; } - + std::vector previous(grams->Order() - 1); const std::size_t size = sizeof(WordIndex) * previous.size(); for (; grams; ++summed) { memcpy(&previous[0], grams->begin(), size); const BufferEntry &sums = *static_cast(summed.Get()); + do { Payload &pay = grams->Value(); - pay.uninterp.prob = discount_.Apply(pay.count) / sums.denominator; + pay.uninterp.prob = discount_.Apply(grams->UnmarkedCount()) / sums.denominator; pay.uninterp.gamma = sums.gamma; } while (++grams && !memcmp(&previous[0], grams->begin(), size)); } @@ -119,17 +250,29 @@ class MergeRight { } // namespace -void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector &discounts, Chains &primary, Chains &second_in, Chains &gamma_out) { - util::stream::ChainConfig gamma_config = config.adder_out; - gamma_config.entry_size = sizeof(BufferEntry); +void InitialProbabilities( + const InitialProbabilitiesConfig &config, + const std::vector &discounts, + util::stream::Chains &primary, + util::stream::Chains &second_in, + util::stream::Chains &gamma_out, + const std::vector &prune_thresholds) { for (size_t i = 0; i < primary.size(); ++i) { + util::stream::ChainConfig gamma_config = config.adder_out; + if(prune_thresholds[i] > 0) + gamma_config.entry_size = sizeof(HashBufferEntry); + else + gamma_config.entry_size = sizeof(BufferEntry); + util::stream::ChainPosition second(second_in[i].Add()); second_in[i] >> util::stream::kRecycle; gamma_out.push_back(gamma_config); - gamma_out[i] >> AddRight(discounts[i], second); + gamma_out[i] >> AddRight(discounts[i], second, prune_thresholds[i] > 0); + primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]); - // Don't bother with the OnlyGamma thread for something to discard. - if (i) gamma_out[i] >> OnlyGamma(); + + // Don't bother with the OnlyGamma thread for something to discard. + if (i) gamma_out[i] >> OnlyGamma(prune_thresholds[i] > 0); } } diff --git a/klm/lm/builder/initial_probabilities.hh b/klm/lm/builder/initial_probabilities.hh index 626388eb..c1010e08 100644 --- a/klm/lm/builder/initial_probabilities.hh +++ b/klm/lm/builder/initial_probabilities.hh @@ -1,14 +1,15 @@ -#ifndef LM_BUILDER_INITIAL_PROBABILITIES__ -#define LM_BUILDER_INITIAL_PROBABILITIES__ +#ifndef LM_BUILDER_INITIAL_PROBABILITIES_H +#define LM_BUILDER_INITIAL_PROBABILITIES_H #include "lm/builder/discount.hh" #include "util/stream/config.hh" #include +namespace util { namespace stream { class Chains; } } + namespace lm { namespace builder { -class Chains; struct InitialProbabilitiesConfig { // These should be small buffers to keep the adder from getting too far ahead @@ -26,9 +27,15 @@ struct InitialProbabilitiesConfig { * The values are bare floats and should be buffered for interpolation to * use. */ -void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector &discounts, Chains &primary, Chains &second_in, Chains &gamma_out); +void InitialProbabilities( + const InitialProbabilitiesConfig &config, + const std::vector &discounts, + util::stream::Chains &primary, + util::stream::Chains &second_in, + util::stream::Chains &gamma_out, + const std::vector &prune_thresholds); } // namespace builder } // namespace lm -#endif // LM_BUILDER_INITIAL_PROBABILITIES__ +#endif // LM_BUILDER_INITIAL_PROBABILITIES_H diff --git a/klm/lm/builder/interpolate.cc b/klm/lm/builder/interpolate.cc index 50026806..a7947a42 100644 --- a/klm/lm/builder/interpolate.cc +++ b/klm/lm/builder/interpolate.cc @@ -1,18 +1,74 @@ #include "lm/builder/interpolate.hh" +#include "lm/builder/hash_gamma.hh" #include "lm/builder/joint_order.hh" -#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram_stream.hh" #include "lm/builder/sort.hh" #include "lm/lm_exception.hh" +#include "util/fixed_array.hh" +#include "util/murmur_hash.hh" #include +#include namespace lm { namespace builder { namespace { -class Callback { +/* Calculate q, the collapsed probability and backoff, as defined in + * @inproceedings{Heafield-rest, + * author = {Kenneth Heafield and Philipp Koehn and Alon Lavie}, + * title = {Language Model Rest Costs and Space-Efficient Storage}, + * year = {2012}, + * month = {July}, + * booktitle = {Proceedings of the Joint Conference on Empirical Methods in Natural Language Processing and Computational Natural Language Learning}, + * address = {Jeju Island, Korea}, + * pages = {1169--1178}, + * url = {http://kheafield.com/professional/edinburgh/rest\_paper.pdf}, + * } + * This is particularly convenient to calculate during interpolation because + * the needed backoff terms are already accessed at the same time. + */ +class OutputQ { public: - Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) { + explicit OutputQ(std::size_t order) : q_delta_(order) {} + + void Gram(unsigned order_minus_1, float full_backoff, ProbBackoff &out) { + float &q_del = q_delta_[order_minus_1]; + if (order_minus_1) { + // Divide by context's backoff (which comes in as out.backoff) + q_del = q_delta_[order_minus_1 - 1] / out.backoff * full_backoff; + } else { + q_del = full_backoff; + } + out.prob = log10f(out.prob * q_del); + // TODO: stop wastefully outputting this! + out.backoff = 0.0; + } + + private: + // Product of backoffs in the numerator divided by backoffs in the + // denominator. Does not include + std::vector q_delta_; +}; + +/* Default: output probability and backoff */ +class OutputProbBackoff { + public: + explicit OutputProbBackoff(std::size_t /*order*/) {} + + void Gram(unsigned /*order_minus_1*/, float full_backoff, ProbBackoff &out) const { + // Correcting for numerical precision issues. Take that IRST. + out.prob = std::min(0.0f, log10f(out.prob)); + out.backoff = log10f(full_backoff); + } +}; + +template class Callback { + public: + Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector &prune_thresholds) + : backoffs_(backoffs.size()), probs_(backoffs.size() + 2), + prune_thresholds_(prune_thresholds), + output_(backoffs.size() + 1 /* order */) { probs_[0] = uniform_prob; for (std::size_t i = 0; i < backoffs.size(); ++i) { backoffs_.push_back(backoffs[i]); @@ -21,6 +77,10 @@ class Callback { ~Callback() { for (std::size_t i = 0; i < backoffs_.size(); ++i) { + if(prune_thresholds_[i + 1] > 0) + while(backoffs_[i]) + ++backoffs_[i]; + if (backoffs_[i]) { std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl; abort(); @@ -32,34 +92,66 @@ class Callback { Payload &pay = gram.Value(); pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1]; probs_[order_minus_1 + 1] = pay.complete.prob; - pay.complete.prob = log10(pay.complete.prob); - // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling. + + float out_backoff; if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) { - pay.complete.backoff = log10(*static_cast(backoffs_[order_minus_1].Get())); - ++backoffs_[order_minus_1]; + if(prune_thresholds_[order_minus_1 + 1] > 0) { + //Compute hash value for current context + uint64_t current_hash = util::MurmurHashNative(gram.begin(), gram.Order() * sizeof(WordIndex)); + + const HashGamma *hashed_backoff = static_cast(backoffs_[order_minus_1].Get()); + while(current_hash != hashed_backoff->hash_value && ++backoffs_[order_minus_1]) + hashed_backoff = static_cast(backoffs_[order_minus_1].Get()); + + if(current_hash == hashed_backoff->hash_value) { + out_backoff = hashed_backoff->gamma; + ++backoffs_[order_minus_1]; + } else { + // Has been pruned away so it is not a context anymore + out_backoff = 1.0; + } + } else { + out_backoff = *static_cast(backoffs_[order_minus_1].Get()); + ++backoffs_[order_minus_1]; + } } else { - // Not a context. - pay.complete.backoff = 0.0; + // Not a context. + out_backoff = 1.0; } + + output_.Gram(order_minus_1, out_backoff, pay.complete); } void Exit(unsigned, const NGram &) const {} private: - FixedArray backoffs_; + util::FixedArray backoffs_; std::vector probs_; + const std::vector& prune_thresholds_; + + Output output_; }; } // namespace -Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs) - : uniform_prob_(1.0 / static_cast(unigram_count - 1)), backoffs_(backoffs) {} +Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector& prune_thresholds, bool output_q) + : uniform_prob_(1.0 / static_cast(vocab_size)), // Includes but excludes . + backoffs_(backoffs), + prune_thresholds_(prune_thresholds), + output_q_(output_q) {} // perform order-wise interpolation -void Interpolate::Run(const ChainPositions &positions) { +void Interpolate::Run(const util::stream::ChainPositions &positions) { assert(positions.size() == backoffs_.size() + 1); - Callback callback(uniform_prob_, backoffs_); - JointOrder(positions, callback); + if (output_q_) { + typedef Callback C; + C callback(uniform_prob_, backoffs_, prune_thresholds_); + JointOrder(positions, callback); + } else { + typedef Callback C; + C callback(uniform_prob_, backoffs_, prune_thresholds_); + JointOrder(positions, callback); + } } }} // namespaces diff --git a/klm/lm/builder/interpolate.hh b/klm/lm/builder/interpolate.hh index 9268d404..0acece92 100644 --- a/klm/lm/builder/interpolate.hh +++ b/klm/lm/builder/interpolate.hh @@ -1,9 +1,11 @@ -#ifndef LM_BUILDER_INTERPOLATE__ -#define LM_BUILDER_INTERPOLATE__ +#ifndef LM_BUILDER_INTERPOLATE_H +#define LM_BUILDER_INTERPOLATE_H -#include +#include "util/stream/multi_stream.hh" + +#include -#include "lm/builder/multi_stream.hh" +#include namespace lm { namespace builder { @@ -14,14 +16,18 @@ namespace lm { namespace builder { */ class Interpolate { public: - explicit Interpolate(uint64_t unigram_count, const ChainPositions &backoffs); + // Normally vocab_size is the unigram count-1 (since p() = 0) but might + // be larger when the user specifies a consistent vocabulary size. + explicit Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector &prune_thresholds, bool output_q_); - void Run(const ChainPositions &positions); + void Run(const util::stream::ChainPositions &positions); private: float uniform_prob_; - ChainPositions backoffs_; + util::stream::ChainPositions backoffs_; + const std::vector prune_thresholds_; + bool output_q_; }; }} // namespaces -#endif // LM_BUILDER_INTERPOLATE__ +#endif // LM_BUILDER_INTERPOLATE_H diff --git a/klm/lm/builder/joint_order.hh b/klm/lm/builder/joint_order.hh index b5620144..7235d4f7 100644 --- a/klm/lm/builder/joint_order.hh +++ b/klm/lm/builder/joint_order.hh @@ -1,14 +1,14 @@ -#ifndef LM_BUILDER_JOINT_ORDER__ -#define LM_BUILDER_JOINT_ORDER__ +#ifndef LM_BUILDER_JOINT_ORDER_H +#define LM_BUILDER_JOINT_ORDER_H -#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram_stream.hh" #include "lm/lm_exception.hh" #include namespace lm { namespace builder { -template void JointOrder(const ChainPositions &positions, Callback &callback) { +template void JointOrder(const util::stream::ChainPositions &positions, Callback &callback) { // Allow matching to reference streams[-1]. NGramStreams streams_with_dummy; streams_with_dummy.InitWithDummy(positions); @@ -40,4 +40,4 @@ template void JointOrder(const ChainPositions &p }} // namespaces -#endif // LM_BUILDER_JOINT_ORDER__ +#endif // LM_BUILDER_JOINT_ORDER_H diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc index 2563deed..265dd216 100644 --- a/klm/lm/builder/lmplz_main.cc +++ b/klm/lm/builder/lmplz_main.cc @@ -1,4 +1,5 @@ #include "lm/builder/pipeline.hh" +#include "lm/lm_exception.hh" #include "util/file.hh" #include "util/file_piece.hh" #include "util/usage.hh" @@ -7,6 +8,7 @@ #include #include +#include namespace { class SizeNotify { @@ -25,6 +27,57 @@ boost::program_options::typed_value *SizeOption(std::size_t &to, co return boost::program_options::value()->notifier(SizeNotify(to))->default_value(default_value); } +// Parse and validate pruning thresholds then return vector of threshold counts +// for each n-grams order. +std::vector ParsePruning(const std::vector ¶m, std::size_t order) { + // convert to vector of integers + std::vector prune_thresholds; + prune_thresholds.reserve(order); + for (std::vector::const_iterator it(param.begin()); it != param.end(); ++it) { + try { + prune_thresholds.push_back(boost::lexical_cast(*it)); + } catch(const boost::bad_lexical_cast &) { + UTIL_THROW(util::Exception, "Bad pruning threshold " << *it); + } + } + + // Fill with zeros by default. + if (prune_thresholds.empty()) { + prune_thresholds.resize(order, 0); + return prune_thresholds; + } + + // validate pruning threshold if specified + // throw if each n-gram order has not threshold specified + UTIL_THROW_IF(prune_thresholds.size() > order, util::Exception, "You specified pruning thresholds for orders 1 through " << prune_thresholds.size() << " but the model only has order " << order); + // threshold for unigram can only be 0 (no pruning) + UTIL_THROW_IF(prune_thresholds[0] != 0, util::Exception, "Unigram pruning is not implemented, so the first pruning threshold must be 0."); + + // check if threshold are not in decreasing order + uint64_t lower_threshold = 0; + for (std::vector::iterator it = prune_thresholds.begin(); it != prune_thresholds.end(); ++it) { + UTIL_THROW_IF(lower_threshold > *it, util::Exception, "Pruning thresholds should be in non-decreasing order. Otherwise substrings would be removed, which is bad for query-time data structures."); + lower_threshold = *it; + } + + // Pad to all orders using the last value. + prune_thresholds.resize(order, prune_thresholds.back()); + return prune_thresholds; +} + +lm::builder::Discount ParseDiscountFallback(const std::vector ¶m) { + lm::builder::Discount ret; + UTIL_THROW_IF(param.size() > 3, util::Exception, "Specify at most three fallback discounts: 1, 2, and 3+"); + UTIL_THROW_IF(param.empty(), util::Exception, "Fallback discounting enabled, but no discount specified"); + ret.amount[0] = 0.0; + for (unsigned i = 0; i < 3; ++i) { + float discount = boost::lexical_cast(param[i < param.size() ? i : (param.size() - 1)]); + UTIL_THROW_IF(discount < 0.0 || discount > static_cast(i+1), util::Exception, "The discount for count " << (i+1) << " was parsed as " << discount << " which is not in the range [0, " << (i+1) << "]."); + ret.amount[i + 1] = discount; + } + return ret; +} + } // namespace int main(int argc, char *argv[]) { @@ -34,25 +87,36 @@ int main(int argc, char *argv[]) { lm::builder::PipelineConfig pipeline; std::string text, arpa; + std::vector pruning; + std::vector discount_fallback; + std::vector discount_fallback_default; + discount_fallback_default.push_back("0.5"); + discount_fallback_default.push_back("1"); + discount_fallback_default.push_back("1.5"); options.add_options() - ("help", po::bool_switch(), "Show this help message") + ("help,h", po::bool_switch(), "Show this help message") ("order,o", po::value(&pipeline.order) #if BOOST_VERSION >= 104200 ->required() #endif , "Order of the model") - ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)") + ("interpolate_unigrams", po::value(&pipeline.initial_probs.interpolate_unigrams)->default_value(true)->implicit_value(true), "Interpolate the unigrams (default) as opposed to giving lots of mass to like SRI. If you want SRI's behavior with a large and the old lmplz default, use --interpolate_unigrams 0.") + ("skip_symbols", po::bool_switch(), "Treat , , and as whitespace instead of throwing an exception") ("temp_prefix,T", po::value(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix") ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory") ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow") ("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)") - ("vocab_estimate", po::value(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table") ("block_count", po::value(&pipeline.block_count)->default_value(2), "Block count (per order)") - ("vocab_file", po::value(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file") + ("vocab_estimate", po::value(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table") + ("vocab_file", po::value(&pipeline.vocab_file)->default_value(""), "Location to write a file containing the unique vocabulary strings delimited by null bytes") + ("vocab_pad", po::value(&pipeline.vocab_size_for_unk)->default_value(0), "If the vocabulary is smaller than this value, pad with to reach this size. Requires --interpolate_unigrams") ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.") ("text", po::value(&text), "Read text from a file instead of stdin") - ("arpa", po::value(&arpa), "Write ARPA to a file instead of stdout"); + ("arpa", po::value(&arpa), "Write ARPA to a file instead of stdout") + ("collapse_values", po::bool_switch(&pipeline.output_q), "Collapse probability and backoff into a single value, q that yields the same sentence-level probabilities. See http://kheafield.com/professional/edinburgh/rest_paper.pdf for more details, including a proof.") + ("prune", po::value >(&pruning)->multitoken(), "Prune n-grams with count less than or equal to the given threshold. Specify one value for each order i.e. 0 0 1 to prune singleton trigrams and above. The sequence of values must be non-decreasing and the last value applies to any remaining orders. Unigram pruning is not implemented, so the first value must be zero. Default is to not prune, which is equivalent to --prune 0.") + ("discount_fallback", po::value >(&discount_fallback)->multitoken()->implicit_value(discount_fallback_default, "0.5 1 1.5"), "The closed-form estimate for Kneser-Ney discounts does not work without singletons or doubletons. It can also fail if these values are out of range. This option falls back to user-specified discounts when the closed-form estimate fails. Note that this option is generally a bad idea: you should deduplicate your corpus instead. However, class-based models need custom discounts because they lack singleton unigrams. Provide up to three discounts (for adjusted counts 1, 2, and 3+), which will be applied to all orders where the closed-form estimates fail."); po::variables_map vm; po::store(po::parse_command_line(argc, argv, options), vm); @@ -95,6 +159,29 @@ int main(int argc, char *argv[]) { } #endif + if (pipeline.vocab_size_for_unk && !pipeline.initial_probs.interpolate_unigrams) { + std::cerr << "--vocab_pad requires --interpolate_unigrams be on" << std::endl; + return 1; + } + + if (vm["skip_symbols"].as()) { + pipeline.disallowed_symbol_action = lm::COMPLAIN; + } else { + pipeline.disallowed_symbol_action = lm::THROW_UP; + } + + if (vm.count("discount_fallback")) { + pipeline.discount.fallback = ParseDiscountFallback(discount_fallback); + pipeline.discount.bad_action = lm::COMPLAIN; + } else { + // Unused, just here to prevent the compiler from complaining about uninitialized. + pipeline.discount.fallback = lm::builder::Discount(); + pipeline.discount.bad_action = lm::THROW_UP; + } + + // parse pruning thresholds. These depend on order, so it is not done as a notifier. + pipeline.prune_thresholds = ParsePruning(pruning, pipeline.order); + util::NormalizeTempPrefix(pipeline.sort.temp_prefix); lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs; diff --git a/klm/lm/builder/ngram.hh b/klm/lm/builder/ngram.hh index f5681516..0472bcb1 100644 --- a/klm/lm/builder/ngram.hh +++ b/klm/lm/builder/ngram.hh @@ -1,5 +1,5 @@ -#ifndef LM_BUILDER_NGRAM__ -#define LM_BUILDER_NGRAM__ +#ifndef LM_BUILDER_NGRAM_H +#define LM_BUILDER_NGRAM_H #include "lm/weights.hh" #include "lm/word_index.hh" @@ -26,7 +26,7 @@ union Payload { class NGram { public: - NGram(void *begin, std::size_t order) + NGram(void *begin, std::size_t order) : begin_(static_cast(begin)), end_(begin_ + order) {} const uint8_t *Base() const { return reinterpret_cast(begin_); } @@ -38,12 +38,12 @@ class NGram { end_ = begin_ + difference; } - // Would do operator++ but that can get confusing for a stream. + // Would do operator++ but that can get confusing for a stream. void NextInMemory() { ReBase(&Value() + 1); } - // Lower-case in deference to STL. + // Lower-case in deference to STL. const WordIndex *begin() const { return begin_; } WordIndex *begin() { return begin_; } const WordIndex *end() const { return end_; } @@ -61,7 +61,7 @@ class NGram { return order * sizeof(WordIndex) + sizeof(Payload); } std::size_t TotalSize() const { - // Compiler should optimize this. + // Compiler should optimize this. return TotalSize(Order()); } static std::size_t OrderFromSize(std::size_t size) { @@ -69,6 +69,31 @@ class NGram { assert(size == TotalSize(ret)); return ret; } + + // manipulate msb to signal that ngram can be pruned + /*mjd**********************************************************************/ + + bool IsMarked() const { + return Value().count >> (sizeof(Value().count) * 8 - 1); + } + + void Mark() { + Value().count |= (1ul << (sizeof(Value().count) * 8 - 1)); + } + + void Unmark() { + Value().count &= ~(1ul << (sizeof(Value().count) * 8 - 1)); + } + + uint64_t UnmarkedCount() const { + return Value().count & ~(1ul << (sizeof(Value().count) * 8 - 1)); + } + + uint64_t CutoffCount() const { + return IsMarked() ? 0 : UnmarkedCount(); + } + + /*mjd**********************************************************************/ private: WordIndex *begin_, *end_; @@ -81,4 +106,4 @@ const WordIndex kEOS = 2; } // namespace builder } // namespace lm -#endif // LM_BUILDER_NGRAM__ +#endif // LM_BUILDER_NGRAM_H diff --git a/klm/lm/builder/ngram_stream.hh b/klm/lm/builder/ngram_stream.hh index 3c994664..ab42734c 100644 --- a/klm/lm/builder/ngram_stream.hh +++ b/klm/lm/builder/ngram_stream.hh @@ -1,8 +1,9 @@ -#ifndef LM_BUILDER_NGRAM_STREAM__ -#define LM_BUILDER_NGRAM_STREAM__ +#ifndef LM_BUILDER_NGRAM_STREAM_H +#define LM_BUILDER_NGRAM_STREAM_H #include "lm/builder/ngram.hh" #include "util/stream/chain.hh" +#include "util/stream/multi_stream.hh" #include "util/stream/stream.hh" #include @@ -51,5 +52,7 @@ inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream & return chain; } +typedef util::stream::GenericStreams NGramStreams; + }} // namespaces -#endif // LM_BUILDER_NGRAM_STREAM__ +#endif // LM_BUILDER_NGRAM_STREAM_H diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc index 44a2313c..21064ab3 100644 --- a/klm/lm/builder/pipeline.cc +++ b/klm/lm/builder/pipeline.cc @@ -2,6 +2,7 @@ #include "lm/builder/adjust_counts.hh" #include "lm/builder/corpus_count.hh" +#include "lm/builder/hash_gamma.hh" #include "lm/builder/initial_probabilities.hh" #include "lm/builder/interpolate.hh" #include "lm/builder/print.hh" @@ -20,10 +21,13 @@ namespace lm { namespace builder { namespace { -void PrintStatistics(const std::vector &counts, const std::vector &discounts) { +void PrintStatistics(const std::vector &counts, const std::vector &counts_pruned, const std::vector &discounts) { std::cerr << "Statistics:\n"; for (size_t i = 0; i < counts.size(); ++i) { - std::cerr << (i + 1) << ' ' << counts[i]; + std::cerr << (i + 1) << ' ' << counts_pruned[i]; + if(counts[i] != counts_pruned[i]) + std::cerr << "/" << counts[i]; + for (size_t d = 1; d <= 3; ++d) std::cerr << " D" << d << (d == 3 ? "+=" : "=") << discounts[i].amount[d]; std::cerr << '\n'; @@ -39,7 +43,7 @@ class Master { const PipelineConfig &Config() const { return config_; } - Chains &MutableChains() { return chains_; } + util::stream::Chains &MutableChains() { return chains_; } template Master &operator>>(const T &worker) { chains_ >> worker; @@ -64,7 +68,7 @@ class Master { } // For initial probabilities, but this is generic. - void SortAndReadTwice(const std::vector &counts, Sorts &sorts, Chains &second, util::stream::ChainConfig second_config) { + void SortAndReadTwice(const std::vector &counts, Sorts &sorts, util::stream::Chains &second, util::stream::ChainConfig second_config) { // Do merge first before allocating chain memory. for (std::size_t i = 1; i < config_.order; ++i) { sorts[i - 1].Merge(0); @@ -198,9 +202,9 @@ class Master { PipelineConfig config_; - Chains chains_; + util::stream::Chains chains_; // Often only unigrams, but sometimes all orders. - FixedArray files_; + util::FixedArray files_; }; void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) { @@ -221,7 +225,7 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m WordIndex type_count = config.vocab_estimate; util::FilePiece text(text_file, NULL, &std::cerr); text_file_name = text.FileName(); - CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize()); + CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action); chain >> boost::ref(counter); util::stream::Sort sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner()); @@ -231,21 +235,22 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m master.InitForAdjust(sorter, type_count); } -void InitialProbabilities(const std::vector &counts, const std::vector &discounts, Master &master, Sorts &primary, FixedArray &gammas) { +void InitialProbabilities(const std::vector &counts, const std::vector &counts_pruned, const std::vector &discounts, Master &master, Sorts &primary, + util::FixedArray &gammas, const std::vector &prune_thresholds) { const PipelineConfig &config = master.Config(); - Chains second(config.order); + util::stream::Chains second(config.order); { Sorts sorts; master.SetupSorts(sorts); - PrintStatistics(counts, discounts); - lm::ngram::ShowSizes(counts); + PrintStatistics(counts, counts_pruned, discounts); + lm::ngram::ShowSizes(counts_pruned); std::cerr << "=== 3/5 Calculating and sorting initial probabilities ===" << std::endl; - master.SortAndReadTwice(counts, sorts, second, config.initial_probs.adder_in); + master.SortAndReadTwice(counts_pruned, sorts, second, config.initial_probs.adder_in); } - Chains gamma_chains(config.order); - InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains); + util::stream::Chains gamma_chains(config.order); + InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains, prune_thresholds); // Don't care about gamma for 0. gamma_chains[0] >> util::stream::kRecycle; gammas.Init(config.order - 1); @@ -257,19 +262,25 @@ void InitialProbabilities(const std::vector &counts, const std::vector master.SetupSorts(primary); } -void InterpolateProbabilities(const std::vector &counts, Master &master, Sorts &primary, FixedArray &gammas) { +void InterpolateProbabilities(const std::vector &counts, Master &master, Sorts &primary, util::FixedArray &gammas) { std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl; const PipelineConfig &config = master.Config(); master.MaximumLazyInput(counts, primary); - Chains gamma_chains(config.order - 1); - util::stream::ChainConfig read_backoffs(config.read_backoffs); - read_backoffs.entry_size = sizeof(float); + util::stream::Chains gamma_chains(config.order - 1); for (std::size_t i = 0; i < config.order - 1; ++i) { + util::stream::ChainConfig read_backoffs(config.read_backoffs); + + // Add 1 because here we are skipping unigrams + if(config.prune_thresholds[i + 1] > 0) + read_backoffs.entry_size = sizeof(HashGamma); + else + read_backoffs.entry_size = sizeof(float); + gamma_chains.push_back(read_backoffs); gamma_chains.back() >> gammas[i].Source(); } - master >> Interpolate(counts[0], ChainPositions(gamma_chains)); + master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* is not included */), util::stream::ChainPositions(gamma_chains), config.prune_thresholds, config.output_q); gamma_chains >> util::stream::kRecycle; master.BufferFinal(counts); } @@ -291,32 +302,40 @@ void Pipeline(PipelineConfig config, int text_file, int out_arpa) { "Not enough memory to fit " << (config.order * config.block_count) << " blocks with minimum size " << config.minimum_block << ". Increase memory to " << (config.minimum_block * config.order * config.block_count) << " bytes or decrease the minimum block size."); UTIL_TIMER("(%w s) Total wall time elapsed\n"); - Master master(config); - - util::scoped_fd vocab_file(config.vocab_file.empty() ? - util::MakeTemp(config.TempPrefix()) : - util::CreateOrThrow(config.vocab_file.c_str())); - uint64_t token_count; - std::string text_file_name; - CountText(text_file, vocab_file.get(), master, token_count, text_file_name); - std::vector counts; - std::vector discounts; - master >> AdjustCounts(counts, discounts); + Master master(config); + // master's destructor will wait for chains. But they might be deadlocked if + // this thread dies because e.g. it ran out of memory. + try { + util::scoped_fd vocab_file(config.vocab_file.empty() ? + util::MakeTemp(config.TempPrefix()) : + util::CreateOrThrow(config.vocab_file.c_str())); + uint64_t token_count; + std::string text_file_name; + CountText(text_file, vocab_file.get(), master, token_count, text_file_name); + + std::vector counts; + std::vector counts_pruned; + std::vector discounts; + master >> AdjustCounts(config.prune_thresholds, counts, counts_pruned, config.discount, discounts); + + { + util::FixedArray gammas; + Sorts primary; + InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds); + InterpolateProbabilities(counts_pruned, master, primary, gammas); + } - { - FixedArray gammas; - Sorts primary; - InitialProbabilities(counts, discounts, master, primary, gammas); - InterpolateProbabilities(counts, master, primary, gammas); + std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl; + VocabReconstitute vocab(vocab_file.get()); + UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?"); + HeaderInfo header_info(text_file_name, token_count); + master >> PrintARPA(vocab, counts_pruned, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle; + master.MutableChains().Wait(true); + } catch (const util::Exception &e) { + std::cerr << e.what() << std::endl; + abort(); } - - std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl; - VocabReconstitute vocab(vocab_file.get()); - UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?"); - HeaderInfo header_info(text_file_name, token_count); - master >> PrintARPA(vocab, counts, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle; - master.MutableChains().Wait(true); } }} // namespaces diff --git a/klm/lm/builder/pipeline.hh b/klm/lm/builder/pipeline.hh index 845e5481..09e1a4d5 100644 --- a/klm/lm/builder/pipeline.hh +++ b/klm/lm/builder/pipeline.hh @@ -1,8 +1,10 @@ -#ifndef LM_BUILDER_PIPELINE__ -#define LM_BUILDER_PIPELINE__ +#ifndef LM_BUILDER_PIPELINE_H +#define LM_BUILDER_PIPELINE_H +#include "lm/builder/adjust_counts.hh" #include "lm/builder/initial_probabilities.hh" #include "lm/builder/header_info.hh" +#include "lm/lm_exception.hh" #include "lm/word_index.hh" #include "util/stream/config.hh" #include "util/file_piece.hh" @@ -18,6 +20,8 @@ struct PipelineConfig { util::stream::SortConfig sort; InitialProbabilitiesConfig initial_probs; util::stream::ChainConfig read_backoffs; + + // Include a header in the ARPA with some statistics? bool verbose_header; // Estimated vocabulary size. Used for sizing CorpusCount memory and @@ -30,6 +34,34 @@ struct PipelineConfig { // Number of blocks to use. This will be overridden to 1 if everything fits. std::size_t block_count; + // n-gram count thresholds for pruning. 0 values means no pruning for + // corresponding n-gram order + std::vector prune_thresholds; //mjd + + // What to do with discount failures. + DiscountConfig discount; + + // Compute collapsed q values instead of probability and backoff + bool output_q; + + /* Computing the perplexity of LMs with different vocabularies is hard. For + * example, the lowest perplexity is attained by a unigram model that + * predicts p() = 1 and has no other vocabulary. Also, linearly + * interpolated models will sum to more than 1 because is duplicated + * (SRI just pretends p() = 0 for these purposes, which makes it sum to + * 1 but comes with its own problems). This option will make the vocabulary + * a particular size by replicating multiple times for purposes of + * computing vocabulary size. It has no effect if the actual vocabulary is + * larger. This parameter serves the same purpose as IRSTLM's "dub". + */ + uint64_t vocab_size_for_unk; + + /* What to do the first time , , or appears in the input. If + * this is anything but THROW_UP, then the symbol will always be treated as + * whitespace. + */ + WarningAction disallowed_symbol_action; + const std::string &TempPrefix() const { return sort.temp_prefix; } std::size_t TotalMemory() const { return sort.total_memory; } }; @@ -38,4 +70,4 @@ struct PipelineConfig { void Pipeline(PipelineConfig config, int text_file, int out_arpa); }} // namespaces -#endif // LM_BUILDER_PIPELINE__ +#endif // LM_BUILDER_PIPELINE_H diff --git a/klm/lm/builder/print.cc b/klm/lm/builder/print.cc index 84bd81ca..aee6e134 100644 --- a/klm/lm/builder/print.cc +++ b/klm/lm/builder/print.cc @@ -42,22 +42,22 @@ PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector util::WriteOrThrow(out_fd, as_string.data(), as_string.size()); } -void PrintARPA::Run(const ChainPositions &positions) { +void PrintARPA::Run(const util::stream::ChainPositions &positions) { util::scoped_fd closer(out_fd_); UTIL_TIMER("(%w s) Wrote ARPA file\n"); util::FakeOFStream out(out_fd_); for (unsigned order = 1; order <= positions.size(); ++order) { out << "\\" << order << "-grams:" << '\n'; for (NGramStream stream(positions[order - 1]); stream; ++stream) { - // Correcting for numerical precision issues. Take that IRST. - out << std::min(0.0f, stream->Value().complete.prob) << '\t' << vocab_.Lookup(*stream->begin()); + // Correcting for numerical precision issues. Take that IRST. + out << stream->Value().complete.prob << '\t' << vocab_.Lookup(*stream->begin()); for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) { out << ' ' << vocab_.Lookup(*i); } - float backoff = stream->Value().complete.backoff; - if (backoff != 0.0) - out << '\t' << backoff; + if (order != positions.size()) + out << '\t' << stream->Value().complete.backoff; out << '\n'; + } out << '\n'; } diff --git a/klm/lm/builder/print.hh b/klm/lm/builder/print.hh index adbbb94a..9856cea8 100644 --- a/klm/lm/builder/print.hh +++ b/klm/lm/builder/print.hh @@ -1,8 +1,8 @@ -#ifndef LM_BUILDER_PRINT__ -#define LM_BUILDER_PRINT__ +#ifndef LM_BUILDER_PRINT_H +#define LM_BUILDER_PRINT_H #include "lm/builder/ngram.hh" -#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram_stream.hh" #include "lm/builder/header_info.hh" #include "util/file.hh" #include "util/mmap.hh" @@ -59,7 +59,7 @@ template class Print { public: explicit Print(const VocabReconstitute &vocab, std::ostream &to) : vocab_(vocab), to_(to) {} - void Run(const ChainPositions &chains) { + void Run(const util::stream::ChainPositions &chains) { NGramStreams streams(chains); for (NGramStream *s = streams.begin(); s != streams.end(); ++s) { DumpStream(*s); @@ -92,7 +92,7 @@ class PrintARPA { // Takes ownership of out_fd upon Run(). explicit PrintARPA(const VocabReconstitute &vocab, const std::vector &counts, const HeaderInfo* header_info, int out_fd); - void Run(const ChainPositions &positions); + void Run(const util::stream::ChainPositions &positions); private: const VocabReconstitute &vocab_; @@ -100,4 +100,4 @@ class PrintARPA { }; }} // namespaces -#endif // LM_BUILDER_PRINT__ +#endif // LM_BUILDER_PRINT_H diff --git a/klm/lm/builder/sort.hh b/klm/lm/builder/sort.hh index 9989389b..712bb8e3 100644 --- a/klm/lm/builder/sort.hh +++ b/klm/lm/builder/sort.hh @@ -1,7 +1,7 @@ -#ifndef LM_BUILDER_SORT__ -#define LM_BUILDER_SORT__ +#ifndef LM_BUILDER_SORT_H +#define LM_BUILDER_SORT_H -#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram_stream.hh" #include "lm/builder/ngram.hh" #include "lm/word_index.hh" #include "util/stream/sort.hh" @@ -14,24 +14,71 @@ namespace lm { namespace builder { +/** + * Abstract parent class for defining custom n-gram comparators. + */ template class Comparator : public std::binary_function { public: + + /** + * Constructs a comparator capable of comparing two n-grams. + * + * @param order Number of words in each n-gram + */ explicit Comparator(std::size_t order) : order_(order) {} + /** + * Applies the comparator using the Compare method that must be defined in any class that inherits from this class. + * + * @param lhs A pointer to the n-gram on the left-hand side of the comparison + * @param rhs A pointer to the n-gram on the right-hand side of the comparison + * + * @see ContextOrder::Compare + * @see PrefixOrder::Compare + * @see SuffixOrder::Compare + */ inline bool operator()(const void *lhs, const void *rhs) const { return static_cast(this)->Compare(static_cast(lhs), static_cast(rhs)); } + /** Gets the n-gram order defined for this comparator. */ std::size_t Order() const { return order_; } protected: std::size_t order_; }; +/** + * N-gram comparator that compares n-grams according to their reverse (suffix) order. + * + * This comparator compares n-grams lexicographically, one word at a time, + * beginning with the last word of each n-gram and ending with the first word of each n-gram. + * + * Some examples of n-gram comparisons as defined by this comparator: + * - a b c == a b c + * - a b c < a b d + * - a b c > a d b + * - a b c > a b b + * - a b c > x a c + * - a b c < x y z + */ class SuffixOrder : public Comparator { public: + + /** + * Constructs a comparator capable of comparing two n-grams. + * + * @param order Number of words in each n-gram + */ explicit SuffixOrder(std::size_t order) : Comparator(order) {} + /** + * Compares two n-grams lexicographically, one word at a time, + * beginning with the last word of each n-gram and ending with the first word of each n-gram. + * + * @param lhs A pointer to the n-gram on the left-hand side of the comparison + * @param rhs A pointer to the n-gram on the right-hand side of the comparison + */ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { for (std::size_t i = order_ - 1; i != 0; --i) { if (lhs[i] != rhs[i]) @@ -43,10 +90,40 @@ class SuffixOrder : public Comparator { static const unsigned kMatchOffset = 1; }; + +/** + * N-gram comparator that compares n-grams according to the reverse (suffix) order of the n-gram context. + * + * This comparator compares n-grams lexicographically, one word at a time, + * beginning with the penultimate word of each n-gram and ending with the first word of each n-gram; + * finally, this comparator compares the last word of each n-gram. + * + * Some examples of n-gram comparisons as defined by this comparator: + * - a b c == a b c + * - a b c < a b d + * - a b c < a d b + * - a b c > a b b + * - a b c > x a c + * - a b c < x y z + */ class ContextOrder : public Comparator { public: + + /** + * Constructs a comparator capable of comparing two n-grams. + * + * @param order Number of words in each n-gram + */ explicit ContextOrder(std::size_t order) : Comparator(order) {} + /** + * Compares two n-grams lexicographically, one word at a time, + * beginning with the penultimate word of each n-gram and ending with the first word of each n-gram; + * finally, this comparator compares the last word of each n-gram. + * + * @param lhs A pointer to the n-gram on the left-hand side of the comparison + * @param rhs A pointer to the n-gram on the right-hand side of the comparison + */ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { for (int i = order_ - 2; i >= 0; --i) { if (lhs[i] != rhs[i]) @@ -56,10 +133,37 @@ class ContextOrder : public Comparator { } }; +/** + * N-gram comparator that compares n-grams according to their natural (prefix) order. + * + * This comparator compares n-grams lexicographically, one word at a time, + * beginning with the first word of each n-gram and ending with the last word of each n-gram. + * + * Some examples of n-gram comparisons as defined by this comparator: + * - a b c == a b c + * - a b c < a b d + * - a b c < a d b + * - a b c > a b b + * - a b c < x a c + * - a b c < x y z + */ class PrefixOrder : public Comparator { public: + + /** + * Constructs a comparator capable of comparing two n-grams. + * + * @param order Number of words in each n-gram + */ explicit PrefixOrder(std::size_t order) : Comparator(order) {} + /** + * Compares two n-grams lexicographically, one word at a time, + * beginning with the first word of each n-gram and ending with the last word of each n-gram. + * + * @param lhs A pointer to the n-gram on the left-hand side of the comparison + * @param rhs A pointer to the n-gram on the right-hand side of the comparison + */ inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { for (std::size_t i = 0; i < order_; ++i) { if (lhs[i] != rhs[i]) @@ -84,15 +188,52 @@ struct AddCombiner { }; // The combiner is only used on a single chain, so I didn't bother to allow -// that template. -template class Sorts : public FixedArray > { +// that template. +/** + * Represents an @ref util::FixedArray "array" capable of storing @ref util::stream::Sort "Sort" objects. + * + * In the anticipated use case, an instance of this class will maintain one @ref util::stream::Sort "Sort" object + * for each n-gram order (ranging from 1 up to the maximum n-gram order being processed). + * Use in this manner would enable the n-grams each n-gram order to be sorted, in parallel. + * + * @tparam Compare An @ref Comparator "ngram comparator" to use during sorting. + */ +template class Sorts : public util::FixedArray > { private: typedef util::stream::Sort S; - typedef FixedArray P; + typedef util::FixedArray P; public: + + /** + * Constructs, but does not initialize. + * + * @ref util::FixedArray::Init() "Init" must be called before use. + * + * @see util::FixedArray::Init() + */ + Sorts() {} + + /** + * Constructs an @ref util::FixedArray "array" capable of storing a fixed number of @ref util::stream::Sort "Sort" objects. + * + * @param number The maximum number of @ref util::stream::Sort "sorters" that can be held by this @ref util::FixedArray "array" + * @see util::FixedArray::FixedArray() + */ + explicit Sorts(std::size_t number) : util::FixedArray >(number) {} + + /** + * Constructs a new @ref util::stream::Sort "Sort" object which is stored in this @ref util::FixedArray "array". + * + * The new @ref util::stream::Sort "Sort" object is constructed using the provided @ref util::stream::SortConfig "SortConfig" and @ref Comparator "ngram comparator"; + * once constructed, a new worker @ref util::stream::Thread "thread" (owned by the @ref util::stream::Chain "chain") will sort the n-gram data stored + * in the @ref util::stream::Block "blocks" of the provided @ref util::stream::Chain "chain". + * + * @see util::stream::Sort::Sort() + * @see util::stream::Chain::operator>>() + */ void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) { - new (P::end()) S(chain, config, compare); + new (P::end()) S(chain, config, compare); // use "placement new" syntax to initalize S in an already-allocated memory location P::Constructed(); } }; @@ -100,4 +241,4 @@ template class Sorts : public FixedArray class ModelFacade : publ } // mamespace base } // namespace lm -#endif // LM_FACADE__ +#endif // LM_FACADE_H diff --git a/klm/lm/filter/arpa_io.hh b/klm/lm/filter/arpa_io.hh index 602b5b31..99c97b11 100644 --- a/klm/lm/filter/arpa_io.hh +++ b/klm/lm/filter/arpa_io.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_ARPA_IO__ -#define LM_FILTER_ARPA_IO__ +#ifndef LM_FILTER_ARPA_IO_H +#define LM_FILTER_ARPA_IO_H /* Input and output for ARPA format language model files. */ #include "lm/read_arpa.hh" @@ -111,4 +111,4 @@ template void ReadARPA(util::FilePiece &in_lm, Output &out) { } // namespace lm -#endif // LM_FILTER_ARPA_IO__ +#endif // LM_FILTER_ARPA_IO_H diff --git a/klm/lm/filter/count_io.hh b/klm/lm/filter/count_io.hh index d992026f..de894baf 100644 --- a/klm/lm/filter/count_io.hh +++ b/klm/lm/filter/count_io.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_COUNT_IO__ -#define LM_FILTER_COUNT_IO__ +#ifndef LM_FILTER_COUNT_IO_H +#define LM_FILTER_COUNT_IO_H #include #include @@ -86,4 +86,4 @@ template void ReadCount(util::FilePiece &in_file, Output &out) { } // namespace lm -#endif // LM_FILTER_COUNT_IO__ +#endif // LM_FILTER_COUNT_IO_H diff --git a/klm/lm/filter/format.hh b/klm/lm/filter/format.hh index 7d8c28db..5a2e2db3 100644 --- a/klm/lm/filter/format.hh +++ b/klm/lm/filter/format.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_FORMAT_H__ -#define LM_FILTER_FORMAT_H__ +#ifndef LM_FILTER_FORMAT_H +#define LM_FILTER_FORMAT_H #include "lm/filter/arpa_io.hh" #include "lm/filter/count_io.hh" @@ -247,4 +247,4 @@ class MultipleOutputBuffer { } // namespace lm -#endif // LM_FILTER_FORMAT_H__ +#endif // LM_FILTER_FORMAT_H diff --git a/klm/lm/filter/phrase.hh b/klm/lm/filter/phrase.hh index e8e85835..e5898c9a 100644 --- a/klm/lm/filter/phrase.hh +++ b/klm/lm/filter/phrase.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_PHRASE_H__ -#define LM_FILTER_PHRASE_H__ +#ifndef LM_FILTER_PHRASE_H +#define LM_FILTER_PHRASE_H #include "util/murmur_hash.hh" #include "util/string_piece.hh" @@ -165,4 +165,4 @@ class Multiple : public detail::ConditionCommon { } // namespace phrase } // namespace lm -#endif // LM_FILTER_PHRASE_H__ +#endif // LM_FILTER_PHRASE_H diff --git a/klm/lm/filter/thread.hh b/klm/lm/filter/thread.hh index e785b263..6a6523f9 100644 --- a/klm/lm/filter/thread.hh +++ b/klm/lm/filter/thread.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_THREAD_H__ -#define LM_FILTER_THREAD_H__ +#ifndef LM_FILTER_THREAD_H +#define LM_FILTER_THREAD_H #include "util/thread_pool.hh" @@ -164,4 +164,4 @@ template class Controller : } // namespace lm -#endif // LM_FILTER_THREAD_H__ +#endif // LM_FILTER_THREAD_H diff --git a/klm/lm/filter/vocab.hh b/klm/lm/filter/vocab.hh index 7f0fadaa..2ee6e1f8 100644 --- a/klm/lm/filter/vocab.hh +++ b/klm/lm/filter/vocab.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_VOCAB_H__ -#define LM_FILTER_VOCAB_H__ +#ifndef LM_FILTER_VOCAB_H +#define LM_FILTER_VOCAB_H // Vocabulary-based filters for language models. @@ -130,4 +130,4 @@ class Multiple { } // namespace vocab } // namespace lm -#endif // LM_FILTER_VOCAB_H__ +#endif // LM_FILTER_VOCAB_H diff --git a/klm/lm/filter/wrapper.hh b/klm/lm/filter/wrapper.hh index eb657501..822c5c27 100644 --- a/klm/lm/filter/wrapper.hh +++ b/klm/lm/filter/wrapper.hh @@ -1,5 +1,5 @@ -#ifndef LM_FILTER_WRAPPER_H__ -#define LM_FILTER_WRAPPER_H__ +#ifndef LM_FILTER_WRAPPER_H +#define LM_FILTER_WRAPPER_H #include "util/string_piece.hh" @@ -53,4 +53,4 @@ template class ContextFilter { } // namespace lm -#endif // LM_FILTER_WRAPPER_H__ +#endif // LM_FILTER_WRAPPER_H diff --git a/klm/lm/interpolate/arpa_to_stream.cc b/klm/lm/interpolate/arpa_to_stream.cc new file mode 100644 index 00000000..f2696f39 --- /dev/null +++ b/klm/lm/interpolate/arpa_to_stream.cc @@ -0,0 +1,47 @@ +#include "lm/interpolate/arpa_to_stream.hh" + +// TODO: should this move out of builder? +#include "lm/builder/ngram_stream.hh" +#include "lm/read_arpa.hh" +#include "lm/vocab.hh" + +namespace lm { namespace interpolate { + +ARPAToStream::ARPAToStream(int fd, ngram::GrowableVocab &vocab) + : in_(fd), vocab_(vocab) { + + // Read the ARPA file header. + // + // After the following call, counts_ will be correctly initialized, + // and in_ will be positioned for reading the body of the ARPA file. + ReadARPACounts(in_, counts_); + +} + +void ARPAToStream::Run(const util::stream::ChainPositions &positions) { + // Make one stream for each order. + builder::NGramStreams streams(positions); + PositiveProbWarn warn; + + // Unigrams are handled specially because they're being inserted into the vocab. + ReadNGramHeader(in_, 1); + for (uint64_t i = 0; i < counts_[0]; ++i, ++streams[0]) { + streams[0]->begin()[0] = vocab_.FindOrInsert(Read1Gram(in_, streams[0]->Value().complete, warn)); + } + // Finish off the unigram stream. + streams[0].Poison(); + + // TODO: don't waste backoff field for highest order. + for (unsigned char n = 2; n <= counts_.size(); ++n) { + ReadNGramHeader(in_, n); + builder::NGramStream &stream = streams[n - 1]; + const uint64_t end = counts_[n - 1]; + for (std::size_t i = 0; i < end; ++i, ++stream) { + ReadNGram(in_, n, vocab_, stream->begin(), stream->Value().complete, warn); + } + // Finish the stream for n-grams.. + stream.Poison(); + } +} + +}} // namespaces diff --git a/klm/lm/interpolate/arpa_to_stream.hh b/klm/lm/interpolate/arpa_to_stream.hh new file mode 100644 index 00000000..4613998d --- /dev/null +++ b/klm/lm/interpolate/arpa_to_stream.hh @@ -0,0 +1,38 @@ +#include "lm/read_arpa.hh" +#include "util/file_piece.hh" + +#include + +#include + +namespace util { namespace stream { class ChainPositions; } } + +namespace lm { + +namespace ngram { +template class GrowableVocab; +class WriteUniqueWords; +} // namespace ngram + +namespace interpolate { + +class ARPAToStream { + public: + // Takes ownership of fd. + explicit ARPAToStream(int fd, ngram::GrowableVocab &vocab); + + std::size_t Order() const { return counts_.size(); } + + const std::vector &Counts() const { return counts_; } + + void Run(const util::stream::ChainPositions &positions); + + private: + util::FilePiece in_; + + std::vector counts_; + + ngram::GrowableVocab &vocab_; +}; + +}} // namespaces diff --git a/klm/lm/interpolate/example_sort_main.cc b/klm/lm/interpolate/example_sort_main.cc new file mode 100644 index 00000000..4282255e --- /dev/null +++ b/klm/lm/interpolate/example_sort_main.cc @@ -0,0 +1,144 @@ +#include "lm/interpolate/arpa_to_stream.hh" + +#include "lm/builder/print.hh" +#include "lm/builder/sort.hh" +#include "lm/vocab.hh" +#include "util/file.hh" +#include "util/unistd.hh" + + +int main() { + + // TODO: Make these all command-line parameters + const std::size_t ONE_GB = 1 << 30; + const std::size_t SIXTY_FOUR_MB = 1 << 26; + const std::size_t NUMBER_OF_BLOCKS = 2; + + // Vocab strings will be written to this file, forgotten, and reconstituted + // later. This saves memory. + util::scoped_fd vocab_file(util::MakeTemp("/tmp/")); + std::vector counts; + util::stream::Chains chains; + { + // Use consistent vocab ids across models. + lm::ngram::GrowableVocab vocab(10, vocab_file.get()); + lm::interpolate::ARPAToStream reader(STDIN_FILENO, vocab); + counts = reader.Counts(); + + // Configure a chain for each order. TODO: extract chain balance heuristics from lm/builder/pipeline.cc + chains.Init(reader.Order()); + + for (std::size_t i = 0; i < reader.Order(); ++i) { + + // The following call to chains.push_back() invokes the Chain constructor + // and appends the newly created Chain object to the chains array + chains.push_back(util::stream::ChainConfig(lm::builder::NGram::TotalSize(i + 1), NUMBER_OF_BLOCKS, ONE_GB)); + + } + + // The following call to the >> method of chains + // constructs a ChainPosition for each chain in chains using Chain::Add(); + // that function begins with a call to Chain::Start() + // that allocates memory for the chain. + // + // After the following call to the >> method of chains, + // a new thread will be running + // and will be executing the reader.Run() method + // to read through the body of the ARPA file from standard input. + // + // For each n-gram line in the ARPA file, + // the thread executing reader.Run() + // will write the probability, the n-gram, and the backoff + // to the appropriate location in the appropriate chain + // (for details, see the ReadNGram() method in read_arpa.hh). + // + // Normally >> copies then runs so inline >> works. But here we want a ref. + chains >> boost::ref(reader); + + + util::stream::SortConfig sort_config; + sort_config.temp_prefix = "/tmp/"; + sort_config.buffer_size = SIXTY_FOUR_MB; + sort_config.total_memory = ONE_GB; + + // Parallel sorts across orders (though somewhat limited because ARPA files are not being read in parallel across orders) + lm::builder::Sorts sorts(reader.Order()); + for (std::size_t i = 0; i < reader.Order(); ++i) { + + // The following call to sorts.push_back() invokes the Sort constructor + // and appends the newly constructed Sort object to the sorts array. + // + // After the construction of the Sort object, + // two new threads will be running (each owned by the chains[i] object). + // + // The first new thread will execute BlockSorter.Run() to sort the n-gram entries of order (i+1) + // that were previously read into chains[i] by the ARPA input reader thread. + // + // The second new thread will execute WriteAndRecycle.Run() + // to write each sorted block of data to disk as a temporary file. + sorts.push_back(chains[i], sort_config, lm::builder::SuffixOrder(i + 1)); + + } + + // Output to the same chains. + for (std::size_t i = 0; i < reader.Order(); ++i) { + + // The following call to Chain::Wait() + // joins the threads owned by chains[i]. + // + // As such the following call won't return + // until all threads owned by chains[i] have completed. + // + // The following call also resets chain[i] + // so that it can be reused + // (including free'ing the memory previously used by the chain) + chains[i].Wait(); + + + // In an ideal world (without memory restrictions) + // we could merge all of the previously sorted blocks + // by reading them all completely into memory + // and then running merge sort over them. + // + // In the real world, we have memory restrictions; + // depending on how many blocks we have, + // and how much memory we can use to read from each block (sort_config.buffer_size) + // it may be the case that we have insufficient memory + // to read sort_config.buffer_size of data from each block from disk. + // + // If this occurs, then it will be necessary to perform one or more rounds of merge sort on disk; + // doing so will reduce the number of blocks that we will eventually need to read from + // when performing the final round of merge sort in memory. + // + // So, the following call determines whether it is necessary + // to perform one or more rounds of merge sort on disk; + // if such on-disk merge sorting is required, such sorting is performed. + // + // Finally, the following method launches a thread that calls OwningMergingReader.Run() + // to perform the final round of merge sort in memory. + // + // Merge sort could have be invoked directly + // so that merge sort memory doesn't coexist with Chain memory. + sorts[i].Output(chains[i]); + } + + // sorts can go out of scope even though it's still writing to the chains. + // note that vocab going out of scope flushes to vocab_file. + } + + + // Get the vocabulary mapping used for this ARPA file + lm::builder::VocabReconstitute reconstitute(vocab_file.get()); + + // After the following call to the << method of chains, + // a new thread will be running + // and will be executing the Run() method of PrintARPA + // to print the final sorted ARPA file to standard output. + chains >> lm::builder::PrintARPA(reconstitute, counts, NULL, STDOUT_FILENO); + + // Joins all threads that chains owns, + // and does a for loop over each chain object in chains, + // calling chain.Wait() on each such chain object + chains.Wait(true); + +} diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 85c1ea37..36d61369 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -35,8 +35,8 @@ * phrase, even if hypotheses are generated left-to-right. */ -#ifndef LM_LEFT__ -#define LM_LEFT__ +#ifndef LM_LEFT_H +#define LM_LEFT_H #include "lm/max_order.hh" #include "lm/state.hh" @@ -213,4 +213,4 @@ template class RuleScore { } // namespace ngram } // namespace lm -#endif // LM_LEFT__ +#endif // LM_LEFT_H diff --git a/klm/lm/lm_exception.hh b/klm/lm/lm_exception.hh index f607ced1..8bb61081 100644 --- a/klm/lm/lm_exception.hh +++ b/klm/lm/lm_exception.hh @@ -1,5 +1,5 @@ -#ifndef LM_LM_EXCEPTION__ -#define LM_LM_EXCEPTION__ +#ifndef LM_LM_EXCEPTION_H +#define LM_LM_EXCEPTION_H // Named to avoid conflict with util/exception.hh. diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh index 3eb97ccd..f7344cde 100644 --- a/klm/lm/max_order.hh +++ b/klm/lm/max_order.hh @@ -1,9 +1,13 @@ -/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER, THEN CHANGE THE BUILD SYSTEM. +#ifndef LM_MAX_ORDER_H +#define LM_MAX_ORDER_H +/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER_H, THEN CHANGE THE BUILD SYSTEM. * If not, this is the default maximum order. * Having this limit means that State can be * (kMaxOrder - 1) * sizeof(float) bytes instead of * sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead */ #ifndef KENLM_ORDER_MESSAGE -#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh." +#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER_H, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh." #endif + +#endif // LM_MAX_ORDER_H diff --git a/klm/lm/model.hh b/klm/lm/model.hh index e75da93b..6925a56d 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -1,5 +1,5 @@ -#ifndef LM_MODEL__ -#define LM_MODEL__ +#ifndef LM_MODEL_H +#define LM_MODEL_H #include "lm/bhiksha.hh" #include "lm/binary_format.hh" @@ -153,4 +153,4 @@ base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), } // namespace ngram } // namespace lm -#endif // LM_MODEL__ +#endif // LM_MODEL_H diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 7005b05e..0f54724b 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -176,7 +176,7 @@ template void MinimalState(const M &model) { AppendTest("to", 1, -1.687872, false); AppendTest("look", 2, -0.2922095, true); BOOST_CHECK_EQUAL(2, state.length); - AppendTest("good", 3, -7, true); + AppendTest("a", 3, -7, true); } template void ExtendLeftTest(const M &model) { diff --git a/klm/lm/model_type.hh b/klm/lm/model_type.hh index 8b35c793..fbe1117a 100644 --- a/klm/lm/model_type.hh +++ b/klm/lm/model_type.hh @@ -1,5 +1,5 @@ -#ifndef LM_MODEL_TYPE__ -#define LM_MODEL_TYPE__ +#ifndef LM_MODEL_TYPE_H +#define LM_MODEL_TYPE_H namespace lm { namespace ngram { @@ -20,4 +20,4 @@ const static ModelType kArrayAdd = static_cast(ARRAY_TRIE - TRIE); } // namespace ngram } // namespace lm -#endif // LM_MODEL_TYPE__ +#endif // LM_MODEL_TYPE_H diff --git a/klm/lm/neural/wordvecs.cc b/klm/lm/neural/wordvecs.cc new file mode 100644 index 00000000..09bb4260 --- /dev/null +++ b/klm/lm/neural/wordvecs.cc @@ -0,0 +1,23 @@ +#include "lm/neural/wordvecs.hh" + +#include "util/file_piece.hh" + +namespace lm { namespace neural { + +WordVecs::WordVecs(util::FilePiece &f) { + const unsigned long lines = f.ReadULong(); + const std::size_t vocab_mem = ngram::ProbingVocabulary::Size(lines, 1.5); + vocab_backing_.reset(util::CallocOrThrow(vocab_mem)); + vocab_.SetupMemory(vocab_backing_.get(), vocab_mem); + const unsigned long width = f.ReadULong(); + vecs_.resize(width, lines); + for (unsigned long i = 0; i < lines; ++i) { + WordIndex column = vocab_.Insert(f.ReadDelimited()); + for (unsigned int row = 0; row < width; ++row) { + vecs_(row,column) = f.ReadFloat(); + } + } + vocab_.FinishedLoading(); +} + +}} // namespaces diff --git a/klm/lm/neural/wordvecs.hh b/klm/lm/neural/wordvecs.hh new file mode 100644 index 00000000..921a2b22 --- /dev/null +++ b/klm/lm/neural/wordvecs.hh @@ -0,0 +1,38 @@ +#ifndef LM_NEURAL_WORDVECS_H +#define LM_NEURAL_WORDVECS_H + +#include "util/scoped.hh" +#include "lm/vocab.hh" + +#include + +namespace util { class FilePiece; } + +namespace lm { +namespace neural { + +class WordVecs { + public: + // Columns of the matrix are word vectors. The column index is the word. + typedef Eigen::Matrix Storage; + + /* The file should begin with a line stating the number of word vectors and + * the length of the vectors. Then it's followed by lines containing a + * word followed by floating-point values. + */ + explicit WordVecs(util::FilePiece &in); + + const Storage &Vectors() const { return vecs_; } + + WordIndex Index(StringPiece str) const { return vocab_.Index(str); } + + private: + util::scoped_malloc vocab_backing_; + ngram::ProbingVocabulary vocab_; + + Storage vecs_; +}; + +}} // namespaces + +#endif // LM_NEURAL_WORDVECS_H diff --git a/klm/lm/ngram_query.hh b/klm/lm/ngram_query.hh index ec2590f4..5f330c5c 100644 --- a/klm/lm/ngram_query.hh +++ b/klm/lm/ngram_query.hh @@ -1,8 +1,9 @@ -#ifndef LM_NGRAM_QUERY__ -#define LM_NGRAM_QUERY__ +#ifndef LM_NGRAM_QUERY_H +#define LM_NGRAM_QUERY_H #include "lm/enumerate_vocab.hh" #include "lm/model.hh" +#include "util/file_piece.hh" #include "util/usage.hh" #include @@ -16,64 +17,94 @@ namespace lm { namespace ngram { -template void Query(const Model &model, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { +struct BasicPrint { + void Word(StringPiece, WordIndex, const FullScoreReturn &) const {} + void Line(uint64_t oov, float total) const { + std::cout << "Total: " << total << " OOV: " << oov << '\n'; + } + void Summary(double, double, uint64_t, uint64_t) {} + +}; + +struct FullPrint : public BasicPrint { + void Word(StringPiece surface, WordIndex vocab, const FullScoreReturn &ret) const { + std::cout << surface << '=' << vocab << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\t'; + } + + void Summary(double ppl_including_oov, double ppl_excluding_oov, uint64_t corpus_oov, uint64_t corpus_tokens) { + std::cout << + "Perplexity including OOVs:\t" << ppl_including_oov << "\n" + "Perplexity excluding OOVs:\t" << ppl_excluding_oov << "\n" + "OOVs:\t" << corpus_oov << "\n" + "Tokens:\t" << corpus_tokens << '\n' + ; + } +}; + +template void Query(const Model &model, bool sentence_context) { + Printer printer; typename Model::State state, out; lm::FullScoreReturn ret; - std::string word; + StringPiece word; + + util::FilePiece in(0); double corpus_total = 0.0; + double corpus_total_oov_only = 0.0; uint64_t corpus_oov = 0; uint64_t corpus_tokens = 0; - while (in_stream) { + while (true) { state = sentence_context ? model.BeginSentenceState() : model.NullContextState(); float total = 0.0; - bool got = false; uint64_t oov = 0; - while (in_stream >> word) { - got = true; + + while (in.ReadWordSameLine(word)) { lm::WordIndex vocab = model.GetVocabulary().Index(word); - if (vocab == 0) ++oov; ret = model.FullScore(state, vocab, out); + if (vocab == model.GetVocabulary().NotFound()) { + ++oov; + corpus_total_oov_only += ret.prob; + } total += ret.prob; - out_stream << word << '=' << vocab << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\t'; + printer.Word(word, vocab, ret); ++corpus_tokens; state = out; - char c; - while (true) { - c = in_stream.get(); - if (!in_stream) break; - if (c == '\n') break; - if (!isspace(c)) { - in_stream.unget(); - break; - } - } - if (c == '\n') break; } - if (!got && !in_stream) break; + // If people don't have a newline after their last query, this won't add a . + // Sue me. + try { + UTIL_THROW_IF('\n' != in.get(), util::Exception, "FilePiece is confused."); + } catch (const util::EndOfFileException &e) { break; } if (sentence_context) { ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); total += ret.prob; ++corpus_tokens; - out_stream << "=" << model.GetVocabulary().EndSentence() << ' ' << static_cast(ret.ngram_length) << ' ' << ret.prob << '\t'; + printer.Word("", model.GetVocabulary().EndSentence(), ret); } - out_stream << "Total: " << total << " OOV: " << oov << '\n'; + printer.Line(oov, total); corpus_total += total; corpus_oov += oov; } - out_stream << "Perplexity " << pow(10.0, -(corpus_total / static_cast(corpus_tokens))) << std::endl; + printer.Summary( + pow(10.0, -(corpus_total / static_cast(corpus_tokens))), // PPL including OOVs + pow(10.0, -((corpus_total - corpus_total_oov_only) / static_cast(corpus_tokens - corpus_oov))), // PPL excluding OOVs + corpus_oov, + corpus_tokens); } -template void Query(const char *file, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { - Config config; - M model(file, config); - Query(model, sentence_context, in_stream, out_stream); +template void Query(const char *file, const Config &config, bool sentence_context, bool show_words) { + Model model(file, config); + if (show_words) { + Query(model, sentence_context); + } else { + Query(model, sentence_context); + } } } // namespace ngram } // namespace lm -#endif // LM_NGRAM_QUERY__ +#endif // LM_NGRAM_QUERY_H diff --git a/klm/lm/partial.hh b/klm/lm/partial.hh index 1dede359..d8adc696 100644 --- a/klm/lm/partial.hh +++ b/klm/lm/partial.hh @@ -1,5 +1,5 @@ -#ifndef LM_PARTIAL__ -#define LM_PARTIAL__ +#ifndef LM_PARTIAL_H +#define LM_PARTIAL_H #include "lm/return.hh" #include "lm/state.hh" @@ -164,4 +164,4 @@ template float Subsume(const Model &model, Left &first_left, const } // namespace ngram } // namespace lm -#endif // LM_PARTIAL__ +#endif // LM_PARTIAL_H diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index 9d3a2f43..84a30872 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -1,5 +1,5 @@ -#ifndef LM_QUANTIZE_H__ -#define LM_QUANTIZE_H__ +#ifndef LM_QUANTIZE_H +#define LM_QUANTIZE_H #include "lm/blank.hh" #include "lm/config.hh" @@ -230,4 +230,4 @@ class SeparatelyQuantize { } // namespace ngram } // namespace lm -#endif // LM_QUANTIZE_H__ +#endif // LM_QUANTIZE_H diff --git a/klm/lm/query_main.cc b/klm/lm/query_main.cc index bd4fde62..3013ff21 100644 --- a/klm/lm/query_main.cc +++ b/klm/lm/query_main.cc @@ -1,4 +1,5 @@ #include "lm/ngram_query.hh" +#include "util/getopt.hh" #ifdef WITH_NPLM #include "lm/wrappers/nplm.hh" @@ -7,47 +8,76 @@ #include void Usage(const char *name) { - std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl; - std::cerr << "Usage: " << name << " [-n] lm_file" << std::endl; - std::cerr << "Input is wrapped in and unless -n is passed." << std::endl; + std::cerr << + "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << ".\n" + "Usage: " << name << " [-n] [-s] lm_file\n" + "-n: Do not wrap the input in and .\n" + "-s: Sentence totals only.\n" + "-l lazy|populate|read|parallel: Load lazily, with populate, or malloc+read\n" + "The default loading method is populate on Linux and read on others.\n"; exit(1); } int main(int argc, char *argv[]) { + if (argc == 1 || (argc == 2 && !strcmp(argv[1], "--help"))) + Usage(argv[0]); + + lm::ngram::Config config; bool sentence_context = true; - const char *file = NULL; - for (char **arg = argv + 1; arg != argv + argc; ++arg) { - if (!strcmp(*arg, "-n")) { - sentence_context = false; - } else if (!strcmp(*arg, "-h") || !strcmp(*arg, "--help") || file) { - Usage(argv[0]); - } else { - file = *arg; + bool show_words = true; + + int opt; + while ((opt = getopt(argc, argv, "hnsl:")) != -1) { + switch (opt) { + case 'n': + sentence_context = false; + break; + case 's': + show_words = false; + break; + case 'l': + if (!strcmp(optarg, "lazy")) { + config.load_method = util::LAZY; + } else if (!strcmp(optarg, "populate")) { + config.load_method = util::POPULATE_OR_READ; + } else if (!strcmp(optarg, "read")) { + config.load_method = util::READ; + } else if (!strcmp(optarg, "parallel")) { + config.load_method = util::PARALLEL_READ; + } else { + Usage(argv[0]); + } + break; + case 'h': + default: + Usage(argv[0]); } } - if (!file) Usage(argv[0]); + if (optind + 1 != argc) + Usage(argv[0]); + const char *file = argv[optind]; try { using namespace lm::ngram; ModelType model_type; if (RecognizeBinary(file, model_type)) { switch(model_type) { case PROBING: - Query(file, sentence_context, std::cin, std::cout); + Query(file, config, sentence_context, show_words); break; case REST_PROBING: - Query(file, sentence_context, std::cin, std::cout); + Query(file, config, sentence_context, show_words); break; case TRIE: - Query(file, sentence_context, std::cin, std::cout); + Query(file, config, sentence_context, show_words); break; case QUANT_TRIE: - Query(file, sentence_context, std::cin, std::cout); + Query(file, config, sentence_context, show_words); break; case ARRAY_TRIE: - Query(file, sentence_context, std::cin, std::cout); + Query(file, config, sentence_context, show_words); break; case QUANT_ARRAY_TRIE: - Query(file, sentence_context, std::cin, std::cout); + Query(file, config, sentence_context, show_words); break; default: std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; @@ -56,12 +86,15 @@ int main(int argc, char *argv[]) { #ifdef WITH_NPLM } else if (lm::np::Model::Recognize(file)) { lm::np::Model model(file); - Query(model, sentence_context, std::cin, std::cout); + if (show_words) { + Query(model, sentence_context); + } else { + Query(model, sentence_context); + } #endif } else { - Query(file, sentence_context, std::cin, std::cout); + Query(file, config, sentence_context, show_words); } - std::cerr << "Total time including destruction:\n"; util::PrintUsage(std::cerr); } catch (const std::exception &e) { std::cerr << e.what() << std::endl; diff --git a/klm/lm/read_arpa.hh b/klm/lm/read_arpa.hh index 234d130c..64eeef30 100644 --- a/klm/lm/read_arpa.hh +++ b/klm/lm/read_arpa.hh @@ -1,5 +1,5 @@ -#ifndef LM_READ_ARPA__ -#define LM_READ_ARPA__ +#ifndef LM_READ_ARPA_H +#define LM_READ_ARPA_H #include "lm/lm_exception.hh" #include "lm/word_index.hh" @@ -28,7 +28,7 @@ void ReadEnd(util::FilePiece &in); extern const bool kARPASpaces[256]; -// Positive log probability warning. +// Positive log probability warning. class PositiveProbWarn { public: PositiveProbWarn() : action_(THROW_UP) {} @@ -48,17 +48,17 @@ template void Read1Gram(util::FilePiece &f, Voc &voca warn.Warn(prob); prob = 0.0; } - if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability"); - Weights &value = unigrams[vocab.Insert(f.ReadDelimited(kARPASpaces))]; - value.prob = prob; - ReadBackoff(f, value); + UTIL_THROW_IF(f.get() != '\t', FormatLoadException, "Expected tab after probability"); + WordIndex word = vocab.Insert(f.ReadDelimited(kARPASpaces)); + Weights &w = unigrams[word]; + w.prob = prob; + ReadBackoff(f, w); } catch(util::Exception &e) { e << " in the 1-gram at byte " << f.Offset(); throw; } } -// Return true if a positive log probability came out. template void Read1Grams(util::FilePiece &f, std::size_t count, Voc &vocab, Weights *unigrams, PositiveProbWarn &warn) { ReadNGramHeader(f, 1); for (std::size_t i = 0; i < count; ++i) { @@ -67,16 +67,21 @@ template void Read1Grams(util::FilePiece &f, std::siz vocab.FinishedLoading(unigrams); } -// Return true if a positive log probability came out. -template void ReadNGram(util::FilePiece &f, const unsigned char n, const Voc &vocab, WordIndex *const reverse_indices, Weights &weights, PositiveProbWarn &warn) { +// Read ngram, write vocab ids to indices_out. +template void ReadNGram(util::FilePiece &f, const unsigned char n, const Voc &vocab, Iterator indices_out, Weights &weights, PositiveProbWarn &warn) { try { weights.prob = f.ReadFloat(); if (weights.prob > 0.0) { warn.Warn(weights.prob); weights.prob = 0.0; } - for (WordIndex *vocab_out = reverse_indices + n - 1; vocab_out >= reverse_indices; --vocab_out) { - *vocab_out = vocab.Index(f.ReadDelimited(kARPASpaces)); + for (unsigned char i = 0; i < n; ++i, ++indices_out) { + StringPiece word(f.ReadDelimited(kARPASpaces)); + WordIndex index = vocab.Index(word); + *indices_out = index; + // Check for words mapped to that are not the string . + UTIL_THROW_IF(index == 0 /* mapped to */ && (word != StringPiece("", 5)) && (word != StringPiece("", 5)), + FormatLoadException, "Word " << word << " was not seen in the unigrams (which are supposed to list the entire vocabulary) but appears"); } ReadBackoff(f, weights); } catch(util::Exception &e) { @@ -87,4 +92,4 @@ template void ReadNGram(util::FilePiece &f, const uns } // namespace lm -#endif // LM_READ_ARPA__ +#endif // LM_READ_ARPA_H diff --git a/klm/lm/return.hh b/klm/lm/return.hh index 622320ce..982ffd66 100644 --- a/klm/lm/return.hh +++ b/klm/lm/return.hh @@ -1,5 +1,5 @@ -#ifndef LM_RETURN__ -#define LM_RETURN__ +#ifndef LM_RETURN_H +#define LM_RETURN_H #include @@ -39,4 +39,4 @@ struct FullScoreReturn { }; } // namespace lm -#endif // LM_RETURN__ +#endif // LM_RETURN_H diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 354a56b4..7e63e006 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -178,7 +178,7 @@ template void ReadNGrams( typename Store::Entry entry; std::vector between; for (size_t i = 0; i < count; ++i) { - ReadNGram(f, n, vocab, &*vocab_ids.begin(), entry.value, warn); + ReadNGram(f, n, vocab, vocab_ids.rbegin(), entry.value, warn); build.SetRest(&*vocab_ids.begin(), n, entry.value); keys[0] = detail::CombineWordHash(static_cast(vocab_ids.front()), vocab_ids[1]); diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 8193262b..9dc84454 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -1,5 +1,5 @@ -#ifndef LM_SEARCH_HASHED__ -#define LM_SEARCH_HASHED__ +#ifndef LM_SEARCH_HASHED_H +#define LM_SEARCH_HASHED_H #include "lm/model_type.hh" #include "lm/config.hh" @@ -189,4 +189,4 @@ template class HashedSearch { } // namespace ngram } // namespace lm -#endif // LM_SEARCH_HASHED__ +#endif // LM_SEARCH_HASHED_H diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 4a88194e..7fc70f4e 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -561,6 +561,7 @@ template uint8_t *TrieSearch::Setup } // Crazy backwards thing so we initialize using pointers to ones that have already been initialized for (unsigned char i = counts.size() - 1; i >= 2; --i) { + // use "placement new" syntax to initalize Middle in an already-allocated memory location new (middle_begin_ + i - 2) Middle( middle_starts[i-2], quant_.MiddleBits(config), diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 299262a5..d8838d2b 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -1,5 +1,5 @@ -#ifndef LM_SEARCH_TRIE__ -#define LM_SEARCH_TRIE__ +#ifndef LM_SEARCH_TRIE_H +#define LM_SEARCH_TRIE_H #include "lm/config.hh" #include "lm/model_type.hh" @@ -127,4 +127,4 @@ template class TrieSearch { } // namespace ngram } // namespace lm -#endif // LM_SEARCH_TRIE__ +#endif // LM_SEARCH_TRIE_H diff --git a/klm/lm/sizes.hh b/klm/lm/sizes.hh index 85abade7..eb7e99de 100644 --- a/klm/lm/sizes.hh +++ b/klm/lm/sizes.hh @@ -1,5 +1,5 @@ -#ifndef LM_SIZES__ -#define LM_SIZES__ +#ifndef LM_SIZES_H +#define LM_SIZES_H #include @@ -14,4 +14,4 @@ void ShowSizes(const std::vector &counts); void ShowSizes(const char *file, const lm::ngram::Config &config); }} // namespaces -#endif // LM_SIZES__ +#endif // LM_SIZES_H diff --git a/klm/lm/state.hh b/klm/lm/state.hh index 543df37c..f6c51d6f 100644 --- a/klm/lm/state.hh +++ b/klm/lm/state.hh @@ -1,5 +1,5 @@ -#ifndef LM_STATE__ -#define LM_STATE__ +#ifndef LM_STATE_H +#define LM_STATE_H #include "lm/max_order.hh" #include "lm/word_index.hh" @@ -122,4 +122,4 @@ inline uint64_t hash_value(const ChartState &state) { } // namespace ngram } // namespace lm -#endif // LM_STATE__ +#endif // LM_STATE_H diff --git a/klm/lm/test.arpa b/klm/lm/test.arpa index ef214eae..c4d2e6df 100644 --- a/klm/lm/test.arpa +++ b/klm/lm/test.arpa @@ -105,7 +105,7 @@ ngram 5=4 -0.04835128 looking on a -0.4771212 -3 also would consider -7 -6 however -12 --7 to look good +-7 to look a \4-grams: -0.009249173 looking on a little -0.4771212 diff --git a/klm/lm/test_nounk.arpa b/klm/lm/test_nounk.arpa index 060733d9..e38fc854 100644 --- a/klm/lm/test_nounk.arpa +++ b/klm/lm/test_nounk.arpa @@ -101,7 +101,7 @@ ngram 5=4 -0.1892331 little more loin -0.04835128 looking on a -0.4771212 -3 also would consider -7 --7 to look good +-7 to look a \4-grams: -0.009249173 looking on a little -0.4771212 diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index d9895f89..5f8e7ce7 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -99,8 +99,11 @@ template util::BitAddress BitPackedMiddle::Find(WordInd } template void BitPackedMiddle::FinishedLoading(uint64_t next_end, const Config &config) { - uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits(); - bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end); + // Write at insert_index. . . + uint64_t last_next_write = insert_index_ * total_bits_ + + // at the offset where the next pointers are stored. + (total_bits_ - bhiksha_.InlineBits()); + bhiksha_.WriteNext(base_, last_next_write, insert_index_, next_end); bhiksha_.FinishedLoading(config); } diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index d858ab5e..cd39298b 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -1,5 +1,5 @@ -#ifndef LM_TRIE__ -#define LM_TRIE__ +#ifndef LM_TRIE_H +#define LM_TRIE_H #include "lm/weights.hh" #include "lm/word_index.hh" @@ -143,4 +143,4 @@ class BitPackedLongest : public BitPacked { } // namespace ngram } // namespace lm -#endif // LM_TRIE__ +#endif // LM_TRIE_H diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index 126d43ab..c3f46874 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -106,14 +107,20 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &temp_pre } struct ThrowCombine { - void operator()(std::size_t /*entry_size*/, const void * /*first*/, const void * /*second*/, FILE * /*out*/) const { - UTIL_THROW(FormatLoadException, "Duplicate n-gram detected."); + void operator()(std::size_t entry_size, unsigned char order, const void *first, const void *second, FILE * /*out*/) const { + const WordIndex *base = reinterpret_cast(first); + FormatLoadException e; + e << "Duplicate n-gram detected with vocab ids"; + for (const WordIndex *i = base; i != base + order; ++i) { + e << ' ' << *i; + } + throw e; } }; // Useful for context files that just contain records with no value. struct FirstCombine { - void operator()(std::size_t entry_size, const void *first, const void * /*second*/, FILE *out) const { + void operator()(std::size_t entry_size, unsigned char /*order*/, const void *first, const void * /*second*/, FILE *out) const { util::WriteOrThrow(out, first, entry_size); } }; @@ -133,7 +140,7 @@ template FILE *MergeSortedFiles(FILE *first_file, FILE *second_f util::WriteOrThrow(out_file.get(), second.Data(), entry_size); ++second; } else { - combine(entry_size, first.Data(), second.Data(), out_file.get()); + combine(entry_size, order, first.Data(), second.Data(), out_file.get()); ++first; ++second; } } @@ -248,11 +255,13 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; if (order == counts.size()) { for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); + std::reverse_iterator it(reinterpret_cast(out) + order); + ReadNGram(f, order, vocab, it, *reinterpret_cast(out + words_size), warn); } } else { for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); + std::reverse_iterator it(reinterpret_cast(out) + order); + ReadNGram(f, order, vocab, it, *reinterpret_cast(out + words_size), warn); } } // Sort full records by full n-gram. diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh index 1afd9562..e5406d9b 100644 --- a/klm/lm/trie_sort.hh +++ b/klm/lm/trie_sort.hh @@ -1,7 +1,7 @@ // Step of trie builder: create sorted files. -#ifndef LM_TRIE_SORT__ -#define LM_TRIE_SORT__ +#ifndef LM_TRIE_SORT_H +#define LM_TRIE_SORT_H #include "lm/max_order.hh" #include "lm/word_index.hh" @@ -111,4 +111,4 @@ class SortedFiles { } // namespace ngram } // namespace lm -#endif // LM_TRIE_SORT__ +#endif // LM_TRIE_SORT_H diff --git a/klm/lm/value.hh b/klm/lm/value.hh index ba716713..36e87084 100644 --- a/klm/lm/value.hh +++ b/klm/lm/value.hh @@ -1,5 +1,5 @@ -#ifndef LM_VALUE__ -#define LM_VALUE__ +#ifndef LM_VALUE_H +#define LM_VALUE_H #include "lm/model_type.hh" #include "lm/value_build.hh" @@ -154,4 +154,4 @@ struct RestValue { } // namespace ngram } // namespace lm -#endif // LM_VALUE__ +#endif // LM_VALUE_H diff --git a/klm/lm/value_build.hh b/klm/lm/value_build.hh index 461e6a5c..6fd26ef8 100644 --- a/klm/lm/value_build.hh +++ b/klm/lm/value_build.hh @@ -1,5 +1,5 @@ -#ifndef LM_VALUE_BUILD__ -#define LM_VALUE_BUILD__ +#ifndef LM_VALUE_BUILD_H +#define LM_VALUE_BUILD_H #include "lm/weights.hh" #include "lm/word_index.hh" @@ -94,4 +94,4 @@ template class LowerRestBuild { } // namespace ngram } // namespace lm -#endif // LM_VALUE_BUILD__ +#endif // LM_VALUE_BUILD_H diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh index 7a3e2379..2a2690e1 100644 --- a/klm/lm/virtual_interface.hh +++ b/klm/lm/virtual_interface.hh @@ -1,5 +1,5 @@ -#ifndef LM_VIRTUAL_INTERFACE__ -#define LM_VIRTUAL_INTERFACE__ +#ifndef LM_VIRTUAL_INTERFACE_H +#define LM_VIRTUAL_INTERFACE_H #include "lm/return.hh" #include "lm/word_index.hh" @@ -157,4 +157,4 @@ class Model { } // mamespace base } // namespace lm -#endif // LM_VIRTUAL_INTERFACE__ +#endif // LM_VIRTUAL_INTERFACE_H diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 7f0878f4..2285d279 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -170,11 +170,15 @@ struct ProbingVocabularyHeader { ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {} +uint64_t ProbingVocabulary::Size(uint64_t entries, float probing_multiplier) { + return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, probing_multiplier); +} + uint64_t ProbingVocabulary::Size(uint64_t entries, const Config &config) { - return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier); + return Size(entries, config.probing_multiplier); } -void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { +void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated) { header_ = static_cast(start); lookup_ = Lookup(static_cast(start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)), allocated); bound_ = 1; @@ -201,12 +205,12 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) { return 0; } else { if (enumerate_) enumerate_->Add(bound_, str); - lookup_.Insert(ProbingVocabuaryEntry::Make(hashed, bound_)); + lookup_.Insert(ProbingVocabularyEntry::Make(hashed, bound_)); return bound_++; } } -void ProbingVocabulary::InternalFinishedLoading() { +void ProbingVocabulary::FinishedLoading() { lookup_.FinishedInserting(); header_->bound = bound_; header_->version = kProbingVocabularyVersion; diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 074b74d8..d6ae07b8 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -1,9 +1,11 @@ -#ifndef LM_VOCAB__ -#define LM_VOCAB__ +#ifndef LM_VOCAB_H +#define LM_VOCAB_H #include "lm/enumerate_vocab.hh" #include "lm/lm_exception.hh" #include "lm/virtual_interface.hh" +#include "util/fake_ofstream.hh" +#include "util/murmur_hash.hh" #include "util/pool.hh" #include "util/probing_hash_table.hh" #include "util/sorted_uniform.hh" @@ -104,17 +106,16 @@ class SortedVocabulary : public base::Vocabulary { #pragma pack(push) #pragma pack(4) -struct ProbingVocabuaryEntry { +struct ProbingVocabularyEntry { uint64_t key; WordIndex value; typedef uint64_t Key; - uint64_t GetKey() const { - return key; - } + uint64_t GetKey() const { return key; } + void SetKey(uint64_t to) { key = to; } - static ProbingVocabuaryEntry Make(uint64_t key, WordIndex value) { - ProbingVocabuaryEntry ret; + static ProbingVocabularyEntry Make(uint64_t key, WordIndex value) { + ProbingVocabularyEntry ret; ret.key = key; ret.value = value; return ret; @@ -132,13 +133,18 @@ class ProbingVocabulary : public base::Vocabulary { return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; } + static uint64_t Size(uint64_t entries, float probing_multiplier); + // This just unwraps Config to get the probing_multiplier. static uint64_t Size(uint64_t entries, const Config &config); // Vocab words are [0, Bound()). WordIndex Bound() const { return bound_; } // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. - void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); + void SetupMemory(void *start, std::size_t allocated); + void SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { + SetupMemory(start, allocated); + } void Relocate(void *new_start); @@ -147,8 +153,9 @@ class ProbingVocabulary : public base::Vocabulary { WordIndex Insert(const StringPiece &str); template void FinishedLoading(Weights * /*reorder_vocab*/) { - InternalFinishedLoading(); + FinishedLoading(); } + void FinishedLoading(); std::size_t UnkCountChangePadding() const { return 0; } @@ -157,9 +164,7 @@ class ProbingVocabulary : public base::Vocabulary { void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); private: - void InternalFinishedLoading(); - - typedef util::ProbingHashTable Lookup; + typedef util::ProbingHashTable Lookup; Lookup lookup_; @@ -181,7 +186,64 @@ template void CheckSpecials(const Config &config, const Vocab &voc if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, ""); } +class WriteUniqueWords { + public: + explicit WriteUniqueWords(int fd) : word_list_(fd) {} + + void operator()(const StringPiece &word) { + word_list_ << word << '\0'; + } + + private: + util::FakeOFStream word_list_; +}; + +class NoOpUniqueWords { + public: + NoOpUniqueWords() {} + void operator()(const StringPiece &word) {} +}; + +template class GrowableVocab { + public: + static std::size_t MemUsage(WordIndex content) { + return Lookup::MemUsage(content > 2 ? content : 2); + } + + // Does not take ownership of write_wordi + template GrowableVocab(WordIndex initial_size, const NewWordConstruct &new_word_construct = NewWordAction()) + : lookup_(initial_size), new_word_(new_word_construct) { + FindOrInsert(""); // Force 0 + FindOrInsert(""); // Force 1 + FindOrInsert(""); // Force 2 + } + + WordIndex Index(const StringPiece &str) const { + Lookup::ConstIterator i; + return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; + } + + WordIndex FindOrInsert(const StringPiece &word) { + ProbingVocabularyEntry entry = ProbingVocabularyEntry::Make(util::MurmurHashNative(word.data(), word.size()), Size()); + Lookup::MutableIterator it; + if (!lookup_.FindOrInsert(entry, it)) { + new_word_(word); + UTIL_THROW_IF(Size() >= std::numeric_limits::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh"); + } + return it->value; + } + + WordIndex Size() const { return lookup_.Size(); } + + private: + typedef util::AutoProbing Lookup; + + Lookup lookup_; + + NewWordAction new_word_; +}; + } // namespace ngram } // namespace lm -#endif // LM_VOCAB__ +#endif // LM_VOCAB_H diff --git a/klm/lm/weights.hh b/klm/lm/weights.hh index bd5d8034..da1963d8 100644 --- a/klm/lm/weights.hh +++ b/klm/lm/weights.hh @@ -1,5 +1,5 @@ -#ifndef LM_WEIGHTS__ -#define LM_WEIGHTS__ +#ifndef LM_WEIGHTS_H +#define LM_WEIGHTS_H // Weights for n-grams. Probability and possibly a backoff. @@ -19,4 +19,4 @@ struct RestWeights { }; } // namespace lm -#endif // LM_WEIGHTS__ +#endif // LM_WEIGHTS_H diff --git a/klm/lm/word_index.hh b/klm/lm/word_index.hh index e09557a7..a5a0fda8 100644 --- a/klm/lm/word_index.hh +++ b/klm/lm/word_index.hh @@ -1,6 +1,6 @@ // Separate header because this is used often. -#ifndef LM_WORD_INDEX__ -#define LM_WORD_INDEX__ +#ifndef LM_WORD_INDEX_H +#define LM_WORD_INDEX_H #include diff --git a/klm/lm/wrappers/README b/klm/lm/wrappers/README new file mode 100644 index 00000000..56c34c23 --- /dev/null +++ b/klm/lm/wrappers/README @@ -0,0 +1,3 @@ +This directory is for wrappers around other people's LMs, presenting an interface similar to KenLM's. You will need to have their LM installed. + +NPLM is a work in progress. diff --git a/klm/lm/wrappers/nplm.cc b/klm/lm/wrappers/nplm.cc new file mode 100644 index 00000000..70622bd2 --- /dev/null +++ b/klm/lm/wrappers/nplm.cc @@ -0,0 +1,90 @@ +#include "lm/wrappers/nplm.hh" +#include "util/exception.hh" +#include "util/file.hh" + +#include + +#include + +#include "neuralLM.h" + +namespace lm { +namespace np { + +Vocabulary::Vocabulary(const nplm::vocabulary &vocab) + : base::Vocabulary(vocab.lookup_word(""), vocab.lookup_word(""), vocab.lookup_word("")), + vocab_(vocab), null_word_(vocab.lookup_word("")) {} + +Vocabulary::~Vocabulary() {} + +WordIndex Vocabulary::Index(const std::string &str) const { + return vocab_.lookup_word(str); +} + +bool Model::Recognize(const std::string &name) { + try { + util::scoped_fd file(util::OpenReadOrThrow(name.c_str())); + char magic_check[16]; + util::ReadOrThrow(file.get(), magic_check, sizeof(magic_check)); + const char nnlm_magic[] = "\\config\nversion "; + return !memcmp(magic_check, nnlm_magic, 16); + } catch (const util::Exception &) { + return false; + } +} + +Model::Model(const std::string &file, std::size_t cache) + : base_instance_(new nplm::neuralLM(file)), vocab_(base_instance_->get_vocabulary()), cache_size_(cache) { + UTIL_THROW_IF(base_instance_->get_order() > NPLM_MAX_ORDER, util::Exception, "This NPLM has order " << (unsigned int)base_instance_->get_order() << " but the KenLM wrapper was compiled with " << NPLM_MAX_ORDER << ". Change the defintion of NPLM_MAX_ORDER and recompile."); + // log10 compatible with backoff models. + base_instance_->set_log_base(10.0); + State begin_sentence, null_context; + std::fill(begin_sentence.words, begin_sentence.words + NPLM_MAX_ORDER - 1, base_instance_->lookup_word("")); + null_word_ = base_instance_->lookup_word(""); + std::fill(null_context.words, null_context.words + NPLM_MAX_ORDER - 1, null_word_); + + Init(begin_sentence, null_context, vocab_, base_instance_->get_order()); +} + +Model::~Model() {} + +FullScoreReturn Model::FullScore(const State &from, const WordIndex new_word, State &out_state) const { + nplm::neuralLM *lm = backend_.get(); + if (!lm) { + lm = new nplm::neuralLM(*base_instance_); + backend_.reset(lm); + lm->set_cache(cache_size_); + } + // State is in natural word order. + FullScoreReturn ret; + for (int i = 0; i < lm->get_order() - 1; ++i) { + lm->staging_ngram()(i) = from.words[i]; + } + lm->staging_ngram()(lm->get_order() - 1) = new_word; + ret.prob = lm->lookup_from_staging(); + // Always say full order. + ret.ngram_length = lm->get_order(); + // Shift everything down by one. + memcpy(out_state.words, from.words + 1, sizeof(WordIndex) * (lm->get_order() - 2)); + out_state.words[lm->get_order() - 2] = new_word; + // Fill in trailing words with zeros so state comparison works. + memset(out_state.words + lm->get_order() - 1, 0, sizeof(WordIndex) * (NPLM_MAX_ORDER - lm->get_order())); + return ret; +} + +// TODO: optimize with direct call? +FullScoreReturn Model::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const { + // State is in natural word order. The API here specifies reverse order. + std::size_t state_length = std::min(Order() - 1, context_rend - context_rbegin); + State state; + // Pad with null words. + for (lm::WordIndex *i = state.words; i < state.words + Order() - 1 - state_length; ++i) { + *i = null_word_; + } + // Put new words at the end. + std::reverse_copy(context_rbegin, context_rbegin + state_length, state.words + Order() - 1 - state_length); + return FullScore(state, new_word, out_state); +} + +} // namespace np +} // namespace lm diff --git a/klm/lm/wrappers/nplm.hh b/klm/lm/wrappers/nplm.hh new file mode 100644 index 00000000..b7dd4a21 --- /dev/null +++ b/klm/lm/wrappers/nplm.hh @@ -0,0 +1,83 @@ +#ifndef LM_WRAPPERS_NPLM_H +#define LM_WRAPPERS_NPLM_H + +#include "lm/facade.hh" +#include "lm/max_order.hh" +#include "util/string_piece.hh" + +#include +#include + +/* Wrapper to NPLM "by Ashish Vaswani, with contributions from David Chiang + * and Victoria Fossum." + * http://nlg.isi.edu/software/nplm/ + */ + +namespace nplm { +class vocabulary; +class neuralLM; +} // namespace nplm + +namespace lm { +namespace np { + +class Vocabulary : public base::Vocabulary { + public: + Vocabulary(const nplm::vocabulary &vocab); + + ~Vocabulary(); + + WordIndex Index(const std::string &str) const; + + // TODO: lobby them to support StringPiece + WordIndex Index(const StringPiece &str) const { + return Index(std::string(str.data(), str.size())); + } + + lm::WordIndex NullWord() const { return null_word_; } + + private: + const nplm::vocabulary &vocab_; + + const lm::WordIndex null_word_; +}; + +// Sorry for imposing my limitations on your code. +#define NPLM_MAX_ORDER 7 + +struct State { + WordIndex words[NPLM_MAX_ORDER - 1]; +}; + +class Model : public lm::base::ModelFacade { + private: + typedef lm::base::ModelFacade P; + + public: + // Does this look like an NPLM? + static bool Recognize(const std::string &file); + + explicit Model(const std::string &file, std::size_t cache_size = 1 << 20); + + ~Model(); + + FullScoreReturn FullScore(const State &from, const WordIndex new_word, State &out_state) const; + + FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; + + private: + boost::scoped_ptr base_instance_; + + mutable boost::thread_specific_ptr backend_; + + Vocabulary vocab_; + + lm::WordIndex null_word_; + + const std::size_t cache_size_; +}; + +} // namespace np +} // namespace lm + +#endif // LM_WRAPPERS_NPLM_H -- cgit v1.2.3