summaryrefslogtreecommitdiff
path: root/klm/lm
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm')
-rw-r--r--klm/lm/binary_format.cc16
-rw-r--r--klm/lm/build_binary.cc2
-rw-r--r--klm/lm/search_hashed.cc2
-rw-r--r--klm/lm/search_hashed.hh2
-rw-r--r--klm/lm/search_trie.cc2
-rw-r--r--klm/lm/sri_test.cc65
-rw-r--r--klm/lm/trie_sort.cc20
-rw-r--r--klm/lm/trie_sort.hh2
8 files changed, 26 insertions, 85 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index fd841e59..efa67056 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -83,7 +83,13 @@ void WriteHeader(void *to, const Parameters &params) {
uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) {
if (config.write_mmap) {
std::size_t total = TotalHeaderSize(order) + memory_size;
- backing.vocab.reset(util::MapZeroedWrite(config.write_mmap, total, backing.file), total, util::scoped_memory::MMAP_ALLOCATED);
+ backing.file.reset(util::CreateOrThrow(config.write_mmap));
+ if (config.write_method == Config::WRITE_MMAP) {
+ backing.vocab.reset(util::MapZeroedWrite(backing.file.get(), total), total, util::scoped_memory::MMAP_ALLOCATED);
+ } else {
+ util::ResizeOrThrow(backing.file.get(), 0);
+ util::MapAnonymous(total, backing.vocab);
+ }
strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order));
return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order);
} else {
@@ -121,12 +127,14 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) {
if (!config.write_mmap) return;
- util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
switch (config.write_method) {
case Config::WRITE_MMAP:
+ util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
util::SyncOrThrow(backing.search.get(), backing.search.size());
break;
case Config::WRITE_AFTER:
+ util::SeekOrThrow(backing.file.get(), 0);
+ util::WriteOrThrow(backing.file.get(), backing.vocab.get(), backing.vocab.size());
util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad);
util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size());
util::FSyncOrThrow(backing.file.get());
@@ -141,6 +149,10 @@ void FinishFile(const Config &config, ModelType model_type, unsigned int search_
params.fixed.has_vocabulary = config.include_vocab;
params.fixed.search_version = search_version;
WriteHeader(backing.vocab.get(), params);
+ if (config.write_method == Config::WRITE_AFTER) {
+ util::SeekOrThrow(backing.file.get(), 0);
+ util::WriteOrThrow(backing.file.get(), backing.vocab.get(), TotalHeaderSize(counts.size()));
+ }
}
namespace detail {
diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc
index efe99899..2b8c9d5b 100644
--- a/klm/lm/build_binary.cc
+++ b/klm/lm/build_binary.cc
@@ -11,6 +11,8 @@
#ifdef WIN32
#include "util/getopt.hh"
+#else
+#include <unistd.h>
#endif
namespace lm {
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index 13942309..a1623834 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -234,7 +234,7 @@ template <> void HashedSearch<BackoffValue>::DispatchBuild(util::FilePiece &f, c
ApplyBuild(f, counts, config, vocab, warn, build);
}
-template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {
+template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {
switch (config.rest_function) {
case Config::REST_MAX:
{
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh
index 3bcde921..a52f107b 100644
--- a/klm/lm/search_hashed.hh
+++ b/klm/lm/search_hashed.hh
@@ -161,7 +161,7 @@ template <class Value> class HashedSearch {
{}
static uint64_t Size(uint64_t count) {
- return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk>
+ return (count + 1) * sizeof(typename Value::Weights); // +1 for hallucinate <unk>
}
const typename Value::Weights &Lookup(WordIndex index) const {
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 832cc9f7..debcfd07 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -89,7 +89,7 @@ class BackoffMessages {
if (!HasExtension(weights.backoff)) {
weights.backoff = kExtensionBackoff;
UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed.");
- WriteOrThrow(unigrams, &weights, sizeof(weights));
+ util::WriteOrThrow(unigrams, &weights, sizeof(weights));
}
const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + sizeof(WordIndex));
base[write_to.array][write_to.index] += weights.backoff;
diff --git a/klm/lm/sri_test.cc b/klm/lm/sri_test.cc
deleted file mode 100644
index e697d722..00000000
--- a/klm/lm/sri_test.cc
+++ /dev/null
@@ -1,65 +0,0 @@
-#include "lm/sri.hh"
-
-#include <stdlib.h>
-
-#define BOOST_TEST_MODULE SRITest
-#include <boost/test/unit_test.hpp>
-
-namespace lm {
-namespace sri {
-namespace {
-
-#define StartTest(word, ngram, score) \
- ret = model.FullScore( \
- state, \
- model.GetVocabulary().Index(word), \
- out);\
- BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
- BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
- BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_);
-
-#define AppendTest(word, ngram, score) \
- StartTest(word, ngram, score) \
- state = out;
-
-template <class M> void Starters(M &model) {
- FullScoreReturn ret;
- Model::State state(model.BeginSentenceState());
- Model::State out;
-
- StartTest("looking", 2, -0.4846522);
-
- // , probability plus <s> backoff
- StartTest(",", 1, -1.383514 + -0.4149733);
- // <unk> probability plus <s> backoff
- StartTest("this_is_not_found", 0, -1.995635 + -0.4149733);
-}
-
-template <class M> void Continuation(M &model) {
- FullScoreReturn ret;
- Model::State state(model.BeginSentenceState());
- Model::State out;
-
- AppendTest("looking", 2, -0.484652);
- AppendTest("on", 3, -0.348837);
- AppendTest("a", 4, -0.0155266);
- AppendTest("little", 5, -0.00306122);
- State preserve = state;
- AppendTest("the", 1, -4.04005);
- AppendTest("biarritz", 1, -1.9889);
- AppendTest("not_found", 0, -2.29666);
- AppendTest("more", 1, -1.20632);
- AppendTest(".", 2, -0.51363);
- AppendTest("</s>", 3, -0.0191651);
-
- state = preserve;
- AppendTest("more", 5, -0.00181395);
- AppendTest("loin", 5, -0.0432557);
-}
-
-BOOST_AUTO_TEST_CASE(starters) { Model m("test.arpa", 5); Starters(m); }
-BOOST_AUTO_TEST_CASE(continuation) { Model m("test.arpa", 5); Continuation(m); }
-
-} // namespace
-} // namespace sri
-} // namespace lm
diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc
index 0d83221e..8663e94e 100644
--- a/klm/lm/trie_sort.cc
+++ b/klm/lm/trie_sort.cc
@@ -22,12 +22,6 @@
namespace lm {
namespace ngram {
namespace trie {
-
-void WriteOrThrow(FILE *to, const void *data, size_t size) {
- assert(size);
- if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
-}
-
namespace {
typedef util::SizedIterator NGramIter;
@@ -95,12 +89,12 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &make
// Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator.
if (context_begin == context_end) return out.release();
PartialIter i(context_begin);
- WriteOrThrow(out.get(), i->Data(), context_size);
+ util::WriteOrThrow(out.get(), i->Data(), context_size);
const void *previous = i->Data();
++i;
for (; i != context_end; ++i) {
if (memcmp(previous, i->Data(), context_size)) {
- WriteOrThrow(out.get(), i->Data(), context_size);
+ util::WriteOrThrow(out.get(), i->Data(), context_size);
previous = i->Data();
}
}
@@ -116,7 +110,7 @@ struct ThrowCombine {
// Useful for context files that just contain records with no value.
struct FirstCombine {
void operator()(std::size_t entry_size, const void *first, const void * /*second*/, FILE *out) const {
- WriteOrThrow(out, first, entry_size);
+ util::WriteOrThrow(out, first, entry_size);
}
};
@@ -129,10 +123,10 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f
EntryCompare less(order);
while (first && second) {
if (less(first.Data(), second.Data())) {
- WriteOrThrow(out_file.get(), first.Data(), entry_size);
+ util::WriteOrThrow(out_file.get(), first.Data(), entry_size);
++first;
} else if (less(second.Data(), first.Data())) {
- WriteOrThrow(out_file.get(), second.Data(), entry_size);
+ util::WriteOrThrow(out_file.get(), second.Data(), entry_size);
++second;
} else {
combine(entry_size, first.Data(), second.Data(), out_file.get());
@@ -140,7 +134,7 @@ template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_f
}
}
for (RecordReader &remains = (first ? first : second); remains; ++remains) {
- WriteOrThrow(out_file.get(), remains.Data(), entry_size);
+ util::WriteOrThrow(out_file.get(), remains.Data(), entry_size);
}
return out_file.release();
}
@@ -164,7 +158,7 @@ void RecordReader::Init(FILE *file, std::size_t entry_size) {
void RecordReader::Overwrite(const void *start, std::size_t amount) {
long internal = (uint8_t*)start - (uint8_t*)data_.get();
UTIL_THROW_IF(fseek(file_, internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision");
- WriteOrThrow(file_, start, amount);
+ util::WriteOrThrow(file_, start, amount);
long forward = entry_size_ - internal - amount;
#if !defined(_WIN32) && !defined(_WIN64)
if (forward)
diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh
index 1e6fce51..2197b80c 100644
--- a/klm/lm/trie_sort.hh
+++ b/klm/lm/trie_sort.hh
@@ -29,8 +29,6 @@ struct Config;
namespace trie {
-void WriteOrThrow(FILE *to, const void *data, size_t size);
-
class EntryCompare : public std::binary_function<const void*, const void*, bool> {
public:
explicit EntryCompare(unsigned char order) : order_(order) {}