From 7424d6083962d58eef6daa83a036c189ca11431f Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Tue, 11 Oct 2011 14:58:52 -0400 Subject: Trie fixes for SRI --- klm/lm/left.hh | 26 +++++++++++++++----------- klm/lm/search_trie.cc | 7 ++++++- klm/lm/search_trie.hh | 2 +- klm/lm/trie_sort.cc | 2 +- klm/util/scoped.hh | 2 +- 5 files changed, 24 insertions(+), 15 deletions(-) (limited to 'klm') diff --git a/klm/lm/left.hh b/klm/lm/left.hh index effa0560..bb3f5539 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -117,7 +117,7 @@ inline size_t hash_value(const ChartState &state) { template class RuleScore { public: - explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) { + explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) { out.left.length = 0; out.right.length = 0; } @@ -130,15 +130,22 @@ template class RuleScore { void Terminal(WordIndex word) { State copy(out_.right); - ProcessRet(model_.FullScore(copy, word, out_.right)); - if (out_.right.length != copy.length + 1) left_done_ = true; + FullScoreReturn ret(model_.FullScore(copy, word, out_.right)); + prob_ += ret.prob; + if (left_done_) return; + if (ret.independent_left) { + left_done_ = true; + return; + } + out_.left.pointers[out_.left.length++] = ret.extend_left; + if (out_.right.length != copy.length + 1) + left_done_ = true; } // Faster version of NonTerminal for the case where the rule begins with a non-terminal. void BeginNonTerminal(const ChartState &in, float prob) { prob_ = prob; out_ = in; - left_write_ = out_.left.pointers + out_.left.length; left_done_ = in.full; } @@ -157,11 +164,10 @@ template class RuleScore { if (!out_.right.length) { out_.right = in.right; if (left_done_) return; - if (left_write_ != out_.left.pointers) { + if (out_.left.length) { left_done_ = true; } else { out_.left = in.left; - left_write_ = out_.left.pointers + in.left.length; left_done_ = in.full; } return; @@ -214,8 +220,8 @@ template class RuleScore { } float Finish() { - out_.left.length = left_write_ - out_.left.pointers; - out_.full = left_done_; + // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram. + out_.full = left_done_ || (out_.left.length == model_.Order() - 1); return prob_; } @@ -227,7 +233,7 @@ template class RuleScore { left_done_ = true; return; } - *(left_write_++) = ret.extend_left; + out_.left.pointers[out_.left.length++] = ret.extend_left; } const M &model_; @@ -236,8 +242,6 @@ template class RuleScore { bool left_done_; - uint64_t *left_write_; - float prob_; }; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 6479813b..5d8c70db 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -151,6 +151,11 @@ class BackoffMessages { private: void FinishedAdding() { Resize(current_ - (uint8_t*)backing_.get()); + // Sort requests in same order as files. + std::sort( + util::SizedIterator(util::SizedProxy(backing_.get(), entry_size_)), + util::SizedIterator(util::SizedProxy(current_, entry_size_)), + util::SizedCompare(EntryCompare((entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex)))); current_ = (uint8_t*)backing_.get(); } @@ -525,7 +530,7 @@ template void BuildTrie(const std::string &file_pre const RecordReader &context = contexts[order - 2]; if (context) { FormatLoadException e; - e << "An " << static_cast(order) << "-gram has context"; + e << "A " << static_cast(order) << "-gram has context"; const WordIndex *ctx = reinterpret_cast(context.Data()); for (const WordIndex *i = ctx; i != ctx + order - 1; ++i) { e << ' ' << *i; diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index c3e02a98..33ae8cff 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -36,7 +36,7 @@ template class TrieSearch { static const ModelType kModelType = static_cast(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd); - static const unsigned int kVersion = 0; + static const unsigned int kVersion = 1; static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { Quant::UpdateConfigFromBinary(fd, counts, config); diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index 86f28493..bb126f18 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -191,7 +191,7 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st assembled << file_prefix << static_cast(order) << "_merge_" << (merge_count++); files.push_back(assembled.str()); MergeSortedFiles(files[0], files[1], files.back(), weights_size, order, ThrowCombine()); - MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order, FirstCombine()); + MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order - 1, FirstCombine()); files.pop_front(); files.pop_front(); } diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh index 12e6652b..93e2e817 100644 --- a/klm/util/scoped.hh +++ b/klm/util/scoped.hh @@ -50,7 +50,7 @@ class scoped_malloc { void call_realloc(std::size_t to) { void *ret; - UTIL_THROW_IF(!(ret = std::realloc(p_, to)), util::ErrnoException, "realloc to " << to << " bytes failed."); + UTIL_THROW_IF(!(ret = std::realloc(p_, to)) && to, util::ErrnoException, "realloc to " << to << " bytes failed."); p_ = ret; } -- cgit v1.2.3