diff options
Diffstat (limited to 'klm/lm/builder/corpus_count.cc')
-rw-r--r-- | klm/lm/builder/corpus_count.cc | 100 |
1 files changed, 32 insertions, 68 deletions
diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc index ccc06efc..590e79fa 100644 --- a/klm/lm/builder/corpus_count.cc +++ b/klm/lm/builder/corpus_count.cc @@ -2,6 +2,7 @@ #include "lm/builder/ngram.hh" #include "lm/lm_exception.hh" +#include "lm/vocab.hh" #include "lm/word_index.hh" #include "util/fake_ofstream.hh" #include "util/file.hh" @@ -37,60 +38,6 @@ struct VocabEntry { }; #pragma pack(pop) -const float kProbingMultiplier = 1.5; - -class VocabHandout { - public: - static std::size_t MemUsage(WordIndex initial_guess) { - if (initial_guess < 2) initial_guess = 2; - return util::CheckOverflow(Table::Size(initial_guess, kProbingMultiplier)); - } - - explicit VocabHandout(int fd, WordIndex initial_guess) : - table_backing_(util::CallocOrThrow(MemUsage(initial_guess))), - table_(table_backing_.get(), MemUsage(initial_guess)), - double_cutoff_(std::max<std::size_t>(initial_guess * 1.1, 1)), - word_list_(fd) { - Lookup("<unk>"); // Force 0 - Lookup("<s>"); // Force 1 - Lookup("</s>"); // Force 2 - } - - WordIndex Lookup(const StringPiece &word) { - VocabEntry entry; - entry.key = util::MurmurHashNative(word.data(), word.size()); - entry.value = table_.SizeNoSerialization(); - - Table::MutableIterator it; - if (table_.FindOrInsert(entry, it)) - return it->value; - word_list_ << word << '\0'; - UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh."); - if (Size() >= double_cutoff_) { - table_backing_.call_realloc(table_.DoubleTo()); - table_.Double(table_backing_.get()); - double_cutoff_ *= 2; - } - return entry.value; - } - - WordIndex Size() const { - return table_.SizeNoSerialization(); - } - - private: - // TODO: factor out a resizable probing hash table. - // TODO: use mremap on linux to get all zeros on resizes. - util::scoped_malloc table_backing_; - - typedef util::ProbingHashTable<VocabEntry, util::IdentityHash> Table; - Table table_; - - std::size_t double_cutoff_; - - util::FakeOFStream word_list_; -}; - class DedupeHash : public std::unary_function<const WordIndex *, bool> { public: explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {} @@ -127,6 +74,10 @@ struct DedupeEntry { } }; + +// TODO: don't have this here, should be with probing hash table defaults? +const float kProbingMultiplier = 1.5; + typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe; class Writer { @@ -220,37 +171,50 @@ float CorpusCount::DedupeMultiplier(std::size_t order) { } std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) { - return VocabHandout::MemUsage(vocab_estimate); + return ngram::GrowableVocab<ngram::WriteUniqueWords>::MemUsage(vocab_estimate); } -CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block) +CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol) : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count), dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), - dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) { + dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)), + disallowed_symbol_action_(disallowed_symbol) { } -void CorpusCount::Run(const util::stream::ChainPosition &position) { - UTIL_TIMER("(%w s) Counted n-grams\n"); +namespace { + void ComplainDisallowed(StringPiece word, WarningAction &action) { + switch (action) { + case SILENT: + return; + case COMPLAIN: + std::cerr << "Warning: " << word << " appears in the input. All instances of <s>, </s>, and <unk> will be interpreted as whitespace." << std::endl; + action = SILENT; + return; + case THROW_UP: + UTIL_THROW(FormatLoadException, "Special word " << word << " is not allowed in the corpus. I plan to support models containing <unk> in the future. Pass --skip_symbols to convert these symbols to whitespace."); + } + } +} // namespace - VocabHandout vocab(vocab_write_, type_count_); +void CorpusCount::Run(const util::stream::ChainPosition &position) { + ngram::GrowableVocab<ngram::WriteUniqueWords> vocab(type_count_, vocab_write_); token_count_ = 0; type_count_ = 0; - const WordIndex end_sentence = vocab.Lookup("</s>"); + const WordIndex end_sentence = vocab.FindOrInsert("</s>"); Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); uint64_t count = 0; bool delimiters[256]; - memset(delimiters, 0, sizeof(delimiters)); - const char kDelimiterSet[] = "\0\t\n\r "; - for (const char *i = kDelimiterSet; i < kDelimiterSet + sizeof(kDelimiterSet); ++i) { - delimiters[static_cast<unsigned char>(*i)] = true; - } + util::BoolCharacter::Build("\0\t\n\r ", delimiters); try { while(true) { StringPiece line(from_.ReadLine()); writer.StartSentence(); for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) { - WordIndex word = vocab.Lookup(*w); - UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing <unk> in the future."); + WordIndex word = vocab.FindOrInsert(*w); + if (word <= 2) { + ComplainDisallowed(*w, disallowed_symbol_action_); + continue; + } writer.Append(word); ++count; } |