From f111672dd611f78656fceb3df3729a290453ef56 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Wed, 21 Sep 2011 18:23:50 -0400 Subject: Updated kenlm. Includes left state support but not the cdec-side use of it. Updated binary format. --- klm/util/scoped.hh | 58 +++++++++++++++++++----------------------------------- 1 file changed, 20 insertions(+), 38 deletions(-) (limited to 'klm/util/scoped.hh') diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh index d36a7df3..12e6652b 100644 --- a/klm/util/scoped.hh +++ b/klm/util/scoped.hh @@ -1,10 +1,11 @@ #ifndef UTIL_SCOPED__ #define UTIL_SCOPED__ -/* Other scoped objects in the style of scoped_ptr. */ +#include "util/exception.hh" +/* Other scoped objects in the style of scoped_ptr. */ #include -#include +#include namespace util { @@ -34,52 +35,33 @@ template class scoped_thing { scoped_thing &operator=(const scoped_thing &); }; -class scoped_fd { +class scoped_malloc { public: - scoped_fd() : fd_(-1) {} + scoped_malloc() : p_(NULL) {} - explicit scoped_fd(int fd) : fd_(fd) {} + scoped_malloc(void *p) : p_(p) {} - ~scoped_fd(); + ~scoped_malloc() { std::free(p_); } - void reset(int to) { - scoped_fd other(fd_); - fd_ = to; + void reset(void *p = NULL) { + scoped_malloc other(p_); + p_ = p; } - int get() const { return fd_; } - - int operator*() const { return fd_; } - - int release() { - int ret = fd_; - fd_ = -1; - return ret; + void call_realloc(std::size_t to) { + void *ret; + UTIL_THROW_IF(!(ret = std::realloc(p_, to)), util::ErrnoException, "realloc to " << to << " bytes failed."); + p_ = ret; } - private: - int fd_; - - scoped_fd(const scoped_fd &); - scoped_fd &operator=(const scoped_fd &); -}; - -class scoped_FILE { - public: - explicit scoped_FILE(std::FILE *file = NULL) : file_(file) {} - - ~scoped_FILE(); - - std::FILE *get() { return file_; } - const std::FILE *get() const { return file_; } - - void reset(std::FILE *to = NULL) { - scoped_FILE other(file_); - file_ = to; - } + void *get() { return p_; } + const void *get() const { return p_; } private: - std::FILE *file_; + void *p_; + + scoped_malloc(const scoped_malloc &); + scoped_malloc &operator=(const scoped_malloc &); }; // Hat tip to boost. -- cgit v1.2.3 From 0e1ffb6c1528e44f63ae8bac466bd5163e973974 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Tue, 11 Oct 2011 14:58:52 -0400 Subject: Trie fixes for SRI --- klm/lm/left.hh | 26 +++++++++++++++----------- klm/lm/search_trie.cc | 7 ++++++- klm/lm/search_trie.hh | 2 +- klm/lm/trie_sort.cc | 2 +- klm/util/scoped.hh | 2 +- 5 files changed, 24 insertions(+), 15 deletions(-) (limited to 'klm/util/scoped.hh') diff --git a/klm/lm/left.hh b/klm/lm/left.hh index effa0560..bb3f5539 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -117,7 +117,7 @@ inline size_t hash_value(const ChartState &state) { template class RuleScore { public: - explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) { + explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) { out.left.length = 0; out.right.length = 0; } @@ -130,15 +130,22 @@ template class RuleScore { void Terminal(WordIndex word) { State copy(out_.right); - ProcessRet(model_.FullScore(copy, word, out_.right)); - if (out_.right.length != copy.length + 1) left_done_ = true; + FullScoreReturn ret(model_.FullScore(copy, word, out_.right)); + prob_ += ret.prob; + if (left_done_) return; + if (ret.independent_left) { + left_done_ = true; + return; + } + out_.left.pointers[out_.left.length++] = ret.extend_left; + if (out_.right.length != copy.length + 1) + left_done_ = true; } // Faster version of NonTerminal for the case where the rule begins with a non-terminal. void BeginNonTerminal(const ChartState &in, float prob) { prob_ = prob; out_ = in; - left_write_ = out_.left.pointers + out_.left.length; left_done_ = in.full; } @@ -157,11 +164,10 @@ template class RuleScore { if (!out_.right.length) { out_.right = in.right; if (left_done_) return; - if (left_write_ != out_.left.pointers) { + if (out_.left.length) { left_done_ = true; } else { out_.left = in.left; - left_write_ = out_.left.pointers + in.left.length; left_done_ = in.full; } return; @@ -214,8 +220,8 @@ template class RuleScore { } float Finish() { - out_.left.length = left_write_ - out_.left.pointers; - out_.full = left_done_; + // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram. + out_.full = left_done_ || (out_.left.length == model_.Order() - 1); return prob_; } @@ -227,7 +233,7 @@ template class RuleScore { left_done_ = true; return; } - *(left_write_++) = ret.extend_left; + out_.left.pointers[out_.left.length++] = ret.extend_left; } const M &model_; @@ -236,8 +242,6 @@ template class RuleScore { bool left_done_; - uint64_t *left_write_; - float prob_; }; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 6479813b..5d8c70db 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -151,6 +151,11 @@ class BackoffMessages { private: void FinishedAdding() { Resize(current_ - (uint8_t*)backing_.get()); + // Sort requests in same order as files. + std::sort( + util::SizedIterator(util::SizedProxy(backing_.get(), entry_size_)), + util::SizedIterator(util::SizedProxy(current_, entry_size_)), + util::SizedCompare(EntryCompare((entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex)))); current_ = (uint8_t*)backing_.get(); } @@ -525,7 +530,7 @@ template void BuildTrie(const std::string &file_pre const RecordReader &context = contexts[order - 2]; if (context) { FormatLoadException e; - e << "An " << static_cast(order) << "-gram has context"; + e << "A " << static_cast(order) << "-gram has context"; const WordIndex *ctx = reinterpret_cast(context.Data()); for (const WordIndex *i = ctx; i != ctx + order - 1; ++i) { e << ' ' << *i; diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index c3e02a98..33ae8cff 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -36,7 +36,7 @@ template class TrieSearch { static const ModelType kModelType = static_cast(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd); - static const unsigned int kVersion = 0; + static const unsigned int kVersion = 1; static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { Quant::UpdateConfigFromBinary(fd, counts, config); diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index 86f28493..bb126f18 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -191,7 +191,7 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st assembled << file_prefix << static_cast(order) << "_merge_" << (merge_count++); files.push_back(assembled.str()); MergeSortedFiles(files[0], files[1], files.back(), weights_size, order, ThrowCombine()); - MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order, FirstCombine()); + MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order - 1, FirstCombine()); files.pop_front(); files.pop_front(); } diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh index 12e6652b..93e2e817 100644 --- a/klm/util/scoped.hh +++ b/klm/util/scoped.hh @@ -50,7 +50,7 @@ class scoped_malloc { void call_realloc(std::size_t to) { void *ret; - UTIL_THROW_IF(!(ret = std::realloc(p_, to)), util::ErrnoException, "realloc to " << to << " bytes failed."); + UTIL_THROW_IF(!(ret = std::realloc(p_, to)) && to, util::ErrnoException, "realloc to " << to << " bytes failed."); p_ = ret; } -- cgit v1.2.3