summaryrefslogtreecommitdiff
path: root/klm/lm/trie_sort.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/trie_sort.cc')
-rw-r--r--klm/lm/trie_sort.cc21
1 files changed, 15 insertions, 6 deletions
diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc
index 126d43ab..c3f46874 100644
--- a/klm/lm/trie_sort.cc
+++ b/klm/lm/trie_sort.cc
@@ -16,6 +16,7 @@
#include <cstdio>
#include <cstdlib>
#include <deque>
+#include <iterator>
#include <limits>
#include <vector>
@@ -106,14 +107,20 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &temp_pre
}
struct ThrowCombine {
- void operator()(std::size_t /*entry_size*/, const void * /*first*/, const void * /*second*/, FILE * /*out*/) const {
- UTIL_THROW(FormatLoadException, "Duplicate n-gram detected.");
+ void operator()(std::size_t entry_size, unsigned char order, const void *first, const void *second, FILE * /*out*/) const {
+ const WordIndex *base = reinterpret_cast<const WordIndex*>(first);
+ FormatLoadException e;
+ e << "Duplicate n-gram detected with vocab ids";
+ for (const WordIndex *i = base; i != base + order; ++i) {
+ e << ' ' << *i;
+ }
+ throw e;
}
};
// Useful for context files that just contain records with no value.
struct FirstCombine {
- void operator()(std::size_t entry_size, const void *first, const void * /*second*/, FILE *out) const {
+ void operator()(std::size_t entry_size, unsigned char /*order*/, const void *first, const void * /*second*/, FILE *out) const {
util::WriteOrThrow(out, first, entry_size);
}
};
@@ -133,7 +140,7 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f
util::WriteOrThrow(out_file.get(), second.Data(), entry_size);
++second;
} else {
- combine(entry_size, first.Data(), second.Data(), out_file.get());
+ combine(entry_size, order, first.Data(), second.Data(), out_file.get());
++first; ++second;
}
}
@@ -248,11 +255,13 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo
uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size;
if (order == counts.size()) {
for (; out != out_end; out += entry_size) {
- ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn);
+ std::reverse_iterator<WordIndex*> it(reinterpret_cast<WordIndex*>(out) + order);
+ ReadNGram(f, order, vocab, it, *reinterpret_cast<Prob*>(out + words_size), warn);
}
} else {
for (; out != out_end; out += entry_size) {
- ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn);
+ std::reverse_iterator<WordIndex*> it(reinterpret_cast<WordIndex*>(out) + order);
+ ReadNGram(f, order, vocab, it, *reinterpret_cast<ProbBackoff*>(out + words_size), warn);
}
}
// Sort full records by full n-gram.