From ddb3eb611b00a2a80936b92b95e94d33896990da Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Wed, 10 Oct 2012 19:08:57 +0100 Subject: Update KenLM --- klm/lm/binary_format.cc | 16 ++++++++++-- klm/lm/build_binary.cc | 2 ++ klm/lm/search_hashed.cc | 2 +- klm/lm/search_hashed.hh | 2 +- klm/lm/search_trie.cc | 2 +- klm/lm/sri_test.cc | 65 ------------------------------------------------- klm/lm/trie_sort.cc | 20 ++++++--------- klm/lm/trie_sort.hh | 2 -- klm/util/file.cc | 27 +++++++++++++++----- klm/util/file.hh | 6 ++--- klm/util/file_piece.cc | 2 ++ 11 files changed, 52 insertions(+), 94 deletions(-) delete mode 100644 klm/lm/sri_test.cc diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index fd841e59..efa67056 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -83,7 +83,13 @@ void WriteHeader(void *to, const Parameters ¶ms) { uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) { if (config.write_mmap) { std::size_t total = TotalHeaderSize(order) + memory_size; - backing.vocab.reset(util::MapZeroedWrite(config.write_mmap, total, backing.file), total, util::scoped_memory::MMAP_ALLOCATED); + backing.file.reset(util::CreateOrThrow(config.write_mmap)); + if (config.write_method == Config::WRITE_MMAP) { + backing.vocab.reset(util::MapZeroedWrite(backing.file.get(), total), total, util::scoped_memory::MMAP_ALLOCATED); + } else { + util::ResizeOrThrow(backing.file.get(), 0); + util::MapAnonymous(total, backing.vocab); + } strncpy(reinterpret_cast(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order)); return reinterpret_cast(backing.vocab.get()) + TotalHeaderSize(order); } else { @@ -121,12 +127,14 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector &counts, std::size_t vocab_pad, Backing &backing) { if (!config.write_mmap) return; - util::SyncOrThrow(backing.vocab.get(), backing.vocab.size()); switch (config.write_method) { case Config::WRITE_MMAP: + util::SyncOrThrow(backing.vocab.get(), backing.vocab.size()); util::SyncOrThrow(backing.search.get(), backing.search.size()); break; case Config::WRITE_AFTER: + util::SeekOrThrow(backing.file.get(), 0); + util::WriteOrThrow(backing.file.get(), backing.vocab.get(), backing.vocab.size()); util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad); util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size()); util::FSyncOrThrow(backing.file.get()); @@ -141,6 +149,10 @@ void FinishFile(const Config &config, ModelType model_type, unsigned int search_ params.fixed.has_vocabulary = config.include_vocab; params.fixed.search_version = search_version; WriteHeader(backing.vocab.get(), params); + if (config.write_method == Config::WRITE_AFTER) { + util::SeekOrThrow(backing.file.get(), 0); + util::WriteOrThrow(backing.file.get(), backing.vocab.get(), TotalHeaderSize(counts.size())); + } } namespace detail { diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index efe99899..2b8c9d5b 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -11,6 +11,8 @@ #ifdef WIN32 #include "util/getopt.hh" +#else +#include #endif namespace lm { diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 13942309..a1623834 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -234,7 +234,7 @@ template <> void HashedSearch::DispatchBuild(util::FilePiece &f, c ApplyBuild(f, counts, config, vocab, warn, build); } -template <> void HashedSearch::DispatchBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { +template <> void HashedSearch::DispatchBuild(util::FilePiece &f, const std::vector &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { switch (config.rest_function) { case Config::REST_MAX: { diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 3bcde921..a52f107b 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -161,7 +161,7 @@ template class HashedSearch { {} static uint64_t Size(uint64_t count) { - return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate + return (count + 1) * sizeof(typename Value::Weights); // +1 for hallucinate } const typename Value::Weights &Lookup(WordIndex index) const { diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 832cc9f7..debcfd07 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -89,7 +89,7 @@ class BackoffMessages { if (!HasExtension(weights.backoff)) { weights.backoff = kExtensionBackoff; UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed."); - WriteOrThrow(unigrams, &weights, sizeof(weights)); + util::WriteOrThrow(unigrams, &weights, sizeof(weights)); } const ProbPointer &write_to = *reinterpret_cast(current_ + sizeof(WordIndex)); base[write_to.array][write_to.index] += weights.backoff; diff --git a/klm/lm/sri_test.cc b/klm/lm/sri_test.cc deleted file mode 100644 index e697d722..00000000 --- a/klm/lm/sri_test.cc +++ /dev/null @@ -1,65 +0,0 @@ -#include "lm/sri.hh" - -#include - -#define BOOST_TEST_MODULE SRITest -#include - -namespace lm { -namespace sri { -namespace { - -#define StartTest(word, ngram, score) \ - ret = model.FullScore( \ - state, \ - model.GetVocabulary().Index(word), \ - out);\ - BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ - BOOST_CHECK_EQUAL(static_cast(ngram), ret.ngram_length); \ - BOOST_CHECK_EQUAL(std::min(ngram, 5 - 1), out.valid_length_); - -#define AppendTest(word, ngram, score) \ - StartTest(word, ngram, score) \ - state = out; - -template void Starters(M &model) { - FullScoreReturn ret; - Model::State state(model.BeginSentenceState()); - Model::State out; - - StartTest("looking", 2, -0.4846522); - - // , probability plus backoff - StartTest(",", 1, -1.383514 + -0.4149733); - // probability plus backoff - StartTest("this_is_not_found", 0, -1.995635 + -0.4149733); -} - -template void Continuation(M &model) { - FullScoreReturn ret; - Model::State state(model.BeginSentenceState()); - Model::State out; - - AppendTest("looking", 2, -0.484652); - AppendTest("on", 3, -0.348837); - AppendTest("a", 4, -0.0155266); - AppendTest("little", 5, -0.00306122); - State preserve = state; - AppendTest("the", 1, -4.04005); - AppendTest("biarritz", 1, -1.9889); - AppendTest("not_found", 0, -2.29666); - AppendTest("more", 1, -1.20632); - AppendTest(".", 2, -0.51363); - AppendTest("", 3, -0.0191651); - - state = preserve; - AppendTest("more", 5, -0.00181395); - AppendTest("loin", 5, -0.0432557); -} - -BOOST_AUTO_TEST_CASE(starters) { Model m("test.arpa", 5); Starters(m); } -BOOST_AUTO_TEST_CASE(continuation) { Model m("test.arpa", 5); Continuation(m); } - -} // namespace -} // namespace sri -} // namespace lm diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index 0d83221e..8663e94e 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -22,12 +22,6 @@ namespace lm { namespace ngram { namespace trie { - -void WriteOrThrow(FILE *to, const void *data, size_t size) { - assert(size); - if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); -} - namespace { typedef util::SizedIterator NGramIter; @@ -95,12 +89,12 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &make // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. if (context_begin == context_end) return out.release(); PartialIter i(context_begin); - WriteOrThrow(out.get(), i->Data(), context_size); + util::WriteOrThrow(out.get(), i->Data(), context_size); const void *previous = i->Data(); ++i; for (; i != context_end; ++i) { if (memcmp(previous, i->Data(), context_size)) { - WriteOrThrow(out.get(), i->Data(), context_size); + util::WriteOrThrow(out.get(), i->Data(), context_size); previous = i->Data(); } } @@ -116,7 +110,7 @@ struct ThrowCombine { // Useful for context files that just contain records with no value. struct FirstCombine { void operator()(std::size_t entry_size, const void *first, const void * /*second*/, FILE *out) const { - WriteOrThrow(out, first, entry_size); + util::WriteOrThrow(out, first, entry_size); } }; @@ -129,10 +123,10 @@ template FILE *MergeSortedFiles(FILE *first_file, FILE *second_f EntryCompare less(order); while (first && second) { if (less(first.Data(), second.Data())) { - WriteOrThrow(out_file.get(), first.Data(), entry_size); + util::WriteOrThrow(out_file.get(), first.Data(), entry_size); ++first; } else if (less(second.Data(), first.Data())) { - WriteOrThrow(out_file.get(), second.Data(), entry_size); + util::WriteOrThrow(out_file.get(), second.Data(), entry_size); ++second; } else { combine(entry_size, first.Data(), second.Data(), out_file.get()); @@ -140,7 +134,7 @@ template FILE *MergeSortedFiles(FILE *first_file, FILE *second_f } } for (RecordReader &remains = (first ? first : second); remains; ++remains) { - WriteOrThrow(out_file.get(), remains.Data(), entry_size); + util::WriteOrThrow(out_file.get(), remains.Data(), entry_size); } return out_file.release(); } @@ -164,7 +158,7 @@ void RecordReader::Init(FILE *file, std::size_t entry_size) { void RecordReader::Overwrite(const void *start, std::size_t amount) { long internal = (uint8_t*)start - (uint8_t*)data_.get(); UTIL_THROW_IF(fseek(file_, internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision"); - WriteOrThrow(file_, start, amount); + util::WriteOrThrow(file_, start, amount); long forward = entry_size_ - internal - amount; #if !defined(_WIN32) && !defined(_WIN64) if (forward) diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh index 1e6fce51..2197b80c 100644 --- a/klm/lm/trie_sort.hh +++ b/klm/lm/trie_sort.hh @@ -29,8 +29,6 @@ struct Config; namespace trie { -void WriteOrThrow(FILE *to, const void *data, size_t size); - class EntryCompare : public std::binary_function { public: explicit EntryCompare(unsigned char order) : order_(order) {} diff --git a/klm/util/file.cc b/klm/util/file.cc index ff5e64c9..6bf879ac 100644 --- a/klm/util/file.cc +++ b/klm/util/file.cc @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -111,6 +112,11 @@ void WriteOrThrow(int fd, const void *data_void, std::size_t size) { } } +void WriteOrThrow(FILE *to, const void *data, std::size_t size) { + assert(size); + if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); +} + void FSyncOrThrow(int fd) { // Apparently windows doesn't have fsync? #if !defined(_WIN32) && !defined(_WIN64) @@ -148,6 +154,12 @@ std::FILE *FDOpenOrThrow(scoped_fd &file) { return ret; } +std::FILE *FOpenOrThrow(const char *path, const char *mode) { + std::FILE *ret; + UTIL_THROW_IF(!(ret = fopen(path, mode)), util::ErrnoException, "Could not fopen " << path << " for " << mode); + return ret; +} + TempMaker::TempMaker(const std::string &prefix) : base_(prefix) { base_ += "XXXXXX"; } @@ -247,7 +259,9 @@ mkstemp_and_unlink(char *tmpl) /* Modified for windows and to unlink */ // fd = open (tmpl, O_RDWR | O_CREAT | O_EXCL, _S_IREAD | _S_IWRITE); - fd = _open (tmpl, _O_RDWR | _O_CREAT | _O_TEMPORARY | _O_EXCL | _O_BINARY, _S_IREAD | _S_IWRITE); + int flags = _O_RDWR | _O_CREAT | _O_EXCL | _O_BINARY; + flags |= _O_TEMPORARY; + fd = _open (tmpl, flags, _S_IREAD | _S_IWRITE); if (fd >= 0) { errno = save_errno; @@ -265,17 +279,18 @@ mkstemp_and_unlink(char *tmpl) int mkstemp_and_unlink(char *tmpl) { int ret = mkstemp(tmpl); - if (ret == -1) return -1; - UTIL_THROW_IF(unlink(tmpl), util::ErrnoException, "Failed to delete " << tmpl); + if (ret != -1) { + UTIL_THROW_IF(unlink(tmpl), util::ErrnoException, "Failed to delete " << tmpl); + } return ret; } #endif int TempMaker::Make() const { - std::string copy(base_); - copy.push_back(0); + std::string name(base_); + name.push_back(0); int ret; - UTIL_THROW_IF(-1 == (ret = mkstemp_and_unlink(©[0])), util::ErrnoException, "Failed to make a temporary based on " << base_); + UTIL_THROW_IF(-1 == (ret = mkstemp_and_unlink(&name[0])), util::ErrnoException, "Failed to make a temporary based on " << base_); return ret; } diff --git a/klm/util/file.hh b/klm/util/file.hh index 8af1ff4f..185cb1f3 100644 --- a/klm/util/file.hh +++ b/klm/util/file.hh @@ -80,6 +80,7 @@ void ReadOrThrow(int fd, void *to, std::size_t size); std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount); void WriteOrThrow(int fd, const void *data_void, std::size_t size); +void WriteOrThrow(FILE *to, const void *data, std::size_t size); void FSyncOrThrow(int fd); @@ -90,6 +91,8 @@ void SeekEnd(int fd); std::FILE *FDOpenOrThrow(scoped_fd &file); +std::FILE *FOpenOrThrow(const char *path, const char *mode); + class TempMaker { public: explicit TempMaker(const std::string &prefix); @@ -98,9 +101,6 @@ class TempMaker { int Make() const; std::FILE *MakeFile() const; - // This will force you to close the fd instead of leaving it open. - std::string Name(scoped_fd &opened) const; - private: std::string base_; }; diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index 19a68728..280f438c 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -5,6 +5,8 @@ #include "util/mmap.hh" #ifdef WIN32 #include +#else +#include #endif // WIN32 #include -- cgit v1.2.3