diff options
Diffstat (limited to 'klm/lm/trie_sort.cc')
-rw-r--r-- | klm/lm/trie_sort.cc | 21 |
1 files changed, 15 insertions, 6 deletions
diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index 126d43ab..c3f46874 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -16,6 +16,7 @@ #include <cstdio> #include <cstdlib> #include <deque> +#include <iterator> #include <limits> #include <vector> @@ -106,14 +107,20 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &temp_pre } struct ThrowCombine { - void operator()(std::size_t /*entry_size*/, const void * /*first*/, const void * /*second*/, FILE * /*out*/) const { - UTIL_THROW(FormatLoadException, "Duplicate n-gram detected."); + void operator()(std::size_t entry_size, unsigned char order, const void *first, const void *second, FILE * /*out*/) const { + const WordIndex *base = reinterpret_cast<const WordIndex*>(first); + FormatLoadException e; + e << "Duplicate n-gram detected with vocab ids"; + for (const WordIndex *i = base; i != base + order; ++i) { + e << ' ' << *i; + } + throw e; } }; // Useful for context files that just contain records with no value. struct FirstCombine { - void operator()(std::size_t entry_size, const void *first, const void * /*second*/, FILE *out) const { + void operator()(std::size_t entry_size, unsigned char /*order*/, const void *first, const void * /*second*/, FILE *out) const { util::WriteOrThrow(out, first, entry_size); } }; @@ -133,7 +140,7 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f util::WriteOrThrow(out_file.get(), second.Data(), entry_size); ++second; } else { - combine(entry_size, first.Data(), second.Data(), out_file.get()); + combine(entry_size, order, first.Data(), second.Data(), out_file.get()); ++first; ++second; } } @@ -248,11 +255,13 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; if (order == counts.size()) { for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn); + std::reverse_iterator<WordIndex*> it(reinterpret_cast<WordIndex*>(out) + order); + ReadNGram(f, order, vocab, it, *reinterpret_cast<Prob*>(out + words_size), warn); } } else { for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn); + std::reverse_iterator<WordIndex*> it(reinterpret_cast<WordIndex*>(out) + order); + ReadNGram(f, order, vocab, it, *reinterpret_cast<ProbBackoff*>(out + words_size), warn); } } // Sort full records by full n-gram. |