summaryrefslogtreecommitdiff
path: root/klm
diff options
context:
space:
mode:
Diffstat (limited to 'klm')
-rw-r--r--klm/lm/binary_format.cc4
-rw-r--r--klm/lm/blank.hh44
-rw-r--r--klm/lm/build_binary.cc2
-rw-r--r--klm/lm/model.cc97
-rw-r--r--klm/lm/model.hh9
-rw-r--r--klm/lm/model_test.cc49
-rw-r--r--klm/lm/ngram_query.cc27
-rw-r--r--klm/lm/read_arpa.cc17
-rw-r--r--klm/lm/read_arpa.hh6
-rw-r--r--klm/lm/search_hashed.cc52
-rw-r--r--klm/lm/search_trie.cc302
-rw-r--r--klm/lm/vocab.hh1
-rw-r--r--klm/util/bit_packing.cc2
-rw-r--r--klm/util/bit_packing.hh17
-rw-r--r--klm/util/ersatz_progress.cc1
-rw-r--r--klm/util/file_piece.cc24
-rw-r--r--klm/util/file_piece.hh35
-rw-r--r--klm/util/key_value_packing.hh4
-rw-r--r--klm/util/probing_hash_table.hh14
-rw-r--r--klm/util/sorted_uniform.hh10
20 files changed, 550 insertions, 167 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index 3d9700da..2a6aff34 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -18,8 +18,8 @@ namespace lm {
namespace ngram {
namespace {
const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
-const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 3\n\0";
-const long int kMagicVersion = 2;
+const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 4\n\0";
+const long int kMagicVersion = 4;
// Test values.
struct Sanity {
diff --git a/klm/lm/blank.hh b/klm/lm/blank.hh
index 639bc98b..4615a09e 100644
--- a/klm/lm/blank.hh
+++ b/klm/lm/blank.hh
@@ -1,12 +1,52 @@
#ifndef LM_BLANK__
#define LM_BLANK__
+
#include <limits>
+#include <inttypes.h>
+#include <math.h>
+
namespace lm {
namespace ngram {
-const float kBlankProb = -std::numeric_limits<float>::quiet_NaN();
-const float kBlankBackoff = std::numeric_limits<float>::infinity();
+/* Suppose "foo bar" appears with zero backoff but there is no trigram
+ * beginning with these words. Then, when scoring "foo bar", the model could
+ * return out_state containing "bar" or even null context if "bar" also has no
+ * backoff and is never followed by another word. Then the backoff is set to
+ * kNoExtensionBackoff. If the n-gram might be extended, then out_state must
+ * contain the full n-gram, in which case kExtensionBackoff is set. In any
+ * case, if an n-gram has non-zero backoff, the full state is returned so
+ * backoff can be properly charged.
+ * These differ only in sign bit because the backoff is in fact zero in either
+ * case.
+ */
+const float kNoExtensionBackoff = -0.0;
+const float kExtensionBackoff = 0.0;
+
+inline void SetExtension(float &backoff) {
+ if (backoff == kNoExtensionBackoff) backoff = kExtensionBackoff;
+}
+
+// This compiles down nicely.
+inline bool HasExtension(const float &backoff) {
+ typedef union { float f; uint32_t i; } UnionValue;
+ UnionValue compare, interpret;
+ compare.f = kNoExtensionBackoff;
+ interpret.f = backoff;
+ return compare.i != interpret.i;
+}
+
+/* Suppose "foo bar baz quux" appears in the ARPA but not "bar baz quux" or
+ * "baz quux" (because they were pruned). 1.2% of n-grams generated by SRI
+ * with default settings on the benchmark data set are like this. Since search
+ * proceeds by finding "quux", "baz quux", "bar baz quux", and finally
+ * "foo bar baz quux" and the trie needs pointer nodes anyway, blanks are
+ * inserted. The blanks have probability kBlankProb and backoff kBlankBackoff.
+ * A blank is recognized by kBlankProb in the probability field; kBlankBackoff
+ * must be 0 so that inference asseses zero backoff from these blanks.
+ */
+const float kBlankProb = -std::numeric_limits<float>::infinity();
+const float kBlankBackoff = kNoExtensionBackoff;
} // namespace ngram
} // namespace lm
diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc
index b340797b..144c57e0 100644
--- a/klm/lm/build_binary.cc
+++ b/klm/lm/build_binary.cc
@@ -21,7 +21,7 @@ void Usage(const char *name) {
"memory and is still faster than SRI or IRST. Building the trie format uses an\n"
"on-disk sort to save memory.\n"
"-t is the temporary directory prefix. Default is the output file name.\n"
-"-m is the amount of memory to use, in MB. Default is 1024MB (1GB).\n\n"
+"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n\n"
/*"sorted is like probing but uses a sorted uniform map instead of a hash table.\n"
"It uses more memory than trie and is also slower, so there's no real reason to\n"
"use it.\n\n"*/
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index c7ba4908..146fe07b 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -61,10 +61,10 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
// Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
util::FilePiece f(backing_.file.release(), file, config.messages);
std::vector<uint64_t> counts;
- // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed with search_.VariableSizeLoad
+ // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
ReadARPACounts(f, counts);
- if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile.");
+ if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile.");
if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
@@ -114,7 +114,24 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const {
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, out_state);
- ret.prob += SlowBackoffLookup(context_rbegin, context_rend, ret.ngram_length);
+
+ // Add the backoff weights for n-grams of order start to (context_rend - context_rbegin).
+ unsigned char start = ret.ngram_length;
+ if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return ret;
+ if (start <= 1) {
+ ret.prob += search_.unigram.Lookup(*context_rbegin).backoff;
+ start = 2;
+ }
+ typename Search::Node node;
+ if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) {
+ return ret;
+ }
+ float backoff;
+ // i is the order of the backoff we're looking for.
+ for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) {
+ if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break;
+ ret.prob += backoff;
+ }
return ret;
}
@@ -128,8 +145,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
float ignored_prob;
typename Search::Node node;
search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node);
- // Tricky part is that an entry might be blank, but out_state.valid_length_ always has the last non-blank n-gram length.
- out_state.valid_length_ = 1;
+ out_state.valid_length_ = HasExtension(out_state.backoff_[0]) ? 1 : 0;
float *backoff_out = out_state.backoff_ + 1;
const typename Search::Middle *mid = &*search_.middle.begin();
for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) {
@@ -137,36 +153,21 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_);
return;
}
- if (*backoff_out != kBlankBackoff) {
- out_state.valid_length_ = i - context_rbegin + 1;
- } else {
- *backoff_out = 0.0;
- }
+ if (HasExtension(*backoff_out)) out_state.valid_length_ = i - context_rbegin + 1;
}
std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_);
}
-template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::SlowBackoffLookup(
- const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const {
- // Add the backoff weights for n-grams of order start to (context_rend - context_rbegin).
- if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return 0.0;
- float ret = 0.0;
- if (start == 1) {
- ret += search_.unigram.Lookup(*context_rbegin).backoff;
- start = 2;
- }
- typename Search::Node node;
- if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) {
- return 0.0;
- }
- float backoff;
- // i is the order of the backoff we're looking for.
- for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) {
- if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break;
- if (backoff != kBlankBackoff) ret += backoff;
- }
- return ret;
+namespace {
+// Do a paraonoid copy of history, assuming new_word has already been copied
+// (hence the -1). out_state.valid_length_ could be zero so I avoided using
+// std::copy.
+void CopyRemainingHistory(const WordIndex *from, State &out_state) {
+ WordIndex *out = out_state.history_ + 1;
+ const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.valid_length_) - 1;
+ for (const WordIndex *in = from; in < in_end; ++in, ++out) *out = *in;
}
+} // namespace
/* Ugly optimized function. Produce a score excluding backoff.
* The search goes in increasing order of ngram length.
@@ -179,28 +180,26 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
const WordIndex new_word,
State &out_state) const {
FullScoreReturn ret;
- // ret.ngram_length contains the last known good (non-blank) ngram length.
+ // ret.ngram_length contains the last known non-blank ngram length.
ret.ngram_length = 1;
typename Search::Node node;
float *backoff_out(out_state.backoff_);
search_.LookupUnigram(new_word, ret.prob, *backoff_out, node);
+ // This is the length of the context that should be used for continuation.
+ out_state.valid_length_ = HasExtension(*backoff_out) ? 1 : 0;
+ // We'll write the word anyway since it will probably be used and does no harm being there.
out_state.history_[0] = new_word;
- if (context_rbegin == context_rend) {
- out_state.valid_length_ = 1;
- return ret;
- }
+ if (context_rbegin == context_rend) return ret;
++backoff_out;
// Ok now we now that the bigram contains known words. Start by looking it up.
-
const WordIndex *hist_iter = context_rbegin;
typename std::vector<Middle>::const_iterator mid_iter = search_.middle.begin();
for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
if (hist_iter == context_rend) {
// Ran out of history. Typically no backoff, but this could be a blank.
- out_state.valid_length_ = ret.ngram_length;
- std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1);
+ CopyRemainingHistory(context_rbegin, out_state);
// ret.prob was already set.
return ret;
}
@@ -210,32 +209,32 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
float revert = ret.prob;
if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) {
// Didn't find an ngram using hist_iter.
- std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1);
- out_state.valid_length_ = ret.ngram_length;
+ CopyRemainingHistory(context_rbegin, out_state);
// ret.prob was already set.
return ret;
}
- if (*backoff_out == kBlankBackoff) {
- *backoff_out = 0.0;
+ if (ret.prob == kBlankProb) {
+ // It's a blank. Go back to the old probability.
ret.prob = revert;
} else {
ret.ngram_length = hist_iter - context_rbegin + 2;
+ if (HasExtension(*backoff_out)) {
+ out_state.valid_length_ = ret.ngram_length;
+ }
}
}
// It passed every lookup in search_.middle. All that's left is to check search_.longest.
if (!search_.LookupLongest(*hist_iter, ret.prob, node)) {
- //assert(ret.ngram_length <= P::Order() - 1);
- out_state.valid_length_ = ret.ngram_length;
- std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1);
+ // Failed to find a longest n-gram. Fall back to the most recent non-blank.
+ CopyRemainingHistory(context_rbegin, out_state);
// ret.prob was already set.
return ret;
}
- // It's an P::Order()-gram. There is no blank in longest_.
- // out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much.
- std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1);
- out_state.valid_length_ = P::Order() - 1;
+ // It's an P::Order()-gram.
+ CopyRemainingHistory(context_rbegin, out_state);
+ // There is no blank in longest_.
ret.ngram_length = P::Order();
return ret;
}
diff --git a/klm/lm/model.hh b/klm/lm/model.hh
index 8183bdf5..fd9640c3 100644
--- a/klm/lm/model.hh
+++ b/klm/lm/model.hh
@@ -4,6 +4,7 @@
#include "lm/binary_format.hh"
#include "lm/config.hh"
#include "lm/facade.hh"
+#include "lm/max_order.hh"
#include "lm/search_hashed.hh"
#include "lm/search_trie.hh"
#include "lm/vocab.hh"
@@ -19,12 +20,6 @@ namespace util { class FilePiece; }
namespace lm {
namespace ngram {
-// If you need higher order, change this and recompile.
-// Having this limit means that State can be
-// (kMaxOrder - 1) * sizeof(float) bytes instead of
-// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
-const unsigned char kMaxOrder = 6;
-
// This is a POD but if you want memcmp to return the same as operator==, call
// ZeroRemaining first.
class State {
@@ -56,6 +51,8 @@ class State {
}
}
+ unsigned char ValidLength() const { return valid_length_; }
+
// You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
// This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
WordIndex history_[kMaxOrder - 1];
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
index 89bbf2e8..548c098d 100644
--- a/klm/lm/model_test.cc
+++ b/klm/lm/model_test.cc
@@ -8,6 +8,15 @@
namespace lm {
namespace ngram {
+
+std::ostream &operator<<(std::ostream &o, const State &state) {
+ o << "State length " << static_cast<unsigned int>(state.valid_length_) << ':';
+ for (const WordIndex *i = state.history_; i < state.history_ + state.valid_length_; ++i) {
+ o << ' ' << *i;
+ }
+ return o;
+}
+
namespace {
#define StartTest(word, ngram, score) \
@@ -17,7 +26,15 @@ namespace {
out);\
BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
- BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_);
+ BOOST_CHECK_GE(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_); \
+ {\
+ WordIndex context[state.valid_length_ + 1]; \
+ context[0] = model.GetVocabulary().Index(word); \
+ std::copy(state.history_, state.history_ + state.valid_length_, context + 1); \
+ State get_state; \
+ model.GetState(context, context + state.valid_length_ + 1, get_state); \
+ BOOST_CHECK_EQUAL(out, get_state); \
+ }
#define AppendTest(word, ngram, score) \
StartTest(word, ngram, score) \
@@ -52,10 +69,13 @@ template <class M> void Continuation(const M &model) {
AppendTest("more", 1, -1.20632 - 20.0);
AppendTest(".", 2, -0.51363);
AppendTest("</s>", 3, -0.0191651);
+ BOOST_CHECK_EQUAL(0, state.valid_length_);
state = preserve;
AppendTest("more", 5, -0.00181395);
+ BOOST_CHECK_EQUAL(4, state.valid_length_);
AppendTest("loin", 5, -0.0432557);
+ BOOST_CHECK_EQUAL(1, state.valid_length_);
}
template <class M> void Blanks(const M &model) {
@@ -68,6 +88,7 @@ template <class M> void Blanks(const M &model) {
State preserve = state;
AppendTest("higher", 4, -4);
AppendTest("looking", 5, -5);
+ BOOST_CHECK_EQUAL(1, state.valid_length_);
state = preserve;
AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103);
@@ -94,6 +115,29 @@ template <class M> void Unknowns(const M &model) {
AppendTest("not_found3", 3, -6);
}
+template <class M> void MinimalState(const M &model) {
+ FullScoreReturn ret;
+ State state(model.NullContextState());
+ State out;
+
+ AppendTest("baz", 1, -6.535897);
+ BOOST_CHECK_EQUAL(0, state.valid_length_);
+ state = model.NullContextState();
+ AppendTest("foo", 1, -3.141592);
+ BOOST_CHECK_EQUAL(1, state.valid_length_);
+ AppendTest("bar", 2, -6.0);
+ // Has to include the backoff weight.
+ BOOST_CHECK_EQUAL(1, state.valid_length_);
+ AppendTest("bar", 1, -2.718281 + 3.0);
+ BOOST_CHECK_EQUAL(1, state.valid_length_);
+
+ state = model.NullContextState();
+ AppendTest("to", 1, -1.687872);
+ AppendTest("look", 2, -0.2922095);
+ BOOST_CHECK_EQUAL(2, state.valid_length_);
+ AppendTest("good", 3, -7);
+}
+
#define StatelessTest(word, provide, ngram, score) \
ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \
BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
@@ -154,6 +198,7 @@ template <class M> void Everything(const M &m) {
Continuation(m);
Blanks(m);
Unknowns(m);
+ MinimalState(m);
Stateless(m);
}
@@ -167,7 +212,7 @@ class ExpectEnumerateVocab : public EnumerateVocab {
}
void Check(const base::Vocabulary &vocab) {
- BOOST_CHECK_EQUAL(34ULL, seen.size());
+ BOOST_CHECK_EQUAL(37ULL, seen.size());
BOOST_REQUIRE(!seen.empty());
BOOST_CHECK_EQUAL("<unk>", seen[0]);
for (WordIndex i = 0; i < seen.size(); ++i) {
diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc
index 3fa8cb03..d6da02e3 100644
--- a/klm/lm/ngram_query.cc
+++ b/klm/lm/ngram_query.cc
@@ -6,6 +6,8 @@
#include <iostream>
#include <string>
+#include <ctype.h>
+
#include <sys/resource.h>
#include <sys/time.h>
@@ -43,35 +45,38 @@ template <class Model> void Query(const Model &model) {
state = model.BeginSentenceState();
float total = 0.0;
bool got = false;
+ unsigned int oov = 0;
while (std::cin >> word) {
got = true;
lm::WordIndex vocab = model.GetVocabulary().Index(word);
+ if (vocab == 0) ++oov;
ret = model.FullScore(state, vocab, out);
total += ret.prob;
std::cout << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\n';
state = out;
- if (std::cin.get() == '\n') break;
+ char c;
+ while (true) {
+ c = std::cin.get();
+ if (!std::cin) break;
+ if (c == '\n') break;
+ if (!isspace(c)) {
+ std::cin.unget();
+ break;
+ }
+ }
+ if (c == '\n') break;
}
if (!got && !std::cin) break;
ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
total += ret.prob;
std::cout << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\n';
- std::cout << "Total: " << total << '\n';
+ std::cout << "Total: " << total << " OOV: " << oov << '\n';
}
PrintUsage("After queries:\n");
}
-class PrintVocab : public lm::ngram::EnumerateVocab {
- public:
- void Add(lm::WordIndex index, const StringPiece &str) {
- std::cerr << "vocab " << index << ' ' << str << '\n';
- }
-};
-
template <class Model> void Query(const char *name) {
lm::ngram::Config config;
- PrintVocab printer;
- config.enumerate_vocab = &printer;
Model model(name, config);
Query(model);
}
diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc
index 262a9c6a..d0fe67f0 100644
--- a/klm/lm/read_arpa.cc
+++ b/klm/lm/read_arpa.cc
@@ -1,5 +1,7 @@
#include "lm/read_arpa.hh"
+#include "lm/blank.hh"
+
#include <cstdlib>
#include <vector>
@@ -8,6 +10,9 @@
namespace lm {
+// 1 for '\t', '\n', and ' '. This is stricter than isspace.
+const bool kARPASpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
+
namespace {
bool IsEntirelyWhiteSpace(const StringPiece &line) {
@@ -116,21 +121,27 @@ void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) {
case '\n':
break;
default:
- UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram");
+ UTIL_THROW(FormatLoadException, "Expected tab or newline for backoff");
}
}
void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) {
+ // Always make zero negative.
+ // Negative zero means that no (n+1)-gram has this n-gram as context.
+ // Therefore the hypothesis state can be shorter. Of course, many n-grams
+ // are context for (n+1)-grams. An algorithm in the data structure will go
+ // back and set the backoff to positive zero in these cases.
switch (in.get()) {
case '\t':
weights.backoff = in.ReadFloat();
+ if (weights.backoff == ngram::kExtensionBackoff) weights.backoff = ngram::kNoExtensionBackoff;
if ((in.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff");
break;
case '\n':
- weights.backoff = 0.0;
+ weights.backoff = ngram::kNoExtensionBackoff;
break;
default:
- UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram");
+ UTIL_THROW(FormatLoadException, "Expected tab or newline for backoff");
}
}
diff --git a/klm/lm/read_arpa.hh b/klm/lm/read_arpa.hh
index 571fcbc5..4efdd29d 100644
--- a/klm/lm/read_arpa.hh
+++ b/klm/lm/read_arpa.hh
@@ -23,12 +23,14 @@ void ReadBackoff(util::FilePiece &in, ProbBackoff &weights);
void ReadEnd(util::FilePiece &in);
void ReadEnd(std::istream &in);
+extern const bool kARPASpaces[256];
+
template <class Voc> void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff *unigrams) {
try {
float prob = f.ReadFloat();
if (prob > 0) UTIL_THROW(FormatLoadException, "Positive probability " << prob);
if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability");
- ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited())];
+ ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited(kARPASpaces))];
value.prob = prob;
ReadBackoff(f, value);
} catch(util::Exception &e) {
@@ -50,7 +52,7 @@ template <class Voc, class Weights> void ReadNGram(util::FilePiece &f, const uns
weights.prob = f.ReadFloat();
if (weights.prob > 0) UTIL_THROW(FormatLoadException, "Positive probability " << weights.prob);
for (WordIndex *vocab_out = reverse_indices + n - 1; vocab_out >= reverse_indices; --vocab_out) {
- *vocab_out = vocab.Index(f.ReadDelimited());
+ *vocab_out = vocab.Index(f.ReadDelimited(kARPASpaces));
}
ReadBackoff(f, weights);
} catch(util::Exception &e) {
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index 9200aeb6..00d03f4e 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -14,7 +14,41 @@ namespace ngram {
namespace {
-template <class Voc, class Store, class Middle> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector<Middle> &middle, Store &store) {
+/* These are passed to ReadNGrams so that n-grams with zero backoff that appear as context will still be used in state. */
+template <class Middle> class ActivateLowerMiddle {
+ public:
+ explicit ActivateLowerMiddle(Middle &middle) : modify_(middle) {}
+
+ void operator()(const WordIndex *vocab_ids, const unsigned int n) {
+ uint64_t hash = static_cast<WordIndex>(vocab_ids[1]);
+ for (const WordIndex *i = vocab_ids + 2; i < vocab_ids + n; ++i) {
+ hash = detail::CombineWordHash(hash, *i);
+ }
+ typename Middle::MutableIterator i;
+ // TODO: somehow get text of n-gram for this error message.
+ if (!modify_.UnsafeMutableFind(hash, i))
+ UTIL_THROW(FormatLoadException, "The context of every " << n << "-gram should appear as a " << (n-1) << "-gram");
+ SetExtension(i->MutableValue().backoff);
+ }
+
+ private:
+ Middle &modify_;
+};
+
+class ActivateUnigram {
+ public:
+ explicit ActivateUnigram(ProbBackoff *unigram) : modify_(unigram) {}
+
+ void operator()(const WordIndex *vocab_ids, const unsigned int /*n*/) {
+ // assert(n == 2);
+ SetExtension(modify_[vocab_ids[1]].backoff);
+ }
+
+ private:
+ ProbBackoff *modify_;
+};
+
+template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector<Middle> &middle, Activate activate, Store &store) {
ReadNGramHeader(f, n);
ProbBackoff blank;
@@ -38,6 +72,7 @@ template <class Voc, class Store, class Middle> void ReadNGrams(util::FilePiece
if (middle[lower].Find(keys[lower], found)) break;
middle[lower].Insert(Middle::Packing::Make(keys[lower], blank));
}
+ activate(vocab_ids, n);
}
store.FinishedInserting();
@@ -53,12 +88,19 @@ template <class MiddleT, class LongestT> template <class Voc> void TemplateHashe
Read1Grams(f, counts[0], vocab, unigram.Raw());
try {
- for (unsigned int n = 2; n < counts.size(); ++n) {
- ReadNGrams(f, n, counts[n-1], vocab, middle, middle[n-2]);
+ if (counts.size() > 2) {
+ ReadNGrams(f, 2, counts[1], vocab, middle, ActivateUnigram(unigram.Raw()), middle[0]);
+ }
+ for (unsigned int n = 3; n < counts.size(); ++n) {
+ ReadNGrams(f, n, counts[n-1], vocab, middle, ActivateLowerMiddle<Middle>(middle[n-3]), middle[n-2]);
+ }
+ if (counts.size() > 2) {
+ ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateUnigram(unigram.Raw()), longest);
+ } else {
+ ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateLowerMiddle<Middle>(middle.back()), longest);
}
- ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, longest);
} catch (util::ProbingSizeException &e) {
- UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces. ");
+ UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n");
}
}
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 3aeeeca3..1060ddef 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -3,6 +3,7 @@
#include "lm/blank.hh"
#include "lm/lm_exception.hh"
+#include "lm/max_order.hh"
#include "lm/read_arpa.hh"
#include "lm/trie.hh"
#include "lm/vocab.hh"
@@ -27,6 +28,7 @@
#include <sys/stat.h>
#include <fcntl.h>
#include <stdlib.h>
+#include <unistd.h>
namespace lm {
namespace ngram {
@@ -98,7 +100,7 @@ class EntryProxy {
}
const WordIndex *Indices() const {
- return static_cast<const WordIndex*>(inner_.Data());
+ return reinterpret_cast<const WordIndex*>(inner_.Data());
}
private:
@@ -114,17 +116,57 @@ class EntryProxy {
typedef util::ProxyIterator<EntryProxy> NGramIter;
-class CompareRecords : public std::binary_function<const EntryProxy &, const EntryProxy &, bool> {
+// Proxy for an entry except there is some extra cruft between the entries. This is used to sort (n-1)-grams using the same memory as the sorted n-grams.
+class PartialViewProxy {
+ public:
+ PartialViewProxy() : attention_size_(0), inner_() {}
+
+ PartialViewProxy(void *ptr, std::size_t block_size, std::size_t attention_size) : attention_size_(attention_size), inner_(ptr, block_size) {}
+
+ operator std::string() const {
+ return std::string(reinterpret_cast<const char*>(inner_.Data()), attention_size_);
+ }
+
+ PartialViewProxy &operator=(const PartialViewProxy &from) {
+ memcpy(inner_.Data(), from.inner_.Data(), attention_size_);
+ return *this;
+ }
+
+ PartialViewProxy &operator=(const std::string &from) {
+ memcpy(inner_.Data(), from.data(), attention_size_);
+ return *this;
+ }
+
+ const WordIndex *Indices() const {
+ return reinterpret_cast<const WordIndex*>(inner_.Data());
+ }
+
+ private:
+ friend class util::ProxyIterator<PartialViewProxy>;
+
+ typedef std::string value_type;
+
+ const std::size_t attention_size_;
+
+ typedef EntryIterator InnerIterator;
+ InnerIterator &Inner() { return inner_; }
+ const InnerIterator &Inner() const { return inner_; }
+ InnerIterator inner_;
+};
+
+typedef util::ProxyIterator<PartialViewProxy> PartialIter;
+
+template <class Proxy> class CompareRecords : public std::binary_function<const Proxy &, const Proxy &, bool> {
public:
explicit CompareRecords(unsigned char order) : order_(order) {}
- bool operator()(const EntryProxy &first, const EntryProxy &second) const {
+ bool operator()(const Proxy &first, const Proxy &second) const {
return Compare(first.Indices(), second.Indices());
}
- bool operator()(const EntryProxy &first, const std::string &second) const {
+ bool operator()(const Proxy &first, const std::string &second) const {
return Compare(first.Indices(), reinterpret_cast<const WordIndex*>(second.data()));
}
- bool operator()(const std::string &first, const EntryProxy &second) const {
+ bool operator()(const std::string &first, const Proxy &second) const {
return Compare(reinterpret_cast<const WordIndex*>(first.data()), second.Indices());
}
bool operator()(const std::string &first, const std::string &second) const {
@@ -144,6 +186,12 @@ class CompareRecords : public std::binary_function<const EntryProxy &, const Ent
unsigned char order_;
};
+FILE *OpenOrThrow(const char *name, const char *mode) {
+ FILE *ret = fopen(name, mode);
+ if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode);
+ return ret;
+}
+
void WriteOrThrow(FILE *to, const void *data, size_t size) {
assert(size);
if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
@@ -163,14 +211,26 @@ void CopyOrThrow(FILE *from, FILE *to, size_t size) {
}
}
+void CopyRestOrThrow(FILE *from, FILE *to) {
+ char buf[kCopyBufSize];
+ size_t amount;
+ while ((amount = fread(buf, 1, kCopyBufSize, from))) {
+ WriteOrThrow(to, buf, amount);
+ }
+ if (!feof(from)) UTIL_THROW(util::ErrnoException, "Short read");
+}
+
+void RemoveOrThrow(const char *name) {
+ if (std::remove(name)) UTIL_THROW(util::ErrnoException, "Could not remove " << name);
+}
+
std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order, std::size_t weights_size) {
const std::size_t entry_size = sizeof(WordIndex) * order + weights_size;
const std::size_t prefix_size = sizeof(WordIndex) * (order - 1);
std::stringstream assembled;
assembled << file_prefix << static_cast<unsigned int>(order) << '_' << batch;
std::string ret(assembled.str());
- util::scoped_FILE out(fopen(ret.c_str(), "w"));
- if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing");
+ util::scoped_FILE out(OpenOrThrow(ret.c_str(), "w"));
// Compress entries that being with the same (order-1) words.
for (const uint8_t *group_begin = static_cast<const uint8_t*>(mem_begin); group_begin != static_cast<const uint8_t*>(mem_end);) {
const uint8_t *group_end;
@@ -194,8 +254,7 @@ class SortedFileReader {
SortedFileReader() : ended_(false) {}
void Init(const std::string &name, unsigned char order) {
- file_.reset(fopen(name.c_str(), "r"));
- if (!file_.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " for read");
+ file_.reset(OpenOrThrow(name.c_str(), "r"));
header_.resize(order - 1);
NextHeader();
}
@@ -262,12 +321,13 @@ void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size)
CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count);
}
-void MergeSortedFiles(const char *first_name, const char *second_name, const char *out, std::size_t weights_size, unsigned char order) {
+void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order) {
SortedFileReader first, second;
- first.Init(first_name, order);
- second.Init(second_name, order);
- util::scoped_FILE out_file(fopen(out, "w"));
- if (!out_file.get()) UTIL_THROW(util::ErrnoException, "Could not open " << out << " for write");
+ first.Init(first_name.c_str(), order);
+ RemoveOrThrow(first_name.c_str());
+ second.Init(second_name.c_str(), order);
+ RemoveOrThrow(second_name.c_str());
+ util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w"));
while (!first.Ended() && !second.Ended()) {
if (first.HeaderVector() < second.HeaderVector()) {
CopyFullRecord(first, out_file.get(), weights_size);
@@ -316,10 +376,109 @@ void MergeSortedFiles(const char *first_name, const char *second_name, const cha
}
}
-void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) {
- if (order == 1) return;
- ConvertToSorted(f, vocab, counts, mem, file_prefix, order - 1);
+const char *kContextSuffix = "_contexts";
+
+void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) {
+ const size_t context_size = sizeof(WordIndex) * (order - 1);
+ // Sort just the contexts using the same memory.
+ PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size));
+ PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size));
+
+ // TODO: __gnu_parallel::sort here.
+ std::sort(context_begin, context_end, CompareRecords<PartialViewProxy>(order - 1));
+
+ std::string name(ngram_file_name + kContextSuffix);
+ util::scoped_FILE out(OpenOrThrow(name.c_str(), "w"));
+
+ // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator.
+ if (context_begin == context_end) return;
+ PartialIter i(context_begin);
+ WriteOrThrow(out.get(), i->Indices(), context_size);
+ const WordIndex *previous = i->Indices();
+ ++i;
+ for (; i != context_end; ++i) {
+ if (memcmp(previous, i->Indices(), context_size)) {
+ WriteOrThrow(out.get(), i->Indices(), context_size);
+ previous = i->Indices();
+ }
+ }
+}
+class ContextReader {
+ public:
+ ContextReader() : length_(0) {}
+
+ ContextReader(const char *name, size_t length) : file_(OpenOrThrow(name, "r")), length_(length), words_(length), valid_(true) {
+ ++*this;
+ }
+
+ void Reset(const char *name, size_t length) {
+ file_.reset(OpenOrThrow(name, "r"));
+ length_ = length;
+ words_.resize(length);
+ valid_ = true;
+ ++*this;
+ }
+
+ ContextReader &operator++() {
+ if (1 != fread(&*words_.begin(), length_, 1, file_.get())) {
+ if (!feof(file_.get()))
+ UTIL_THROW(util::ErrnoException, "Short read");
+ valid_ = false;
+ }
+ return *this;
+ }
+
+ const WordIndex *operator*() const { return &*words_.begin(); }
+
+ operator bool() const { return valid_; }
+
+ FILE *GetFile() { return file_.get(); }
+
+ private:
+ util::scoped_FILE file_;
+
+ size_t length_;
+
+ std::vector<WordIndex> words_;
+
+ bool valid_;
+};
+
+void MergeContextFiles(const std::string &first_base, const std::string &second_base, const std::string &out_base, unsigned char order) {
+ const size_t context_size = sizeof(WordIndex) * (order - 1);
+ std::string first_name(first_base + kContextSuffix);
+ std::string second_name(second_base + kContextSuffix);
+ ContextReader first(first_name.c_str(), context_size), second(second_name.c_str(), context_size);
+ RemoveOrThrow(first_name.c_str());
+ RemoveOrThrow(second_name.c_str());
+ std::string out_name(out_base + kContextSuffix);
+ util::scoped_FILE out(OpenOrThrow(out_name.c_str(), "w"));
+ while (first && second) {
+ for (const WordIndex *f = *first, *s = *second; ; ++f, ++s) {
+ if (f == *first + order) {
+ // Equal.
+ WriteOrThrow(out.get(), *first, context_size);
+ ++first;
+ ++second;
+ break;
+ }
+ if (*f < *s) {
+ // First lower
+ WriteOrThrow(out.get(), *first, context_size);
+ ++first;
+ break;
+ } else if (*f > *s) {
+ WriteOrThrow(out.get(), *second, context_size);
+ ++second;
+ break;
+ }
+ }
+ }
+ CopyRestOrThrow((first ? first : second).GetFile(), out.get());
+}
+
+void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) {
ReadNGramHeader(f, order);
const size_t count = counts[order - 1];
// Size of weights. Does it include backoff?
@@ -341,11 +500,13 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size));
}
}
- // TODO: __gnu_parallel::sort here.
+ // Sort full records by full n-gram.
EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size);
- std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords(order));
-
+ // TODO: __gnu_parallel::sort here.
+ std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords<EntryProxy>(order));
files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size));
+ WriteContextFile(begin, out_end, files.back(), entry_size, order);
+
done += (out_end - begin) / entry_size;
}
@@ -356,10 +517,9 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
std::stringstream assembled;
assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++);
files.push_back(assembled.str());
- MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), weights_size, order);
- if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]);
+ MergeSortedFiles(files[0], files[1], files.back(), weights_size, order);
+ MergeContextFiles(files[0], files[1], files.back(), order);
files.pop_front();
- if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]);
files.pop_front();
}
if (!files.empty()) {
@@ -367,6 +527,9 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
assembled << file_prefix << static_cast<unsigned int>(order) << "_merged";
std::string merged_name(assembled.str());
if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str());
+ std::string context_name = files[0] + kContextSuffix;
+ merged_name += kContextSuffix;
+ if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str());
}
}
@@ -378,26 +541,38 @@ void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts,
Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()));
}
+ // Only use as much buffer as we need.
+ size_t buffer_use = 0;
+ for (unsigned int order = 2; order < counts.size(); ++order) {
+ buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1]));
+ }
+ buffer_use = std::max(buffer_use, size_t((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back()));
+ buffer = std::min(buffer, buffer_use);
+
util::scoped_memory mem;
mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED);
if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);
- ConvertToSorted(f, vocab, counts, mem, file_prefix, counts.size());
+
+ for (unsigned char order = 2; order <= counts.size(); ++order) {
+ ConvertToSorted(f, vocab, counts, mem, file_prefix, order);
+ }
ReadEnd(f);
}
bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const WordIndex *header) {
for (; words != words_end; ++words, ++header) {
if (*words != *header) {
- assert(*words <= *header);
+ //assert(*words <= *header);
return false;
}
}
return true;
}
+// Counting phrase
class JustCount {
public:
- JustCount(UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order)
+ JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order)
: counts_(counts), longest_counts_(counts + order - 1) {}
void Unigrams(WordIndex begin, WordIndex end) {
@@ -408,7 +583,7 @@ class JustCount {
++counts_[mid_idx + 1];
}
- void Middle(const unsigned char mid_idx, WordIndex /*key*/, const ProbBackoff &/*weights*/) {
+ void Middle(const unsigned char mid_idx, const WordIndex * /*before*/, WordIndex /*key*/, const ProbBackoff &/*weights*/) {
++counts_[mid_idx + 1];
}
@@ -427,7 +602,8 @@ class JustCount {
class WriteEntries {
public:
- WriteEntries(UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) :
+ WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) :
+ contexts_(contexts),
unigrams_(unigrams),
middle_(middle),
longest_(longest),
@@ -444,7 +620,13 @@ class WriteEntries {
middle_[mid_idx].Insert(key, kBlankProb, kBlankBackoff);
}
- void Middle(const unsigned char mid_idx, WordIndex key, const ProbBackoff &weights) {
+ void Middle(const unsigned char mid_idx, const WordIndex *before, WordIndex key, ProbBackoff weights) {
+ // Order (mid_idx+2).
+ ContextReader &context = contexts_[mid_idx + 1];
+ if (context && !memcmp(before, *context, sizeof(WordIndex) * (mid_idx + 1)) && (*context)[mid_idx + 1] == key) {
+ SetExtension(weights.backoff);
+ ++context;
+ }
middle_[mid_idx].Insert(key, weights.prob, weights.backoff);
}
@@ -455,6 +637,7 @@ class WriteEntries {
void Cleanup() {}
private:
+ ContextReader *contexts_;
UnigramValue *const unigrams_;
BitPackedMiddle *const middle_;
BitPackedLongest &longest_;
@@ -463,14 +646,15 @@ class WriteEntries {
template <class Doing> class RecursiveInsert {
public:
- RecursiveInsert(SortedFileReader *inputs, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) :
- doing_(unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), words_(new WordIndex[order]), order_minus_2_(order - 2) {
+ RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) :
+ doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) {
}
// Outer unigram loop.
void Apply(std::ostream *progress_out, const char *message, WordIndex unigram_count) {
util::ErsatzProgress progress(progress_out, message, unigram_count + 1);
for (words_[0] = 0; ; ++words_[0]) {
+ progress.Set(words_[0]);
WordIndex min_continue = unigram_count;
for (SortedFileReader *other = inputs_; other != inputs_end_; ++other) {
if (other->Ended()) continue;
@@ -479,7 +663,6 @@ template <class Doing> class RecursiveInsert {
// This will write at unigram_count. This is by design so that the next pointers will make sense.
doing_.Unigrams(words_[0], min_continue + 1);
if (min_continue == unigram_count) break;
- progress += min_continue - words_[0];
words_[0] = min_continue;
Middle(0);
}
@@ -497,7 +680,7 @@ template <class Doing> class RecursiveInsert {
SortedFileReader &reader = inputs_[mid_idx];
- if (reader.Ended() || !HeadMatch(words_.get(), words_.get() + mid_idx + 1, reader.Header())) {
+ if (reader.Ended() || !HeadMatch(words_, words_ + mid_idx + 1, reader.Header())) {
// This order doesn't have a header match, but longer ones might.
MiddleAllBlank(mid_idx);
return;
@@ -509,7 +692,7 @@ template <class Doing> class RecursiveInsert {
while (count) {
WordIndex min_continue = std::numeric_limits<WordIndex>::max();
for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) {
- if (!other->Ended() && HeadMatch(words_.get(), words_.get() + mid_idx + 1, other->Header()))
+ if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header()))
min_continue = std::min(min_continue, other->Header()[mid_idx + 1]);
}
while (true) {
@@ -521,7 +704,7 @@ template <class Doing> class RecursiveInsert {
}
ProbBackoff weights;
reader.ReadWeights(weights);
- doing_.Middle(mid_idx, current, weights);
+ doing_.Middle(mid_idx, words_, current, weights);
--count;
if (current == min_continue) {
words_[mid_idx + 1] = min_continue;
@@ -542,7 +725,7 @@ template <class Doing> class RecursiveInsert {
while (true) {
WordIndex min_continue = std::numeric_limits<WordIndex>::max();
for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) {
- if (!other->Ended() && HeadMatch(words_.get(), words_.get() + mid_idx + 1, other->Header()))
+ if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header()))
min_continue = std::min(min_continue, other->Header()[mid_idx + 1]);
}
if (min_continue == std::numeric_limits<WordIndex>::max()) return;
@@ -554,7 +737,7 @@ template <class Doing> class RecursiveInsert {
void Longest() {
SortedFileReader &reader = *(inputs_end_ - 1);
- if (reader.Ended() || !HeadMatch(words_.get(), words_.get() + order_minus_2_ + 1, reader.Header())) return;
+ if (reader.Ended() || !HeadMatch(words_, words_ + order_minus_2_ + 1, reader.Header())) return;
WordIndex count = reader.ReadCount();
for (WordIndex i = 0; i < count; ++i) {
WordIndex word = reader.ReadWord();
@@ -571,7 +754,7 @@ template <class Doing> class RecursiveInsert {
SortedFileReader *inputs_;
SortedFileReader *inputs_end_;
- util::scoped_array<WordIndex> words_;
+ WordIndex words_[kMaxOrder];
const unsigned char order_minus_2_;
};
@@ -586,17 +769,21 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u
void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) {
SortedFileReader inputs[counts.size() - 1];
+ ContextReader contexts[counts.size() - 1];
for (unsigned char i = 2; i <= counts.size(); ++i) {
std::stringstream assembled;
assembled << file_prefix << static_cast<unsigned int>(i) << "_merged";
inputs[i-2].Init(assembled.str(), i);
- unlink(assembled.str().c_str());
+ RemoveOrThrow(assembled.str().c_str());
+ assembled << kContextSuffix;
+ contexts[i-2].Reset(assembled.str().c_str(), (i-1) * sizeof(WordIndex));
+ RemoveOrThrow(assembled.str().c_str());
}
std::vector<uint64_t> fixed_counts(counts.size());
{
- RecursiveInsert<JustCount> counter(inputs, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
+ RecursiveInsert<JustCount> counter(inputs, contexts, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]);
}
SanityCheckCounts(counts, fixed_counts);
@@ -609,21 +796,38 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun
UnigramValue *unigrams = out.unigram.Raw();
// Fill entries except unigram probabilities.
{
- RecursiveInsert<WriteEntries> inserter(inputs, unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
+ RecursiveInsert<WriteEntries> inserter(inputs, contexts, unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
inserter.Apply(config.messages, "Building trie", fixed_counts[0]);
}
// Fill unigram probabilities.
{
std::string name(file_prefix + "unigrams");
- util::scoped_FILE file(fopen(name.c_str(), "r"));
- if (!file.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " failed");
+ util::scoped_FILE file(OpenOrThrow(name.c_str(), "r"));
for (WordIndex i = 0; i < counts[0]; ++i) {
ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff));
+ if (contexts[0] && **contexts[0] == i) {
+ SetExtension(unigrams[i].weights.backoff);
+ ++contexts[0];
+ }
}
unlink(name.c_str());
}
+ // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
+ for (unsigned char order = 2; order <= counts.size(); ++order) {
+ const ContextReader &context = contexts[order - 2];
+ if (context) {
+ FormatLoadException e;
+ e << "An " << static_cast<unsigned int>(order) << "-gram has the context (i.e. all but the last word):";
+ for (const WordIndex *i = *context; i != *context + order - 1; ++i) {
+ e << ' ' << *i;
+ }
+ e << " so this context must appear in the model as a " << static_cast<unsigned int>(order - 1) << "-gram but it does not.";
+ throw e;
+ }
+ }
+
/* Set ending offsets so the last entry will be sized properly */
// Last entry for unigrams was already set.
if (!out.middle.empty()) {
@@ -634,19 +838,27 @@ void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &coun
}
}
+bool IsDirectory(const char *path) {
+ struct stat info;
+ if (0 != stat(path, &info)) return false;
+ return S_ISDIR(info.st_mode);
+}
+
} // namespace
void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
std::string temporary_directory;
if (config.temporary_directory_prefix) {
temporary_directory = config.temporary_directory_prefix;
+ if (!temporary_directory.empty() && temporary_directory[temporary_directory.size() - 1] != '/' && IsDirectory(temporary_directory.c_str()))
+ temporary_directory += '/';
} else if (config.write_mmap) {
temporary_directory = config.write_mmap;
} else {
temporary_directory = file;
}
// Null on end is kludge to ensure null termination.
- temporary_directory += "-tmp-XXXXXX";
+ temporary_directory += "_trie_tmp_XXXXXX";
temporary_directory += '\0';
if (!mkdtemp(&temporary_directory[0])) {
UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str());
@@ -658,7 +870,7 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::v
// At least 1MB sorting memory.
ARPAToSortedFiles(f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
- BuildTrie(temporary_directory.c_str(), counts, config, *this, backing);
+ BuildTrie(temporary_directory, counts, config, *this, backing);
if (rmdir(temporary_directory.c_str()) && config.messages) {
*config.messages << "Failed to delete " << temporary_directory << std::endl;
}
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index 8c99d797..b584c82f 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -65,7 +65,6 @@ class SortedVocabulary : public base::Vocabulary {
}
}
- // Ignores second argument for consistency with probing hash which has a float here.
static size_t Size(std::size_t entries, const Config &config);
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
diff --git a/klm/util/bit_packing.cc b/klm/util/bit_packing.cc
index 9d4fdf27..681da5f2 100644
--- a/klm/util/bit_packing.cc
+++ b/klm/util/bit_packing.cc
@@ -22,7 +22,7 @@ uint8_t RequiredBits(uint64_t max_value) {
}
void BitPackingSanity() {
- const detail::FloatEnc neg1 = { -1.0 }, pos1 = { 1.0 };
+ const FloatEnc neg1 = { -1.0 }, pos1 = { 1.0 };
if ((neg1.i ^ pos1.i) != 0x80000000) UTIL_THROW(Exception, "Sign bit is not 0x80000000");
char mem[57+8];
memset(mem, 0, sizeof(mem));
diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh
index 636547b1..70cfc2d2 100644
--- a/klm/util/bit_packing.hh
+++ b/klm/util/bit_packing.hh
@@ -53,29 +53,32 @@ inline void WriteInt57(void *base, uint8_t bit, uint8_t length, uint64_t value)
*reinterpret_cast<uint64_t*>(base) |= (value << BitPackShift(bit, length));
}
-namespace detail { typedef union { float f; uint32_t i; } FloatEnc; }
+typedef union { float f; uint32_t i; } FloatEnc;
+
inline float ReadFloat32(const void *base, uint8_t bit) {
- detail::FloatEnc encoded;
+ FloatEnc encoded;
encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 32);
return encoded.f;
}
inline void WriteFloat32(void *base, uint8_t bit, float value) {
- detail::FloatEnc encoded;
+ FloatEnc encoded;
encoded.f = value;
WriteInt57(base, bit, 32, encoded.i);
}
+const uint32_t kSignBit = 0x80000000;
+
inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) {
- detail::FloatEnc encoded;
+ FloatEnc encoded;
encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 31);
// Sign bit set means negative.
- encoded.i |= 0x80000000;
+ encoded.i |= kSignBit;
return encoded.f;
}
inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) {
- detail::FloatEnc encoded;
+ FloatEnc encoded;
encoded.f = value;
- encoded.i &= ~0x80000000;
+ encoded.i &= ~kSignBit;
WriteInt57(base, bit, 31, encoded.i);
}
diff --git a/klm/util/ersatz_progress.cc b/klm/util/ersatz_progress.cc
index 55c182bd..a82ce672 100644
--- a/klm/util/ersatz_progress.cc
+++ b/klm/util/ersatz_progress.cc
@@ -36,6 +36,7 @@ void ErsatzProgress::Milestone() {
if (stone == kWidth) {
(*out_) << std::endl;
next_ = std::numeric_limits<std::size_t>::max();
+ out_ = NULL;
} else {
next_ = std::max(next_, (stone * complete_) / kWidth);
}
diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc
index 5a667ebb..81eb9bb9 100644
--- a/klm/util/file_piece.cc
+++ b/klm/util/file_piece.cc
@@ -37,6 +37,9 @@ GZException::GZException(void *file) {
#endif // HAVE_ZLIB
}
+// Sigh this is the only way I could come up with to do a _const_ bool. It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale).
+const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
+
int OpenReadOrThrow(const char *name) {
int ret = open(name, O_RDONLY);
if (ret == -1) UTIL_THROW(ErrnoException, "in open (" << name << ") for reading");
@@ -107,13 +110,6 @@ unsigned long int FilePiece::ReadULong() throw(GZException, EndOfFileException,
return ReadNumber<unsigned long int>();
}
-void FilePiece::SkipSpaces() throw (GZException, EndOfFileException) {
- for (; ; ++position_) {
- if (position_ == position_end_) Shift();
- if (!isspace(*position_)) return;
- }
-}
-
void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) {
#ifdef HAVE_ZLIB
gz_file_ = NULL;
@@ -190,20 +186,6 @@ template <class T> T FilePiece::ReadNumber() throw(GZException, EndOfFileExcepti
return ret;
}
-const char *FilePiece::FindDelimiterOrEOF() throw (GZException, EndOfFileException) {
- for (const char *i = position_; i <= last_space_; ++i) {
- if (isspace(*i)) return i;
- }
- while (!at_end_) {
- size_t skip = position_end_ - position_;
- Shift();
- for (const char *i = position_ + skip; i <= last_space_; ++i) {
- if (isspace(*i)) return i;
- }
- }
- return position_end_;
-}
-
void FilePiece::Shift() throw(GZException, EndOfFileException) {
if (at_end_) {
progress_.Finished();
diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh
index b7697e71..f5249fcf 100644
--- a/klm/util/file_piece.hh
+++ b/klm/util/file_piece.hh
@@ -36,10 +36,13 @@ class GZException : public Exception {
int OpenReadOrThrow(const char *name);
+extern const bool kSpaces[256];
+
// Return value for SizeFile when it can't size properly.
const off_t kBadSize = -1;
off_t SizeFile(int fd);
+// Memory backing the returned StringPiece may vanish on the next call.
class FilePiece {
public:
// 32 MB default.
@@ -57,12 +60,12 @@ class FilePiece {
return *(position_++);
}
- // Memory backing the returned StringPiece may vanish on the next call.
- // Leaves the delimiter, if any, to be returned by get().
- StringPiece ReadDelimited() throw(GZException, EndOfFileException) {
- SkipSpaces();
- return Consume(FindDelimiterOrEOF());
+ // Leaves the delimiter, if any, to be returned by get(). Delimiters defined by isspace().
+ StringPiece ReadDelimited(const bool *delim = kSpaces) throw(GZException, EndOfFileException) {
+ SkipSpaces(delim);
+ return Consume(FindDelimiterOrEOF(delim));
}
+
// Unlike ReadDelimited, this includes leading spaces and consumes the delimiter.
// It is similar to getline in that way.
StringPiece ReadLine(char delim = '\n') throw(GZException, EndOfFileException);
@@ -72,7 +75,13 @@ class FilePiece {
long int ReadLong() throw(GZException, EndOfFileException, ParseNumberException);
unsigned long int ReadULong() throw(GZException, EndOfFileException, ParseNumberException);
- void SkipSpaces() throw (GZException, EndOfFileException);
+ // Skip spaces defined by isspace.
+ void SkipSpaces(const bool *delim = kSpaces) throw (GZException, EndOfFileException) {
+ for (; ; ++position_) {
+ if (position_ == position_end_) Shift();
+ if (!delim[static_cast<unsigned char>(*position_)]) return;
+ }
+ }
off_t Offset() const {
return position_ - data_.begin() + mapped_offset_;
@@ -91,7 +100,19 @@ class FilePiece {
return ret;
}
- const char *FindDelimiterOrEOF() throw(EndOfFileException, GZException);
+ const char *FindDelimiterOrEOF(const bool *delim = kSpaces) throw (GZException, EndOfFileException) {
+ for (const char *i = position_; i < position_end_; ++i) {
+ if (delim[static_cast<unsigned char>(*i)]) return i;
+ }
+ while (!at_end_) {
+ size_t skip = position_end_ - position_;
+ Shift();
+ for (const char *i = position_ + skip; i < position_end_; ++i) {
+ if (delim[static_cast<unsigned char>(*i)]) return i;
+ }
+ }
+ return position_end_;
+ }
void Shift() throw (EndOfFileException, GZException);
// Backends to Shift().
diff --git a/klm/util/key_value_packing.hh b/klm/util/key_value_packing.hh
index 450512ac..b84a5aad 100644
--- a/klm/util/key_value_packing.hh
+++ b/klm/util/key_value_packing.hh
@@ -18,6 +18,8 @@ template <class Key, class Value> struct Entry {
const Key &GetKey() const { return key; }
const Value &GetValue() const { return value; }
+ Value &MutableValue() { return value; }
+
void Set(const Key &key_in, const Value &value_in) {
SetKey(key_in);
SetValue(value_in);
@@ -77,6 +79,8 @@ template <class KeyT, class ValueT> class ByteAlignedPacking {
const Key &GetKey() const { return key; }
const Value &GetValue() const { return value; }
+ Value &MutableValue() { return value; }
+
void Set(const Key &key_in, const Value &value_in) {
SetKey(key_in);
SetValue(value_in);
diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh
index 7b5cdc22..00be0ed7 100644
--- a/klm/util/probing_hash_table.hh
+++ b/klm/util/probing_hash_table.hh
@@ -77,6 +77,16 @@ template <class PackingT, class HashT, class EqualT = std::equal_to<typename Pac
void LoadedBinary() {}
+ // Don't change anything related to GetKey,
+ template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) {
+ for (MutableIterator i(begin_ + (hash_(key) % buckets_));;) {
+ Key got(i->GetKey());
+ if (equal_(got, key)) { out = i; return true; }
+ if (equal_(got, invalid_)) return false;
+ if (++i == end_) i = begin_;
+ }
+ }
+
template <class Key> bool Find(const Key key, ConstIterator &out) const {
#ifdef DEBUG
assert(initialized_);
@@ -84,8 +94,8 @@ template <class PackingT, class HashT, class EqualT = std::equal_to<typename Pac
for (ConstIterator i(begin_ + (hash_(key) % buckets_));;) {
Key got(i->GetKey());
if (equal_(got, key)) { out = i; return true; }
- if (equal_(got, invalid_)) { return false; }
- if (++i == end_) { i = begin_; }
+ if (equal_(got, invalid_)) return false;
+ if (++i == end_) i = begin_;
}
}
diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh
index a8e208fb..05826b51 100644
--- a/klm/util/sorted_uniform.hh
+++ b/klm/util/sorted_uniform.hh
@@ -62,6 +62,7 @@ template <class PackingT> class SortedUniformMap {
public:
typedef PackingT Packing;
typedef typename Packing::ConstIterator ConstIterator;
+ typedef typename Packing::MutableIterator MutableIterator;
public:
// Offer consistent API with probing hash.
@@ -113,6 +114,15 @@ template <class PackingT> class SortedUniformMap {
*size_ptr_ = (end_ - begin_);
}
+ // Don't use this to change the key.
+ template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) {
+#ifdef DEBUG
+ assert(initialized_);
+ assert(loaded_);
+#endif
+ return SortedUniformFind<MutableIterator, Key>(begin_, end_, key, out);
+ }
+
// Do not call before FinishedInserting.
template <class Key> bool Find(const Key key, ConstIterator &out) const {
#ifdef DEBUG