summaryrefslogtreecommitdiff
path: root/klm/lm/builder/corpus_count.cc
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2014-10-13 00:42:37 -0400
committerChris Dyer <redpony@gmail.com>2014-10-13 00:42:37 -0400
commitb1ed81ef3216b212295afa76c5d20a56fb647204 (patch)
tree9633cdc1b8a341dfa58b0b7fec0e2cae44d28835 /klm/lm/builder/corpus_count.cc
parent1b17f61d359be6e1c3cea29f8c100db3bcdd73a0 (diff)
new kenlm
Diffstat (limited to 'klm/lm/builder/corpus_count.cc')
-rw-r--r--klm/lm/builder/corpus_count.cc100
1 files changed, 32 insertions, 68 deletions
diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc
index ccc06efc..590e79fa 100644
--- a/klm/lm/builder/corpus_count.cc
+++ b/klm/lm/builder/corpus_count.cc
@@ -2,6 +2,7 @@
#include "lm/builder/ngram.hh"
#include "lm/lm_exception.hh"
+#include "lm/vocab.hh"
#include "lm/word_index.hh"
#include "util/fake_ofstream.hh"
#include "util/file.hh"
@@ -37,60 +38,6 @@ struct VocabEntry {
};
#pragma pack(pop)
-const float kProbingMultiplier = 1.5;
-
-class VocabHandout {
- public:
- static std::size_t MemUsage(WordIndex initial_guess) {
- if (initial_guess < 2) initial_guess = 2;
- return util::CheckOverflow(Table::Size(initial_guess, kProbingMultiplier));
- }
-
- explicit VocabHandout(int fd, WordIndex initial_guess) :
- table_backing_(util::CallocOrThrow(MemUsage(initial_guess))),
- table_(table_backing_.get(), MemUsage(initial_guess)),
- double_cutoff_(std::max<std::size_t>(initial_guess * 1.1, 1)),
- word_list_(fd) {
- Lookup("<unk>"); // Force 0
- Lookup("<s>"); // Force 1
- Lookup("</s>"); // Force 2
- }
-
- WordIndex Lookup(const StringPiece &word) {
- VocabEntry entry;
- entry.key = util::MurmurHashNative(word.data(), word.size());
- entry.value = table_.SizeNoSerialization();
-
- Table::MutableIterator it;
- if (table_.FindOrInsert(entry, it))
- return it->value;
- word_list_ << word << '\0';
- UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh.");
- if (Size() >= double_cutoff_) {
- table_backing_.call_realloc(table_.DoubleTo());
- table_.Double(table_backing_.get());
- double_cutoff_ *= 2;
- }
- return entry.value;
- }
-
- WordIndex Size() const {
- return table_.SizeNoSerialization();
- }
-
- private:
- // TODO: factor out a resizable probing hash table.
- // TODO: use mremap on linux to get all zeros on resizes.
- util::scoped_malloc table_backing_;
-
- typedef util::ProbingHashTable<VocabEntry, util::IdentityHash> Table;
- Table table_;
-
- std::size_t double_cutoff_;
-
- util::FakeOFStream word_list_;
-};
-
class DedupeHash : public std::unary_function<const WordIndex *, bool> {
public:
explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {}
@@ -127,6 +74,10 @@ struct DedupeEntry {
}
};
+
+// TODO: don't have this here, should be with probing hash table defaults?
+const float kProbingMultiplier = 1.5;
+
typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe;
class Writer {
@@ -220,37 +171,50 @@ float CorpusCount::DedupeMultiplier(std::size_t order) {
}
std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
- return VocabHandout::MemUsage(vocab_estimate);
+ return ngram::GrowableVocab<ngram::WriteUniqueWords>::MemUsage(vocab_estimate);
}
-CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)
+CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol)
: from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
- dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) {
+ dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)),
+ disallowed_symbol_action_(disallowed_symbol) {
}
-void CorpusCount::Run(const util::stream::ChainPosition &position) {
- UTIL_TIMER("(%w s) Counted n-grams\n");
+namespace {
+ void ComplainDisallowed(StringPiece word, WarningAction &action) {
+ switch (action) {
+ case SILENT:
+ return;
+ case COMPLAIN:
+ std::cerr << "Warning: " << word << " appears in the input. All instances of <s>, </s>, and <unk> will be interpreted as whitespace." << std::endl;
+ action = SILENT;
+ return;
+ case THROW_UP:
+ UTIL_THROW(FormatLoadException, "Special word " << word << " is not allowed in the corpus. I plan to support models containing <unk> in the future. Pass --skip_symbols to convert these symbols to whitespace.");
+ }
+ }
+} // namespace
- VocabHandout vocab(vocab_write_, type_count_);
+void CorpusCount::Run(const util::stream::ChainPosition &position) {
+ ngram::GrowableVocab<ngram::WriteUniqueWords> vocab(type_count_, vocab_write_);
token_count_ = 0;
type_count_ = 0;
- const WordIndex end_sentence = vocab.Lookup("</s>");
+ const WordIndex end_sentence = vocab.FindOrInsert("</s>");
Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
uint64_t count = 0;
bool delimiters[256];
- memset(delimiters, 0, sizeof(delimiters));
- const char kDelimiterSet[] = "\0\t\n\r ";
- for (const char *i = kDelimiterSet; i < kDelimiterSet + sizeof(kDelimiterSet); ++i) {
- delimiters[static_cast<unsigned char>(*i)] = true;
- }
+ util::BoolCharacter::Build("\0\t\n\r ", delimiters);
try {
while(true) {
StringPiece line(from_.ReadLine());
writer.StartSentence();
for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) {
- WordIndex word = vocab.Lookup(*w);
- UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing <unk> in the future.");
+ WordIndex word = vocab.FindOrInsert(*w);
+ if (word <= 2) {
+ ComplainDisallowed(*w, disallowed_symbol_action_);
+ continue;
+ }
writer.Append(word);
++count;
}