summaryrefslogtreecommitdiff
path: root/klm/lm
diff options
context:
space:
mode:
authorKenneth Heafield <kheafiel@cluster10.lti.ece.cmu.local>2011-03-09 13:40:23 -0500
committerKenneth Heafield <kheafiel@cluster10.lti.ece.cmu.local>2011-03-09 13:40:23 -0500
commit75310799a6ee82b742ba69abab951a74fd0d19fc (patch)
tree7be29bf130d75b478ef1f2449ca67a7eddfe0781 /klm/lm
parentb2a018f6a5fc34dc799aa41a05a16eb79aa95de1 (diff)
kenlm sync
Diffstat (limited to 'klm/lm')
-rw-r--r--klm/lm/binary_format.cc35
-rw-r--r--klm/lm/binary_format.hh6
-rw-r--r--klm/lm/build_binary.cc26
-rw-r--r--klm/lm/config.cc1
-rw-r--r--klm/lm/config.hh7
-rw-r--r--klm/lm/exception.cc21
-rw-r--r--klm/lm/exception.hh40
-rw-r--r--klm/lm/lm_exception.cc4
-rw-r--r--klm/lm/lm_exception.hh2
-rw-r--r--klm/lm/model.cc22
-rw-r--r--klm/lm/ngram.hh226
-rw-r--r--klm/lm/ngram_build_binary.cc13
-rw-r--r--klm/lm/ngram_config.hh58
-rw-r--r--klm/lm/ngram_query.cc33
-rw-r--r--klm/lm/ngram_test.cc91
-rw-r--r--klm/lm/read_arpa.cc83
-rw-r--r--klm/lm/read_arpa.hh3
-rw-r--r--klm/lm/search_hashed.cc5
-rw-r--r--klm/lm/search_trie.cc50
-rw-r--r--klm/lm/test.binarybin1660 -> 0 bytes
-rw-r--r--klm/lm/virtual_interface.cc2
-rw-r--r--klm/lm/vocab.cc24
-rw-r--r--klm/lm/vocab.hh10
23 files changed, 163 insertions, 599 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index 2a6aff34..9be0bc8e 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -9,6 +9,7 @@
#include <fcntl.h>
#include <errno.h>
#include <stdlib.h>
+#include <string.h>
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/stat.h>
@@ -19,6 +20,8 @@ namespace ngram {
namespace {
const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 4\n\0";
+// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed).
+const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n";
const long int kMagicVersion = 4;
// Test values.
@@ -81,6 +84,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_
if (config.write_mmap) {
std::size_t total = TotalHeaderSize(order) + memory_size;
backing.vocab.reset(util::MapZeroedWrite(config.write_mmap, total, backing.file), total, util::scoped_memory::MMAP_ALLOCATED);
+ strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order));
return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order);
} else {
backing.vocab.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED);
@@ -88,17 +92,8 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_
}
}
-uint8_t *GrowForSearch(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing) {
+uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing) {
if (config.write_mmap) {
- // header and vocab share the same mmap. The header is written here because we know the counts.
- Parameters params;
- params.counts = counts;
- params.fixed.order = counts.size();
- params.fixed.probing_multiplier = config.probing_multiplier;
- params.fixed.model_type = model_type;
- params.fixed.has_vocabulary = config.include_vocab;
- WriteHeader(backing.vocab.get(), params);
-
// Grow the file to accomodate the search, using zeros.
if (-1 == ftruncate(backing.file.get(), backing.vocab.size() + memory_size))
UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (backing.vocab.size() + memory_size) << " failed");
@@ -115,6 +110,19 @@ uint8_t *GrowForSearch(const Config &config, ModelType model_type, const std::ve
}
}
+void FinishFile(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, Backing &backing) {
+ if (config.write_mmap) {
+ // header and vocab share the same mmap. The header is written here because we know the counts.
+ Parameters params;
+ params.counts = counts;
+ params.fixed.order = counts.size();
+ params.fixed.probing_multiplier = config.probing_multiplier;
+ params.fixed.model_type = model_type;
+ params.fixed.has_vocabulary = config.include_vocab;
+ WriteHeader(backing.vocab.get(), params);
+ }
+}
+
namespace detail {
bool IsBinaryFormat(int fd) {
@@ -130,14 +138,17 @@ bool IsBinaryFormat(int fd) {
Sanity reference_header = Sanity();
reference_header.SetToReference();
if (!memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true;
+ if (!memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) {
+ UTIL_THROW(FormatLoadException, "This binary file did not finish building");
+ }
if (!memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) {
char *end_ptr;
const char *begin_version = static_cast<const char*>(memory.get()) + strlen(kMagicBeforeVersion);
long int version = strtol(begin_version, &end_ptr, 10);
if ((end_ptr != begin_version) && version != kMagicVersion) {
- UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to rebuild your binary LM from the ARPA. Sorry.");
+ UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary");
}
- UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture.");
+ UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture");
}
return false;
}
diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh
index 2d66f813..72d8c159 100644
--- a/klm/lm/binary_format.hh
+++ b/klm/lm/binary_format.hh
@@ -43,7 +43,9 @@ struct Backing {
uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing);
// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin.
-uint8_t *GrowForSearch(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing);
+uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing);
+
+void FinishFile(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, Backing &backing);
namespace detail {
@@ -81,7 +83,7 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to)
to.InitializeFromARPA(file, config);
}
} catch (util::Exception &e) {
- e << " in file " << file;
+ e << " File: " << file;
throw;
}
diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc
index 144c57e0..d6dd5994 100644
--- a/klm/lm/build_binary.cc
+++ b/klm/lm/build_binary.cc
@@ -1,6 +1,8 @@
#include "lm/model.hh"
#include "util/file_piece.hh"
+#include <cstdlib>
+#include <exception>
#include <iostream>
#include <iomanip>
@@ -13,8 +15,10 @@ namespace ngram {
namespace {
void Usage(const char *name) {
- std::cerr << "Usage: " << name << " [-u unknown_probability] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n"
-"Where type is one of probing, trie, or sorted:\n\n"
+ std::cerr << "Usage: " << name << " [-u unknown_probability] [-s] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n"
+"-u sets the default probability for <unk> if the ARPA file does not have one.\n"
+"-s allows models to be built even if they do not have <s> and </s>.\n\n"
+"type is one of probing, trie, or sorted:\n\n"
"probing uses a probing hash table. It is the fastest but uses the most memory.\n"
"-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n"
"trie is a straightforward trie with bit-level packing. It uses the least\n"
@@ -65,12 +69,25 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {
} // namespace lm
} // namespace
+void terminate_handler() {
+ try { throw; }
+ catch(const std::exception& e) {
+ std::cerr << e.what() << std::endl;
+ }
+ catch(...) {
+ std::cerr << "A non-standard exception was thrown." << std::endl;
+ }
+ std::abort();
+}
+
int main(int argc, char *argv[]) {
using namespace lm::ngram;
+ std::set_terminate(terminate_handler);
+
lm::ngram::Config config;
int opt;
- while ((opt = getopt(argc, argv, "u:p:t:m:")) != -1) {
+ while ((opt = getopt(argc, argv, "su:p:t:m:")) != -1) {
switch(opt) {
case 'u':
config.unknown_missing_prob = ParseFloat(optarg);
@@ -84,6 +101,9 @@ int main(int argc, char *argv[]) {
case 'm':
config.building_memory = ParseUInt(optarg) * 1048576;
break;
+ case 's':
+ config.sentence_marker_missing = lm::ngram::Config::SILENT;
+ break;
default:
Usage(argv[0]);
}
diff --git a/klm/lm/config.cc b/klm/lm/config.cc
index 2831d578..d8773fe5 100644
--- a/klm/lm/config.cc
+++ b/klm/lm/config.cc
@@ -9,6 +9,7 @@ Config::Config() :
messages(&std::cerr),
enumerate_vocab(NULL),
unknown_missing(COMPLAIN),
+ sentence_marker_missing(THROW_UP),
unknown_missing_prob(0.0),
probing_multiplier(1.5),
building_memory(1073741824ULL), // 1 GB
diff --git a/klm/lm/config.hh b/klm/lm/config.hh
index 767fa5f9..17f67df3 100644
--- a/klm/lm/config.hh
+++ b/klm/lm/config.hh
@@ -27,9 +27,12 @@ struct Config {
// ONLY EFFECTIVE WHEN READING ARPA
+ typedef enum {THROW_UP, COMPLAIN, SILENT} WarningAction;
// What to do when <unk> isn't in the provided model.
- typedef enum {THROW_UP, COMPLAIN, SILENT} UnknownMissing;
- UnknownMissing unknown_missing;
+ WarningAction unknown_missing;
+ // What to do when <s> or </s> is missing from the model.
+ // If THROW_UP, the exception will be of type util::SpecialWordMissingException.
+ WarningAction sentence_marker_missing;
// The probability to substitute for <unk> if it's missing from the model.
// No effect if the model has <unk> or unknown_missing == THROW_UP.
diff --git a/klm/lm/exception.cc b/klm/lm/exception.cc
deleted file mode 100644
index 59a1650d..00000000
--- a/klm/lm/exception.cc
+++ /dev/null
@@ -1,21 +0,0 @@
-#include "lm/exception.hh"
-
-#include<errno.h>
-#include<stdio.h>
-
-namespace lm {
-
-LoadException::LoadException() throw() {}
-LoadException::~LoadException() throw() {}
-VocabLoadException::VocabLoadException() throw() {}
-VocabLoadException::~VocabLoadException() throw() {}
-
-FormatLoadException::FormatLoadException() throw() {}
-FormatLoadException::~FormatLoadException() throw() {}
-
-SpecialWordMissingException::SpecialWordMissingException(StringPiece which) throw() {
- *this << "Missing special word " << which;
-}
-SpecialWordMissingException::~SpecialWordMissingException() throw() {}
-
-} // namespace lm
diff --git a/klm/lm/exception.hh b/klm/lm/exception.hh
deleted file mode 100644
index 95109012..00000000
--- a/klm/lm/exception.hh
+++ /dev/null
@@ -1,40 +0,0 @@
-#ifndef LM_EXCEPTION__
-#define LM_EXCEPTION__
-
-#include "util/exception.hh"
-#include "util/string_piece.hh"
-
-#include <exception>
-#include <string>
-
-namespace lm {
-
-class LoadException : public util::Exception {
- public:
- virtual ~LoadException() throw();
-
- protected:
- LoadException() throw();
-};
-
-class VocabLoadException : public LoadException {
- public:
- virtual ~VocabLoadException() throw();
- VocabLoadException() throw();
-};
-
-class FormatLoadException : public LoadException {
- public:
- FormatLoadException() throw();
- ~FormatLoadException() throw();
-};
-
-class SpecialWordMissingException : public VocabLoadException {
- public:
- explicit SpecialWordMissingException(StringPiece which) throw();
- ~SpecialWordMissingException() throw();
-};
-
-} // namespace lm
-
-#endif
diff --git a/klm/lm/lm_exception.cc b/klm/lm/lm_exception.cc
index 473849d1..0b572e98 100644
--- a/klm/lm/lm_exception.cc
+++ b/klm/lm/lm_exception.cc
@@ -17,9 +17,7 @@ FormatLoadException::~FormatLoadException() throw() {}
VocabLoadException::VocabLoadException() throw() {}
VocabLoadException::~VocabLoadException() throw() {}
-SpecialWordMissingException::SpecialWordMissingException(StringPiece which) throw() {
- *this << "Missing special word " << which;
-}
+SpecialWordMissingException::SpecialWordMissingException() throw() {}
SpecialWordMissingException::~SpecialWordMissingException() throw() {}
} // namespace lm
diff --git a/klm/lm/lm_exception.hh b/klm/lm/lm_exception.hh
index 3773c572..aa3ca886 100644
--- a/klm/lm/lm_exception.hh
+++ b/klm/lm/lm_exception.hh
@@ -39,7 +39,7 @@ class VocabLoadException : public LoadException {
class SpecialWordMissingException : public VocabLoadException {
public:
- explicit SpecialWordMissingException(StringPiece which) throw();
+ explicit SpecialWordMissingException() throw();
~SpecialWordMissingException() throw();
};
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index 146fe07b..14949e97 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -82,25 +82,13 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
}
- // TODO: fail faster?
if (!vocab_.SawUnk()) {
- switch(config.unknown_missing) {
- case Config::THROW_UP:
- {
- SpecialWordMissingException e("<unk>");
- e << " and configuration was set to throw if unknown is missing";
- throw e;
- }
- case Config::COMPLAIN:
- if (config.messages) *config.messages << "Language model is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl;
- // There's no break;. This is by design.
- case Config::SILENT:
- // Default probabilities for unknown.
- search_.unigram.Unknown().backoff = 0.0;
- search_.unigram.Unknown().prob = config.unknown_missing_prob;
- break;
- }
+ assert(config.unknown_missing != Config::THROW_UP);
+ // Default probabilities for unknown.
+ search_.unigram.Unknown().backoff = 0.0;
+ search_.unigram.Unknown().prob = config.unknown_missing_prob;
}
+ FinishFile(config, kModelType, counts, backing_);
}
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
diff --git a/klm/lm/ngram.hh b/klm/lm/ngram.hh
deleted file mode 100644
index 899a80e8..00000000
--- a/klm/lm/ngram.hh
+++ /dev/null
@@ -1,226 +0,0 @@
-#ifndef LM_NGRAM__
-#define LM_NGRAM__
-
-#include "lm/facade.hh"
-#include "lm/ngram_config.hh"
-#include "util/key_value_packing.hh"
-#include "util/mmap.hh"
-#include "util/probing_hash_table.hh"
-#include "util/scoped.hh"
-#include "util/sorted_uniform.hh"
-#include "util/string_piece.hh"
-
-#include <algorithm>
-#include <memory>
-#include <vector>
-
-namespace util { class FilePiece; }
-
-namespace lm {
-namespace ngram {
-
-// If you need higher order, change this and recompile.
-// Having this limit means that State can be
-// (kMaxOrder - 1) * sizeof(float) bytes instead of
-// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
-const std::size_t kMaxOrder = 6;
-
-// This is a POD.
-class State {
- public:
- bool operator==(const State &other) const {
- if (valid_length_ != other.valid_length_) return false;
- const WordIndex *end = history_ + valid_length_;
- for (const WordIndex *first = history_, *second = other.history_;
- first != end; ++first, ++second) {
- if (*first != *second) return false;
- }
- // If the histories are equal, so are the backoffs.
- return true;
- }
-
- // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
- // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
- WordIndex history_[kMaxOrder - 1];
- float backoff_[kMaxOrder - 1];
- unsigned char valid_length_;
-};
-
-size_t hash_value(const State &state);
-
-namespace detail {
-
-uint64_t HashForVocab(const char *str, std::size_t len);
-inline uint64_t HashForVocab(const StringPiece &str) {
- return HashForVocab(str.data(), str.length());
-}
-
-struct Prob {
- float prob;
- void SetBackoff(float to);
- void ZeroBackoff() {}
-};
-// No inheritance so this will be a POD.
-struct ProbBackoff {
- float prob;
- float backoff;
- void SetBackoff(float to) { backoff = to; }
- void ZeroBackoff() { backoff = 0.0; }
-};
-
-} // namespace detail
-
-// Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices.
-class SortedVocabulary : public base::Vocabulary {
- private:
- // Sorted uniform requires a GetKey function.
- struct Entry {
- uint64_t GetKey() const { return key; }
- uint64_t key;
- bool operator<(const Entry &other) const {
- return key < other.key;
- }
- };
-
- public:
- SortedVocabulary();
-
- WordIndex Index(const StringPiece &str) const {
- const Entry *found;
- if (util::SortedUniformFind<const Entry *, uint64_t>(begin_, end_, detail::HashForVocab(str), found)) {
- return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table.
- } else {
- return 0;
- }
- }
-
- // Ignores second argument for consistency with probing hash which has a float here.
- static size_t Size(std::size_t entries, float ignored = 0.0);
-
- // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
- void Init(void *start, std::size_t allocated, std::size_t entries);
-
- WordIndex Insert(const StringPiece &str);
-
- // Returns true if unknown was seen. Reorders reorder_vocab so that the IDs are sorted.
- bool FinishedLoading(detail::ProbBackoff *reorder_vocab);
-
- void LoadedBinary();
-
- private:
- Entry *begin_, *end_;
-
- bool saw_unk_;
-};
-
-namespace detail {
-
-// Vocabulary storing a map from uint64_t to WordIndex.
-template <class Search> class MapVocabulary : public base::Vocabulary {
- public:
- MapVocabulary();
-
- WordIndex Index(const StringPiece &str) const {
- typename Lookup::ConstIterator i;
- return lookup_.Find(HashForVocab(str), i) ? i->GetValue() : 0;
- }
-
- static size_t Size(std::size_t entries, float probing_multiplier) {
- return Lookup::Size(entries, probing_multiplier);
- }
-
- // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
- void Init(void *start, std::size_t allocated, std::size_t entries);
-
- WordIndex Insert(const StringPiece &str);
-
- // Returns true if unknown was seen. Does nothing with reorder_vocab.
- bool FinishedLoading(ProbBackoff *reorder_vocab);
-
- void LoadedBinary();
-
- private:
- typedef typename Search::template Table<WordIndex>::T Lookup;
- Lookup lookup_;
-
- bool saw_unk_;
-};
-
-// std::identity is an SGI extension :-(
-struct IdentityHash : public std::unary_function<uint64_t, size_t> {
- size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); }
-};
-
-// Should return the same results as SRI.
-// Why VocabularyT instead of just Vocabulary? ModelFacade defines Vocabulary.
-template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> {
- private:
- typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P;
- public:
- // Get the size of memory that will be mapped given ngram counts. This
- // does not include small non-mapped control structures, such as this class
- // itself.
- static size_t Size(const std::vector<size_t> &counts, const Config &config = Config());
-
- GenericModel(const char *file, Config config = Config());
-
- FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const;
-
- private:
- // Appears after Size in the cc.
- void SetupMemory(char *start, const std::vector<size_t> &counts, const Config &config);
-
- void LoadFromARPA(util::FilePiece &f, const std::vector<size_t> &counts, const Config &config);
-
- util::scoped_fd mapped_file_;
-
- // memory_ is the raw block of memory backing vocab_, unigram_, [middle.begin(), middle.end()), and longest_.
- util::scoped_mmap memory_;
-
- VocabularyT vocab_;
-
- ProbBackoff *unigram_;
-
- typedef typename Search::template Table<ProbBackoff>::T Middle;
- std::vector<Middle> middle_;
-
- typedef typename Search::template Table<Prob>::T Longest;
- Longest longest_;
-};
-
-struct ProbingSearch {
- typedef float Init;
-
- static const unsigned char kBinaryTag = 1;
-
- template <class Value> struct Table {
- typedef util::ByteAlignedPacking<uint64_t, Value> Packing;
- typedef util::ProbingHashTable<Packing, IdentityHash> T;
- };
-};
-
-struct SortedUniformSearch {
- // This is ignored.
- typedef float Init;
-
- static const unsigned char kBinaryTag = 2;
-
- template <class Value> struct Table {
- typedef util::ByteAlignedPacking<uint64_t, Value> Packing;
- typedef util::SortedUniformMap<Packing> T;
- };
-};
-
-} // namespace detail
-
-// These must also be instantiated in the cc file.
-typedef detail::MapVocabulary<detail::ProbingSearch> Vocabulary;
-typedef detail::GenericModel<detail::ProbingSearch, Vocabulary> Model;
-
-// SortedVocabulary was defined above.
-typedef detail::GenericModel<detail::SortedUniformSearch, SortedVocabulary> SortedModel;
-
-} // namespace ngram
-} // namespace lm
-
-#endif // LM_NGRAM__
diff --git a/klm/lm/ngram_build_binary.cc b/klm/lm/ngram_build_binary.cc
deleted file mode 100644
index 9dab30a1..00000000
--- a/klm/lm/ngram_build_binary.cc
+++ /dev/null
@@ -1,13 +0,0 @@
-#include "lm/ngram.hh"
-
-#include <iostream>
-
-int main(int argc, char *argv[]) {
- if (argc != 3) {
- std::cerr << "Usage: " << argv[0] << " input.arpa output.mmap" << std::endl;
- return 1;
- }
- lm::ngram::Config config;
- config.write_mmap = argv[2];
- lm::ngram::Model(argv[1], config);
-}
diff --git a/klm/lm/ngram_config.hh b/klm/lm/ngram_config.hh
deleted file mode 100644
index a7b3afae..00000000
--- a/klm/lm/ngram_config.hh
+++ /dev/null
@@ -1,58 +0,0 @@
-#ifndef LM_NGRAM_CONFIG__
-#define LM_NGRAM_CONFIG__
-
-/* Configuration for ngram model. Separate header to reduce pollution. */
-
-#include <iostream>
-
-namespace lm { namespace ngram {
-
-struct Config {
- /* EFFECTIVE FOR BOTH ARPA AND BINARY READS */
- // Where to log messages including the progress bar. Set to NULL for
- // silence.
- std::ostream *messages;
-
-
-
- /* ONLY EFFECTIVE WHEN READING ARPA */
-
- // What to do when <unk> isn't in the provided model.
- typedef enum {THROW_UP, COMPLAIN, SILENT} UnknownMissing;
- UnknownMissing unknown_missing;
-
- // The probability to substitute for <unk> if it's missing from the model.
- // No effect if the model has <unk> or unknown_missing == THROW_UP.
- float unknown_missing_prob;
-
- // Size multiplier for probing hash table. Must be > 1. Space is linear in
- // this. Time is probing_multiplier / (probing_multiplier - 1). No effect
- // for sorted variant.
- // If you find yourself setting this to a low number, consider using the
- // Sorted version instead which has lower memory consumption.
- float probing_multiplier;
-
- // While loading an ARPA file, also write out this binary format file. Set
- // to NULL to disable.
- const char *write_mmap;
-
-
-
- /* ONLY EFFECTIVE WHEN READING BINARY */
- bool prefault;
-
-
-
- // Defaults.
- Config() :
- messages(&std::cerr),
- unknown_missing(COMPLAIN),
- unknown_missing_prob(0.0),
- probing_multiplier(1.5),
- write_mmap(NULL),
- prefault(false) {}
-};
-
-} /* namespace ngram */ } /* namespace lm */
-
-#endif // LM_NGRAM_CONFIG__
diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc
index d6da02e3..9454a6d1 100644
--- a/klm/lm/ngram_query.cc
+++ b/klm/lm/ngram_query.cc
@@ -35,14 +35,14 @@ void PrintUsage(const char *message) {
}
}
-template <class Model> void Query(const Model &model) {
+template <class Model> void Query(const Model &model, bool sentence_context) {
PrintUsage("Loading statistics:\n");
typename Model::State state, out;
lm::FullScoreReturn ret;
std::string word;
while (std::cin) {
- state = model.BeginSentenceState();
+ state = sentence_context ? model.BeginSentenceState() : model.NullContextState();
float total = 0.0;
bool got = false;
unsigned int oov = 0;
@@ -52,7 +52,7 @@ template <class Model> void Query(const Model &model) {
if (vocab == 0) ++oov;
ret = model.FullScore(state, vocab, out);
total += ret.prob;
- std::cout << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\n';
+ std::cout << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
state = out;
char c;
while (true) {
@@ -67,9 +67,11 @@ template <class Model> void Query(const Model &model) {
if (c == '\n') break;
}
if (!got && !std::cin) break;
- ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
- total += ret.prob;
- std::cout << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\n';
+ if (sentence_context) {
+ ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
+ total += ret.prob;
+ std::cout << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
+ }
std::cout << "Total: " << total << " OOV: " << oov << '\n';
}
PrintUsage("After queries:\n");
@@ -82,29 +84,30 @@ template <class Model> void Query(const char *name) {
}
int main(int argc, char *argv[]) {
- if (argc < 2) {
- std::cerr << "Pass language model name." << std::endl;
- return 0;
+ if (!(argc == 2 || (argc == 3 && !strcmp(argv[2], "null")))) {
+ std::cerr << "Usage: " << argv[0] << " lm_file [null]" << std::endl;
+ std::cerr << "Input is wrapped in <s> and </s> unless null is passed." << std::endl;
+ return 1;
}
+ bool sentence_context = (argc == 2);
lm::ngram::ModelType model_type;
if (lm::ngram::RecognizeBinary(argv[1], model_type)) {
switch(model_type) {
case lm::ngram::HASH_PROBING:
- Query<lm::ngram::ProbingModel>(argv[1]);
- break;
- case lm::ngram::HASH_SORTED:
- Query<lm::ngram::SortedModel>(argv[1]);
+ Query<lm::ngram::ProbingModel>(argv[1], sentence_context);
break;
case lm::ngram::TRIE_SORTED:
- Query<lm::ngram::TrieModel>(argv[1]);
+ Query<lm::ngram::TrieModel>(argv[1], sentence_context);
break;
+ case lm::ngram::HASH_SORTED:
default:
std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
abort();
}
} else {
- Query<lm::ngram::ProbingModel>(argv[1]);
+ Query<lm::ngram::ProbingModel>(argv[1], sentence_context);
}
PrintUsage("Total time including destruction:\n");
+ return 0;
}
diff --git a/klm/lm/ngram_test.cc b/klm/lm/ngram_test.cc
deleted file mode 100644
index 031e0348..00000000
--- a/klm/lm/ngram_test.cc
+++ /dev/null
@@ -1,91 +0,0 @@
-#include "lm/ngram.hh"
-
-#include <stdlib.h>
-
-#define BOOST_TEST_MODULE NGramTest
-#include <boost/test/unit_test.hpp>
-
-namespace lm {
-namespace ngram {
-namespace {
-
-#define StartTest(word, ngram, score) \
- ret = model.FullScore( \
- state, \
- model.GetVocabulary().Index(word), \
- out);\
- BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
- BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
- BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_);
-
-#define AppendTest(word, ngram, score) \
- StartTest(word, ngram, score) \
- state = out;
-
-template <class M> void Starters(M &model) {
- FullScoreReturn ret;
- Model::State state(model.BeginSentenceState());
- Model::State out;
-
- StartTest("looking", 2, -0.4846522);
-
- // , probability plus <s> backoff
- StartTest(",", 1, -1.383514 + -0.4149733);
- // <unk> probability plus <s> backoff
- StartTest("this_is_not_found", 0, -1.995635 + -0.4149733);
-}
-
-template <class M> void Continuation(M &model) {
- FullScoreReturn ret;
- Model::State state(model.BeginSentenceState());
- Model::State out;
-
- AppendTest("looking", 2, -0.484652);
- AppendTest("on", 3, -0.348837);
- AppendTest("a", 4, -0.0155266);
- AppendTest("little", 5, -0.00306122);
- State preserve = state;
- AppendTest("the", 1, -4.04005);
- AppendTest("biarritz", 1, -1.9889);
- AppendTest("not_found", 0, -2.29666);
- AppendTest("more", 1, -1.20632);
- AppendTest(".", 2, -0.51363);
- AppendTest("</s>", 3, -0.0191651);
-
- state = preserve;
- AppendTest("more", 5, -0.00181395);
- AppendTest("loin", 5, -0.0432557);
-}
-
-BOOST_AUTO_TEST_CASE(starters_probing) { Model m("test.arpa"); Starters(m); }
-BOOST_AUTO_TEST_CASE(continuation_probing) { Model m("test.arpa"); Continuation(m); }
-BOOST_AUTO_TEST_CASE(starters_sorted) { SortedModel m("test.arpa"); Starters(m); }
-BOOST_AUTO_TEST_CASE(continuation_sorted) { SortedModel m("test.arpa"); Continuation(m); }
-
-BOOST_AUTO_TEST_CASE(write_and_read_probing) {
- Config config;
- config.write_mmap = "test.binary";
- {
- Model copy_model("test.arpa", config);
- }
- Model binary("test.binary");
- Starters(binary);
- Continuation(binary);
-}
-
-BOOST_AUTO_TEST_CASE(write_and_read_sorted) {
- Config config;
- config.write_mmap = "test.binary";
- config.prefault = true;
- {
- SortedModel copy_model("test.arpa", config);
- }
- SortedModel binary("test.binary");
- Starters(binary);
- Continuation(binary);
-}
-
-
-} // namespace
-} // namespace ngram
-} // namespace lm
diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc
index d0fe67f0..0e90196d 100644
--- a/klm/lm/read_arpa.cc
+++ b/klm/lm/read_arpa.cc
@@ -6,6 +6,7 @@
#include <vector>
#include <ctype.h>
+#include <string.h>
#include <inttypes.h>
namespace lm {
@@ -22,14 +23,20 @@ bool IsEntirelyWhiteSpace(const StringPiece &line) {
return true;
}
-template <class F> void GenericReadARPACounts(F &in, std::vector<uint64_t> &number) {
+const char kBinaryMagic[] = "mmap lm http://kheafield.com/code";
+
+} // namespace
+
+void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
number.clear();
StringPiece line;
if (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) {
- UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, run\nzcat " << in.FileName() << " |kenlm/build_binary /dev/stdin " << in.FileName() << ".binary\nIf this already in binary format, you need to decompress it because mmap doesn't work on top of gzip.");
+ UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip.");
}
- UTIL_THROW(FormatLoadException, "First line was \"" << static_cast<int>(line.data()[1]) << "\" not blank");
+ if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic)
+ UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser. Did you compress the binary file or pass a binary file where only ARPA files are accepted?");
+ UTIL_THROW(FormatLoadException, "First line was \"" << line.data() << "\" not blank");
}
if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\.");
while (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
@@ -49,66 +56,14 @@ template <class F> void GenericReadARPACounts(F &in, std::vector<uint64_t> &numb
}
}
-template <class F> void GenericReadNGramHeader(F &in, unsigned int length) {
- StringPiece line;
+void ReadNGramHeader(util::FilePiece &in, unsigned int length) {
+ StringPiece line;
while (IsEntirelyWhiteSpace(line = in.ReadLine())) {}
std::stringstream expected;
expected << '\\' << length << "-grams:";
if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead");
}
-template <class F> void GenericReadEnd(F &in) {
- StringPiece line;
- do {
- line = in.ReadLine();
- } while (IsEntirelyWhiteSpace(line));
- if (line != "\\end\\") UTIL_THROW(FormatLoadException, "Expected \\end\\ but the ARPA file has " << line);
-}
-
-class FakeFilePiece {
- public:
- explicit FakeFilePiece(std::istream &in) : in_(in) {
- in_.exceptions(std::ios::failbit | std::ios::badbit | std::ios::eofbit);
- }
-
- StringPiece ReadLine() throw(util::EndOfFileException) {
- getline(in_, buffer_);
- return StringPiece(buffer_);
- }
-
- float ReadFloat() {
- float ret;
- in_ >> ret;
- return ret;
- }
-
- const char *FileName() const {
- // This only used for error messages and we don't know the file name. . .
- return "$file";
- }
-
- private:
- std::istream &in_;
- std::string buffer_;
-};
-
-} // namespace
-
-void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
- GenericReadARPACounts(in, number);
-}
-void ReadARPACounts(std::istream &in, std::vector<uint64_t> &number) {
- FakeFilePiece fake(in);
- GenericReadARPACounts(fake, number);
-}
-void ReadNGramHeader(util::FilePiece &in, unsigned int length) {
- GenericReadNGramHeader(in, length);
-}
-void ReadNGramHeader(std::istream &in, unsigned int length) {
- FakeFilePiece fake(in);
- GenericReadNGramHeader(fake, length);
-}
-
void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) {
switch (in.get()) {
case '\t':
@@ -146,20 +101,18 @@ void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) {
}
void ReadEnd(util::FilePiece &in) {
- GenericReadEnd(in);
StringPiece line;
+ do {
+ line = in.ReadLine();
+ } while (IsEntirelyWhiteSpace(line));
+ if (line != "\\end\\") UTIL_THROW(FormatLoadException, "Expected \\end\\ but the ARPA file has " << line);
+
try {
while (true) {
line = in.ReadLine();
if (!IsEntirelyWhiteSpace(line)) UTIL_THROW(FormatLoadException, "Trailing line " << line);
}
- } catch (const util::EndOfFileException &e) {
- return;
- }
-}
-void ReadEnd(std::istream &in) {
- FakeFilePiece fake(in);
- GenericReadEnd(fake);
+ } catch (const util::EndOfFileException &e) {}
}
} // namespace lm
diff --git a/klm/lm/read_arpa.hh b/klm/lm/read_arpa.hh
index 4efdd29d..4953d40e 100644
--- a/klm/lm/read_arpa.hh
+++ b/klm/lm/read_arpa.hh
@@ -13,15 +13,12 @@
namespace lm {
void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number);
-void ReadARPACounts(std::istream &in, std::vector<uint64_t> &number);
void ReadNGramHeader(util::FilePiece &in, unsigned int length);
-void ReadNGramHeader(std::istream &in, unsigned int length);
void ReadBackoff(util::FilePiece &in, Prob &weights);
void ReadBackoff(util::FilePiece &in, ProbBackoff &weights);
void ReadEnd(util::FilePiece &in);
-void ReadEnd(std::istream &in);
extern const bool kARPASpaces[256];
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index f97ec790..bb3b955a 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -83,9 +83,10 @@ namespace detail {
template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing) {
// TODO: fix sorted.
- SetupMemory(GrowForSearch(config, HASH_PROBING, counts, Size(counts, config), backing), counts, config);
+ SetupMemory(GrowForSearch(config, Size(counts, config), backing), counts, config);
- Read1Grams(f, counts[0], vocab, unigram.Raw());
+ Read1Grams(f, counts[0], vocab, unigram.Raw());
+ CheckSpecials(config, vocab);
try {
if (counts.size() > 2) {
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 1060ddef..63631223 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -11,6 +11,7 @@
#include "lm/word_index.hh"
#include "util/ersatz_progress.hh"
#include "util/file_piece.hh"
+#include "util/have.hh"
#include "util/proxy_iterator.hh"
#include "util/scoped.hh"
@@ -20,7 +21,6 @@
#include <cstdio>
#include <deque>
#include <limits>
-//#include <parallel/algorithm>
#include <vector>
#include <sys/mman.h>
@@ -170,7 +170,7 @@ template <class Proxy> class CompareRecords : public std::binary_function<const
return Compare(reinterpret_cast<const WordIndex*>(first.data()), second.Indices());
}
bool operator()(const std::string &first, const std::string &second) const {
- return Compare(reinterpret_cast<const WordIndex*>(first.data()), reinterpret_cast<const WordIndex*>(first.data()));
+ return Compare(reinterpret_cast<const WordIndex*>(first.data()), reinterpret_cast<const WordIndex*>(second.data()));
}
private:
@@ -384,7 +384,6 @@ void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_fil
PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size));
PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size));
- // TODO: __gnu_parallel::sort here.
std::sort(context_begin, context_end, CompareRecords<PartialViewProxy>(order - 1));
std::string name(ngram_file_name + kContextSuffix);
@@ -406,16 +405,16 @@ void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_fil
class ContextReader {
public:
- ContextReader() : length_(0) {}
+ ContextReader() : valid_(false) {}
- ContextReader(const char *name, size_t length) : file_(OpenOrThrow(name, "r")), length_(length), words_(length), valid_(true) {
- ++*this;
+ ContextReader(const char *name, unsigned char order) {
+ Reset(name, order);
}
- void Reset(const char *name, size_t length) {
+ void Reset(const char *name, unsigned char order) {
file_.reset(OpenOrThrow(name, "r"));
- length_ = length;
- words_.resize(length);
+ length_ = sizeof(WordIndex) * static_cast<size_t>(order);
+ words_.resize(order);
valid_ = true;
++*this;
}
@@ -449,14 +448,14 @@ void MergeContextFiles(const std::string &first_base, const std::string &second_
const size_t context_size = sizeof(WordIndex) * (order - 1);
std::string first_name(first_base + kContextSuffix);
std::string second_name(second_base + kContextSuffix);
- ContextReader first(first_name.c_str(), context_size), second(second_name.c_str(), context_size);
+ ContextReader first(first_name.c_str(), order - 1), second(second_name.c_str(), order - 1);
RemoveOrThrow(first_name.c_str());
RemoveOrThrow(second_name.c_str());
std::string out_name(out_base + kContextSuffix);
util::scoped_FILE out(OpenOrThrow(out_name.c_str(), "w"));
while (first && second) {
for (const WordIndex *f = *first, *s = *second; ; ++f, ++s) {
- if (f == *first + order) {
+ if (f == *first + order - 1) {
// Equal.
WriteOrThrow(out.get(), *first, context_size);
++first;
@@ -475,7 +474,10 @@ void MergeContextFiles(const std::string &first_base, const std::string &second_
}
}
}
- CopyRestOrThrow((first ? first : second).GetFile(), out.get());
+ ContextReader &remaining = first ? first : second;
+ if (!remaining) return;
+ WriteOrThrow(out.get(), *remaining, context_size);
+ CopyRestOrThrow(remaining.GetFile(), out.get());
}
void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) {
@@ -502,7 +504,7 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
}
// Sort full records by full n-gram.
EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size);
- // TODO: __gnu_parallel::sort here.
+ // parallel_sort uses too much RAM
std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords<EntryProxy>(order));
files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size));
WriteContextFile(begin, out_end, files.back(), entry_size, order);
@@ -533,21 +535,22 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
}
}
-void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
+void ARPAToSortedFiles(const Config &config, util::FilePiece &f, const std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
{
std::string unigram_name = file_prefix + "unigrams";
util::scoped_fd unigram_file;
util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff));
Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()));
+ CheckSpecials(config, vocab);
}
// Only use as much buffer as we need.
size_t buffer_use = 0;
for (unsigned int order = 2; order < counts.size(); ++order) {
- buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1]));
+ buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1]));
}
- buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back()));
- buffer = std::min(buffer, buffer_use);
+ buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back()));
+ buffer = std::min<size_t>(buffer, buffer_use);
util::scoped_memory mem;
mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED);
@@ -767,7 +770,7 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u
}
}
-void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) {
+void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) {
SortedFileReader inputs[counts.size() - 1];
ContextReader contexts[counts.size() - 1];
@@ -777,7 +780,7 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun
inputs[i-2].Init(assembled.str(), i);
RemoveOrThrow(assembled.str().c_str());
assembled << kContextSuffix;
- contexts[i-2].Reset(assembled.str().c_str(), (i-1) * sizeof(WordIndex));
+ contexts[i-2].Reset(assembled.str().c_str(), i-1);
RemoveOrThrow(assembled.str().c_str());
}
@@ -787,8 +790,9 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun
counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]);
}
SanityCheckCounts(counts, fixed_counts);
+ counts = fixed_counts;
- out.SetupMemory(GrowForSearch(config, TrieSearch::kModelType, fixed_counts, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config);
+ out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config);
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Rewind();
@@ -811,7 +815,7 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun
++contexts[0];
}
}
- unlink(name.c_str());
+ RemoveOrThrow(name.c_str());
}
// Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
@@ -823,7 +827,7 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun
for (const WordIndex *i = *context; i != *context + order - 1; ++i) {
e << ' ' << *i;
}
- e << " so this context must appear in the model as a " << static_cast<unsigned int>(order - 1) << "-gram but it does not.";
+ e << " so this context must appear in the model as a " << static_cast<unsigned int>(order - 1) << "-gram but it does not";
throw e;
}
}
@@ -868,7 +872,7 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::v
// Add directory delimiter. Assumes a real operating system.
temporary_directory += '/';
// At least 1MB sorting memory.
- ARPAToSortedFiles(f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
+ ARPAToSortedFiles(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
BuildTrie(temporary_directory, counts, config, *this, backing);
if (rmdir(temporary_directory.c_str()) && config.messages) {
diff --git a/klm/lm/test.binary b/klm/lm/test.binary
deleted file mode 100644
index 90bd2b76..00000000
--- a/klm/lm/test.binary
+++ /dev/null
Binary files differ
diff --git a/klm/lm/virtual_interface.cc b/klm/lm/virtual_interface.cc
index c5a64972..17a74c3c 100644
--- a/klm/lm/virtual_interface.cc
+++ b/klm/lm/virtual_interface.cc
@@ -11,8 +11,6 @@ void Vocabulary::SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, Wo
begin_sentence_ = begin_sentence;
end_sentence_ = end_sentence;
not_found_ = not_found;
- if (begin_sentence_ == not_found_) throw SpecialWordMissingException("<s>");
- if (end_sentence_ == not_found_) throw SpecialWordMissingException("</s>");
}
Model::~Model() {}
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index ae79c727..415f8331 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -187,5 +187,29 @@ void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
SetSpecial(Index("<s>"), Index("</s>"), 0);
}
+void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
+ switch(config.unknown_missing) {
+ case Config::SILENT:
+ return;
+ case Config::COMPLAIN:
+ if (config.messages) *config.messages << "The ARPA file is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl;
+ break;
+ case Config::THROW_UP:
+ UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing <unk> and the model is configured to throw an exception.");
+ }
+}
+
+void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException) {
+ switch (config.sentence_marker_missing) {
+ case Config::SILENT:
+ return;
+ case Config::COMPLAIN:
+ if (config.messages) *config.messages << "Missing special word " << str << "; will treat it as <unk>.";
+ break;
+ case Config::THROW_UP:
+ UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing " << str << " and the model is configured to reject these models. Run build_binary -s to disable this check.");
+ }
+}
+
} // namespace ngram
} // namespace lm
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index b584c82f..546c1649 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -2,6 +2,7 @@
#define LM_VOCAB__
#include "lm/enumerate_vocab.hh"
+#include "lm/lm_exception.hh"
#include "lm/virtual_interface.hh"
#include "util/key_value_packing.hh"
#include "util/probing_hash_table.hh"
@@ -134,6 +135,15 @@ class ProbingVocabulary : public base::Vocabulary {
EnumerateVocab *enumerate_;
};
+void MissingUnknown(const Config &config) throw(SpecialWordMissingException);
+void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException);
+
+template <class Vocab> void CheckSpecials(const Config &config, const Vocab &vocab) throw(SpecialWordMissingException) {
+ if (!vocab.SawUnk()) MissingUnknown(config);
+ if (vocab.BeginSentence() == vocab.NotFound()) MissingSentenceMarker(config, "<s>");
+ if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>");
+}
+
} // namespace ngram
} // namespace lm