summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-01-25 22:30:48 +0200
committerChris Dyer <cdyer@cs.cmu.edu>2011-01-25 22:30:48 +0200
commitc4ade3091b812ca135ae6520fa7173e1bbf28754 (patch)
tree2528af208f6dafd0c27dcbec0d2da291a9c93ca2 /klm/lm/search_trie.cc
parentd04c0ca2d9df0e147239b18e90650ca8bd51d594 (diff)
update kenlm
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r--klm/lm/search_trie.cc302
1 files changed, 257 insertions, 45 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 3aeeeca3..1060ddef 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -3,6 +3,7 @@
#include "lm/blank.hh"
#include "lm/lm_exception.hh"
+#include "lm/max_order.hh"
#include "lm/read_arpa.hh"
#include "lm/trie.hh"
#include "lm/vocab.hh"
@@ -27,6 +28,7 @@
#include <sys/stat.h>
#include <fcntl.h>
#include <stdlib.h>
+#include <unistd.h>
namespace lm {
namespace ngram {
@@ -98,7 +100,7 @@ class EntryProxy {
}
const WordIndex *Indices() const {
- return static_cast<const WordIndex*>(inner_.Data());
+ return reinterpret_cast<const WordIndex*>(inner_.Data());
}
private:
@@ -114,17 +116,57 @@ class EntryProxy {
typedef util::ProxyIterator<EntryProxy> NGramIter;
-class CompareRecords : public std::binary_function<const EntryProxy &, const EntryProxy &, bool> {
+// Proxy for an entry except there is some extra cruft between the entries. This is used to sort (n-1)-grams using the same memory as the sorted n-grams.
+class PartialViewProxy {
+ public:
+ PartialViewProxy() : attention_size_(0), inner_() {}
+
+ PartialViewProxy(void *ptr, std::size_t block_size, std::size_t attention_size) : attention_size_(attention_size), inner_(ptr, block_size) {}
+
+ operator std::string() const {
+ return std::string(reinterpret_cast<const char*>(inner_.Data()), attention_size_);
+ }
+
+ PartialViewProxy &operator=(const PartialViewProxy &from) {
+ memcpy(inner_.Data(), from.inner_.Data(), attention_size_);
+ return *this;
+ }
+
+ PartialViewProxy &operator=(const std::string &from) {
+ memcpy(inner_.Data(), from.data(), attention_size_);
+ return *this;
+ }
+
+ const WordIndex *Indices() const {
+ return reinterpret_cast<const WordIndex*>(inner_.Data());
+ }
+
+ private:
+ friend class util::ProxyIterator<PartialViewProxy>;
+
+ typedef std::string value_type;
+
+ const std::size_t attention_size_;
+
+ typedef EntryIterator InnerIterator;
+ InnerIterator &Inner() { return inner_; }
+ const InnerIterator &Inner() const { return inner_; }
+ InnerIterator inner_;
+};
+
+typedef util::ProxyIterator<PartialViewProxy> PartialIter;
+
+template <class Proxy> class CompareRecords : public std::binary_function<const Proxy &, const Proxy &, bool> {
public:
explicit CompareRecords(unsigned char order) : order_(order) {}
- bool operator()(const EntryProxy &first, const EntryProxy &second) const {
+ bool operator()(const Proxy &first, const Proxy &second) const {
return Compare(first.Indices(), second.Indices());
}
- bool operator()(const EntryProxy &first, const std::string &second) const {
+ bool operator()(const Proxy &first, const std::string &second) const {
return Compare(first.Indices(), reinterpret_cast<const WordIndex*>(second.data()));
}
- bool operator()(const std::string &first, const EntryProxy &second) const {
+ bool operator()(const std::string &first, const Proxy &second) const {
return Compare(reinterpret_cast<const WordIndex*>(first.data()), second.Indices());
}
bool operator()(const std::string &first, const std::string &second) const {
@@ -144,6 +186,12 @@ class CompareRecords : public std::binary_function<const EntryProxy &, const Ent
unsigned char order_;
};
+FILE *OpenOrThrow(const char *name, const char *mode) {
+ FILE *ret = fopen(name, mode);
+ if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode);
+ return ret;
+}
+
void WriteOrThrow(FILE *to, const void *data, size_t size) {
assert(size);
if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
@@ -163,14 +211,26 @@ void CopyOrThrow(FILE *from, FILE *to, size_t size) {
}
}
+void CopyRestOrThrow(FILE *from, FILE *to) {
+ char buf[kCopyBufSize];
+ size_t amount;
+ while ((amount = fread(buf, 1, kCopyBufSize, from))) {
+ WriteOrThrow(to, buf, amount);
+ }
+ if (!feof(from)) UTIL_THROW(util::ErrnoException, "Short read");
+}
+
+void RemoveOrThrow(const char *name) {
+ if (std::remove(name)) UTIL_THROW(util::ErrnoException, "Could not remove " << name);
+}
+
std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order, std::size_t weights_size) {
const std::size_t entry_size = sizeof(WordIndex) * order + weights_size;
const std::size_t prefix_size = sizeof(WordIndex) * (order - 1);
std::stringstream assembled;
assembled << file_prefix << static_cast<unsigned int>(order) << '_' << batch;
std::string ret(assembled.str());
- util::scoped_FILE out(fopen(ret.c_str(), "w"));
- if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing");
+ util::scoped_FILE out(OpenOrThrow(ret.c_str(), "w"));
// Compress entries that being with the same (order-1) words.
for (const uint8_t *group_begin = static_cast<const uint8_t*>(mem_begin); group_begin != static_cast<const uint8_t*>(mem_end);) {
const uint8_t *group_end;
@@ -194,8 +254,7 @@ class SortedFileReader {
SortedFileReader() : ended_(false) {}
void Init(const std::string &name, unsigned char order) {
- file_.reset(fopen(name.c_str(), "r"));
- if (!file_.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " for read");
+ file_.reset(OpenOrThrow(name.c_str(), "r"));
header_.resize(order - 1);
NextHeader();
}
@@ -262,12 +321,13 @@ void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size)
CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count);
}
-void MergeSortedFiles(const char *first_name, const char *second_name, const char *out, std::size_t weights_size, unsigned char order) {
+void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order) {
SortedFileReader first, second;
- first.Init(first_name, order);
- second.Init(second_name, order);
- util::scoped_FILE out_file(fopen(out, "w"));
- if (!out_file.get()) UTIL_THROW(util::ErrnoException, "Could not open " << out << " for write");
+ first.Init(first_name.c_str(), order);
+ RemoveOrThrow(first_name.c_str());
+ second.Init(second_name.c_str(), order);
+ RemoveOrThrow(second_name.c_str());
+ util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w"));
while (!first.Ended() && !second.Ended()) {
if (first.HeaderVector() < second.HeaderVector()) {
CopyFullRecord(first, out_file.get(), weights_size);
@@ -316,10 +376,109 @@ void MergeSortedFiles(const char *first_name, const char *second_name, const cha
}
}
-void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) {
- if (order == 1) return;
- ConvertToSorted(f, vocab, counts, mem, file_prefix, order - 1);
+const char *kContextSuffix = "_contexts";
+
+void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) {
+ const size_t context_size = sizeof(WordIndex) * (order - 1);
+ // Sort just the contexts using the same memory.
+ PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size));
+ PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size));
+
+ // TODO: __gnu_parallel::sort here.
+ std::sort(context_begin, context_end, CompareRecords<PartialViewProxy>(order - 1));
+
+ std::string name(ngram_file_name + kContextSuffix);
+ util::scoped_FILE out(OpenOrThrow(name.c_str(), "w"));
+
+ // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator.
+ if (context_begin == context_end) return;
+ PartialIter i(context_begin);
+ WriteOrThrow(out.get(), i->Indices(), context_size);
+ const WordIndex *previous = i->Indices();
+ ++i;
+ for (; i != context_end; ++i) {
+ if (memcmp(previous, i->Indices(), context_size)) {
+ WriteOrThrow(out.get(), i->Indices(), context_size);
+ previous = i->Indices();
+ }
+ }
+}
+class ContextReader {
+ public:
+ ContextReader() : length_(0) {}
+
+ ContextReader(const char *name, size_t length) : file_(OpenOrThrow(name, "r")), length_(length), words_(length), valid_(true) {
+ ++*this;
+ }
+
+ void Reset(const char *name, size_t length) {
+ file_.reset(OpenOrThrow(name, "r"));
+ length_ = length;
+ words_.resize(length);
+ valid_ = true;
+ ++*this;
+ }
+
+ ContextReader &operator++() {
+ if (1 != fread(&*words_.begin(), length_, 1, file_.get())) {
+ if (!feof(file_.get()))
+ UTIL_THROW(util::ErrnoException, "Short read");
+ valid_ = false;
+ }
+ return *this;
+ }
+
+ const WordIndex *operator*() const { return &*words_.begin(); }
+
+ operator bool() const { return valid_; }
+
+ FILE *GetFile() { return file_.get(); }
+
+ private:
+ util::scoped_FILE file_;
+
+ size_t length_;
+
+ std::vector<WordIndex> words_;
+
+ bool valid_;
+};
+
+void MergeContextFiles(const std::string &first_base, const std::string &second_base, const std::string &out_base, unsigned char order) {
+ const size_t context_size = sizeof(WordIndex) * (order - 1);
+ std::string first_name(first_base + kContextSuffix);
+ std::string second_name(second_base + kContextSuffix);
+ ContextReader first(first_name.c_str(), context_size), second(second_name.c_str(), context_size);
+ RemoveOrThrow(first_name.c_str());
+ RemoveOrThrow(second_name.c_str());
+ std::string out_name(out_base + kContextSuffix);
+ util::scoped_FILE out(OpenOrThrow(out_name.c_str(), "w"));
+ while (first && second) {
+ for (const WordIndex *f = *first, *s = *second; ; ++f, ++s) {
+ if (f == *first + order) {
+ // Equal.
+ WriteOrThrow(out.get(), *first, context_size);
+ ++first;
+ ++second;
+ break;
+ }
+ if (*f < *s) {
+ // First lower
+ WriteOrThrow(out.get(), *first, context_size);
+ ++first;
+ break;
+ } else if (*f > *s) {
+ WriteOrThrow(out.get(), *second, context_size);
+ ++second;
+ break;
+ }
+ }
+ }
+ CopyRestOrThrow((first ? first : second).GetFile(), out.get());
+}
+
+void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) {
ReadNGramHeader(f, order);
const size_t count = counts[order - 1];
// Size of weights. Does it include backoff?
@@ -341,11 +500,13 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size));
}
}
- // TODO: __gnu_parallel::sort here.
+ // Sort full records by full n-gram.
EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size);
- std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords(order));
-
+ // TODO: __gnu_parallel::sort here.
+ std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords<EntryProxy>(order));
files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size));
+ WriteContextFile(begin, out_end, files.back(), entry_size, order);
+
done += (out_end - begin) / entry_size;
}
@@ -356,10 +517,9 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
std::stringstream assembled;
assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++);
files.push_back(assembled.str());
- MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), weights_size, order);
- if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]);
+ MergeSortedFiles(files[0], files[1], files.back(), weights_size, order);
+ MergeContextFiles(files[0], files[1], files.back(), order);
files.pop_front();
- if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]);
files.pop_front();
}
if (!files.empty()) {
@@ -367,6 +527,9 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
assembled << file_prefix << static_cast<unsigned int>(order) << "_merged";
std::string merged_name(assembled.str());
if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str());
+ std::string context_name = files[0] + kContextSuffix;
+ merged_name += kContextSuffix;
+ if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str());
}
}
@@ -378,26 +541,38 @@ void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts,
Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()));
}
+ // Only use as much buffer as we need.
+ size_t buffer_use = 0;
+ for (unsigned int order = 2; order < counts.size(); ++order) {
+ buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1]));
+ }
+ buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back()));
+ buffer = std::min(buffer, buffer_use);
+
util::scoped_memory mem;
mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED);
if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);
- ConvertToSorted(f, vocab, counts, mem, file_prefix, counts.size());
+
+ for (unsigned char order = 2; order <= counts.size(); ++order) {
+ ConvertToSorted(f, vocab, counts, mem, file_prefix, order);
+ }
ReadEnd(f);
}
bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const WordIndex *header) {
for (; words != words_end; ++words, ++header) {
if (*words != *header) {
- assert(*words <= *header);
+ //assert(*words <= *header);
return false;
}
}
return true;
}
+// Counting phrase
class JustCount {
public:
- JustCount(UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order)
+ JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order)
: counts_(counts), longest_counts_(counts + order - 1) {}
void Unigrams(WordIndex begin, WordIndex end) {
@@ -408,7 +583,7 @@ class JustCount {
++counts_[mid_idx + 1];
}
- void Middle(const unsigned char mid_idx, WordIndex /*key*/, const ProbBackoff &/*weights*/) {
+ void Middle(const unsigned char mid_idx, const WordIndex * /*before*/, WordIndex /*key*/, const ProbBackoff &/*weights*/) {
++counts_[mid_idx + 1];
}
@@ -427,7 +602,8 @@ class JustCount {
class WriteEntries {
public:
- WriteEntries(UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) :
+ WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) :
+ contexts_(contexts),
unigrams_(unigrams),
middle_(middle),
longest_(longest),
@@ -444,7 +620,13 @@ class WriteEntries {
middle_[mid_idx].Insert(key, kBlankProb, kBlankBackoff);
}
- void Middle(const unsigned char mid_idx, WordIndex key, const ProbBackoff &weights) {
+ void Middle(const unsigned char mid_idx, const WordIndex *before, WordIndex key, ProbBackoff weights) {
+ // Order (mid_idx+2).
+ ContextReader &context = contexts_[mid_idx + 1];
+ if (context && !memcmp(before, *context, sizeof(WordIndex) * (mid_idx + 1)) && (*context)[mid_idx + 1] == key) {
+ SetExtension(weights.backoff);
+ ++context;
+ }
middle_[mid_idx].Insert(key, weights.prob, weights.backoff);
}
@@ -455,6 +637,7 @@ class WriteEntries {
void Cleanup() {}
private:
+ ContextReader *contexts_;
UnigramValue *const unigrams_;
BitPackedMiddle *const middle_;
BitPackedLongest &longest_;
@@ -463,14 +646,15 @@ class WriteEntries {
template <class Doing> class RecursiveInsert {
public:
- RecursiveInsert(SortedFileReader *inputs, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) :
- doing_(unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), words_(new WordIndex[order]), order_minus_2_(order - 2) {
+ RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) :
+ doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) {
}
// Outer unigram loop.
void Apply(std::ostream *progress_out, const char *message, WordIndex unigram_count) {
util::ErsatzProgress progress(progress_out, message, unigram_count + 1);
for (words_[0] = 0; ; ++words_[0]) {
+ progress.Set(words_[0]);
WordIndex min_continue = unigram_count;
for (SortedFileReader *other = inputs_; other != inputs_end_; ++other) {
if (other->Ended()) continue;
@@ -479,7 +663,6 @@ template <class Doing> class RecursiveInsert {
// This will write at unigram_count. This is by design so that the next pointers will make sense.
doing_.Unigrams(words_[0], min_continue + 1);
if (min_continue == unigram_count) break;
- progress += min_continue - words_[0];
words_[0] = min_continue;
Middle(0);
}
@@ -497,7 +680,7 @@ template <class Doing> class RecursiveInsert {
SortedFileReader &reader = inputs_[mid_idx];
- if (reader.Ended() || !HeadMatch(words_.get(), words_.get() + mid_idx + 1, reader.Header())) {
+ if (reader.Ended() || !HeadMatch(words_, words_ + mid_idx + 1, reader.Header())) {
// This order doesn't have a header match, but longer ones might.
MiddleAllBlank(mid_idx);
return;
@@ -509,7 +692,7 @@ template <class Doing> class RecursiveInsert {
while (count) {
WordIndex min_continue = std::numeric_limits<WordIndex>::max();
for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) {
- if (!other->Ended() && HeadMatch(words_.get(), words_.get() + mid_idx + 1, other->Header()))
+ if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header()))
min_continue = std::min(min_continue, other->Header()[mid_idx + 1]);
}
while (true) {
@@ -521,7 +704,7 @@ template <class Doing> class RecursiveInsert {
}
ProbBackoff weights;
reader.ReadWeights(weights);
- doing_.Middle(mid_idx, current, weights);
+ doing_.Middle(mid_idx, words_, current, weights);
--count;
if (current == min_continue) {
words_[mid_idx + 1] = min_continue;
@@ -542,7 +725,7 @@ template <class Doing> class RecursiveInsert {
while (true) {
WordIndex min_continue = std::numeric_limits<WordIndex>::max();
for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) {
- if (!other->Ended() && HeadMatch(words_.get(), words_.get() + mid_idx + 1, other->Header()))
+ if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header()))
min_continue = std::min(min_continue, other->Header()[mid_idx + 1]);
}
if (min_continue == std::numeric_limits<WordIndex>::max()) return;
@@ -554,7 +737,7 @@ template <class Doing> class RecursiveInsert {
void Longest() {
SortedFileReader &reader = *(inputs_end_ - 1);
- if (reader.Ended() || !HeadMatch(words_.get(), words_.get() + order_minus_2_ + 1, reader.Header())) return;
+ if (reader.Ended() || !HeadMatch(words_, words_ + order_minus_2_ + 1, reader.Header())) return;
WordIndex count = reader.ReadCount();
for (WordIndex i = 0; i < count; ++i) {
WordIndex word = reader.ReadWord();
@@ -571,7 +754,7 @@ template <class Doing> class RecursiveInsert {
SortedFileReader *inputs_;
SortedFileReader *inputs_end_;
- util::scoped_array<WordIndex> words_;
+ WordIndex words_[kMaxOrder];
const unsigned char order_minus_2_;
};
@@ -586,17 +769,21 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u
void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) {
SortedFileReader inputs[counts.size() - 1];
+ ContextReader contexts[counts.size() - 1];
for (unsigned char i = 2; i <= counts.size(); ++i) {
std::stringstream assembled;
assembled << file_prefix << static_cast<unsigned int>(i) << "_merged";
inputs[i-2].Init(assembled.str(), i);
- unlink(assembled.str().c_str());
+ RemoveOrThrow(assembled.str().c_str());
+ assembled << kContextSuffix;
+ contexts[i-2].Reset(assembled.str().c_str(), (i-1) * sizeof(WordIndex));
+ RemoveOrThrow(assembled.str().c_str());
}
std::vector<uint64_t> fixed_counts(counts.size());
{
- RecursiveInsert<JustCount> counter(inputs, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
+ RecursiveInsert<JustCount> counter(inputs, contexts, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]);
}
SanityCheckCounts(counts, fixed_counts);
@@ -609,21 +796,38 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun
UnigramValue *unigrams = out.unigram.Raw();
// Fill entries except unigram probabilities.
{
- RecursiveInsert<WriteEntries> inserter(inputs, unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
+ RecursiveInsert<WriteEntries> inserter(inputs, contexts, unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
inserter.Apply(config.messages, "Building trie", fixed_counts[0]);
}
// Fill unigram probabilities.
{
std::string name(file_prefix + "unigrams");
- util::scoped_FILE file(fopen(name.c_str(), "r"));
- if (!file.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " failed");
+ util::scoped_FILE file(OpenOrThrow(name.c_str(), "r"));
for (WordIndex i = 0; i < counts[0]; ++i) {
ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff));
+ if (contexts[0] && **contexts[0] == i) {
+ SetExtension(unigrams[i].weights.backoff);
+ ++contexts[0];
+ }
}
unlink(name.c_str());
}
+ // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
+ for (unsigned char order = 2; order <= counts.size(); ++order) {
+ const ContextReader &context = contexts[order - 2];
+ if (context) {
+ FormatLoadException e;
+ e << "An " << static_cast<unsigned int>(order) << "-gram has the context (i.e. all but the last word):";
+ for (const WordIndex *i = *context; i != *context + order - 1; ++i) {
+ e << ' ' << *i;
+ }
+ e << " so this context must appear in the model as a " << static_cast<unsigned int>(order - 1) << "-gram but it does not.";
+ throw e;
+ }
+ }
+
/* Set ending offsets so the last entry will be sized properly */
// Last entry for unigrams was already set.
if (!out.middle.empty()) {
@@ -634,19 +838,27 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun
}
}
+bool IsDirectory(const char *path) {
+ struct stat info;
+ if (0 != stat(path, &info)) return false;
+ return S_ISDIR(info.st_mode);
+}
+
} // namespace
void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
std::string temporary_directory;
if (config.temporary_directory_prefix) {
temporary_directory = config.temporary_directory_prefix;
+ if (!temporary_directory.empty() && temporary_directory[temporary_directory.size() - 1] != '/' && IsDirectory(temporary_directory.c_str()))
+ temporary_directory += '/';
} else if (config.write_mmap) {
temporary_directory = config.write_mmap;
} else {
temporary_directory = file;
}
// Null on end is kludge to ensure null termination.
- temporary_directory += "-tmp-XXXXXX";
+ temporary_directory += "_trie_tmp_XXXXXX";
temporary_directory += '\0';
if (!mkdtemp(&temporary_directory[0])) {
UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str());
@@ -658,7 +870,7 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::v
// At least 1MB sorting memory.
ARPAToSortedFiles(f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
- BuildTrie(temporary_directory.c_str(), counts, config, *this, backing);
+ BuildTrie(temporary_directory, counts, config, *this, backing);
if (rmdir(temporary_directory.c_str()) && config.messages) {
*config.messages << "Failed to delete " << temporary_directory << std::endl;
}