summaryrefslogtreecommitdiff
path: root/klm/lm/builder/corpus_count.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/builder/corpus_count.cc')
-rw-r--r--klm/lm/builder/corpus_count.cc82
1 files changed, 59 insertions, 23 deletions
diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc
index abea4ed0..aea93ad1 100644
--- a/klm/lm/builder/corpus_count.cc
+++ b/klm/lm/builder/corpus_count.cc
@@ -3,6 +3,7 @@
#include "lm/builder/ngram.hh"
#include "lm/lm_exception.hh"
#include "lm/word_index.hh"
+#include "util/fake_ofstream.hh"
#include "util/file.hh"
#include "util/file_piece.hh"
#include "util/murmur_hash.hh"
@@ -23,39 +24,71 @@ namespace lm {
namespace builder {
namespace {
+#pragma pack(push)
+#pragma pack(4)
+struct VocabEntry {
+ typedef uint64_t Key;
+
+ uint64_t GetKey() const { return key; }
+ void SetKey(uint64_t to) { key = to; }
+
+ uint64_t key;
+ lm::WordIndex value;
+};
+#pragma pack(pop)
+
+const float kProbingMultiplier = 1.5;
+
class VocabHandout {
public:
- explicit VocabHandout(int fd) {
- util::scoped_fd duped(util::DupOrThrow(fd));
- word_list_.reset(util::FDOpenOrThrow(duped));
-
+ static std::size_t MemUsage(WordIndex initial_guess) {
+ if (initial_guess < 2) initial_guess = 2;
+ return util::CheckOverflow(Table::Size(initial_guess, kProbingMultiplier));
+ }
+
+ explicit VocabHandout(int fd, WordIndex initial_guess) :
+ table_backing_(util::CallocOrThrow(MemUsage(initial_guess))),
+ table_(table_backing_.get(), MemUsage(initial_guess)),
+ double_cutoff_(std::max<std::size_t>(initial_guess * 1.1, 1)),
+ word_list_(fd) {
Lookup("<unk>"); // Force 0
Lookup("<s>"); // Force 1
Lookup("</s>"); // Force 2
}
WordIndex Lookup(const StringPiece &word) {
- uint64_t hashed = util::MurmurHashNative(word.data(), word.size());
- std::pair<Seen::iterator, bool> ret(seen_.insert(std::pair<uint64_t, lm::WordIndex>(hashed, seen_.size())));
- if (ret.second) {
- char null_delimit = 0;
- util::WriteOrThrow(word_list_.get(), word.data(), word.size());
- util::WriteOrThrow(word_list_.get(), &null_delimit, 1);
- UTIL_THROW_IF(seen_.size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh.");
+ VocabEntry entry;
+ entry.key = util::MurmurHashNative(word.data(), word.size());
+ entry.value = table_.SizeNoSerialization();
+
+ Table::MutableIterator it;
+ if (table_.FindOrInsert(entry, it))
+ return it->value;
+ word_list_ << word << '\0';
+ UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh.");
+ if (Size() >= double_cutoff_) {
+ table_backing_.call_realloc(table_.DoubleTo());
+ table_.Double(table_backing_.get());
+ double_cutoff_ *= 2;
}
- return ret.first->second;
+ return entry.value;
}
WordIndex Size() const {
- return seen_.size();
+ return table_.SizeNoSerialization();
}
private:
- typedef boost::unordered_map<uint64_t, lm::WordIndex> Seen;
+ // TODO: factor out a resizable probing hash table.
+ // TODO: use mremap on linux to get all zeros on resizes.
+ util::scoped_malloc table_backing_;
- Seen seen_;
+ typedef util::ProbingHashTable<VocabEntry, util::IdentityHash> Table;
+ Table table_;
- util::scoped_FILE word_list_;
+ std::size_t double_cutoff_;
+
+ util::FakeOFStream word_list_;
};
class DedupeHash : public std::unary_function<const WordIndex *, bool> {
@@ -85,6 +118,7 @@ class DedupeEquals : public std::binary_function<const WordIndex *, const WordIn
struct DedupeEntry {
typedef WordIndex *Key;
Key GetKey() const { return key; }
+ void SetKey(WordIndex *to) { key = to; }
Key key;
static DedupeEntry Construct(WordIndex *at) {
DedupeEntry ret;
@@ -95,8 +129,6 @@ struct DedupeEntry {
typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe;
-const float kProbingMultiplier = 1.5;
-
class Writer {
public:
Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)
@@ -105,7 +137,7 @@ class Writer {
dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)),
buffer_(new WordIndex[order - 1]),
block_size_(position.GetChain().BlockSize()) {
- dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
+ dedupe_.Clear();
assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size);
if (order == 1) {
// Add special words. AdjustCounts is responsible if order != 1.
@@ -149,7 +181,7 @@ class Writer {
}
// Block end. Need to store the context in a temporary buffer.
std::copy(gram_.begin() + 1, gram_.end(), buffer_.get());
- dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0]));
+ dedupe_.Clear();
block_->SetValidSize(block_size_);
gram_.ReBase((++block_)->Get());
std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin());
@@ -187,18 +219,22 @@ float CorpusCount::DedupeMultiplier(std::size_t order) {
return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram::TotalSize(order));
}
+std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
+ return VocabHandout::MemUsage(vocab_estimate);
+}
+
CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)
: from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) {
- token_count_ = 0;
- type_count_ = 0;
}
void CorpusCount::Run(const util::stream::ChainPosition &position) {
UTIL_TIMER("(%w s) Counted n-grams\n");
- VocabHandout vocab(vocab_write_);
+ VocabHandout vocab(vocab_write_, type_count_);
+ token_count_ = 0;
+ type_count_ = 0;
const WordIndex end_sentence = vocab.Lookup("</s>");
Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
uint64_t count = 0;