summaryrefslogtreecommitdiff
path: root/klm
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-05-16 13:24:08 -0700
committerChris Dyer <cdyer@cab.ark.cs.cmu.edu>2012-05-26 22:59:54 -0400
commit149232c38eec558ddb1097698d1570aacb67b59f (patch)
tree5860b4d6f681eeb04a1020cbb2fe7e6ac394af99 /klm
parent01ecc09f8e3a82c32bf7dd2f90c12554becea71d (diff)
Big kenlm change includes lower order models for probing only. And other stuff.
Diffstat (limited to 'klm')
-rw-r--r--klm/lm/Jamfile2
-rw-r--r--klm/lm/Makefile.am1
-rw-r--r--klm/lm/binary_format.cc2
-rw-r--r--klm/lm/build_binary.cc97
-rw-r--r--klm/lm/config.cc1
-rw-r--r--klm/lm/config.hh22
-rw-r--r--klm/lm/left.hh110
-rw-r--r--klm/lm/left_test.cc83
-rw-r--r--klm/lm/max_order.hh2
-rw-r--r--klm/lm/model.cc192
-rw-r--r--klm/lm/model.hh93
-rw-r--r--klm/lm/model_test.cc42
-rw-r--r--klm/lm/model_type.hh13
-rw-r--r--klm/lm/ngram_query.cc18
-rw-r--r--klm/lm/ngram_query.hh47
-rw-r--r--klm/lm/quantize.cc20
-rw-r--r--klm/lm/quantize.hh164
-rw-r--r--klm/lm/read_arpa.cc8
-rw-r--r--klm/lm/read_arpa.hh14
-rw-r--r--klm/lm/return.hh3
-rw-r--r--klm/lm/search_hashed.cc243
-rw-r--r--klm/lm/search_hashed.hh229
-rw-r--r--klm/lm/search_trie.cc38
-rw-r--r--klm/lm/search_trie.hh71
-rw-r--r--klm/lm/state.hh123
-rw-r--r--klm/lm/trie.cc61
-rw-r--r--klm/lm/trie.hh61
-rw-r--r--klm/lm/value.hh157
-rw-r--r--klm/lm/value_build.cc58
-rw-r--r--klm/lm/value_build.hh97
-rw-r--r--klm/lm/vocab.cc2
-rw-r--r--klm/lm/vocab.hh6
-rw-r--r--klm/lm/weights.hh5
-rw-r--r--klm/util/Jamfile2
-rw-r--r--klm/util/Makefile.am3
-rw-r--r--klm/util/bit_packing.hh7
-rw-r--r--klm/util/ersatz_progress.cc8
-rw-r--r--klm/util/ersatz_progress.hh4
-rw-r--r--klm/util/file.cc10
-rw-r--r--klm/util/file.hh3
-rw-r--r--klm/util/file_piece.cc17
-rw-r--r--klm/util/file_piece.hh10
-rw-r--r--klm/util/have.hh10
-rw-r--r--klm/util/mmap.cc14
-rw-r--r--klm/util/murmur_hash.cc11
-rw-r--r--klm/util/murmur_hash.hh6
-rw-r--r--klm/util/probing_hash_table.hh21
-rw-r--r--klm/util/usage.cc46
-rw-r--r--klm/util/usage.hh8
49 files changed, 1460 insertions, 805 deletions
diff --git a/klm/lm/Jamfile b/klm/lm/Jamfile
index b84dbb35..b1971d88 100644
--- a/klm/lm/Jamfile
+++ b/klm/lm/Jamfile
@@ -1,4 +1,4 @@
-lib kenlm : bhiksha.cc binary_format.cc config.cc lm_exception.cc model.cc quantize.cc read_arpa.cc search_hashed.cc search_trie.cc trie.cc trie_sort.cc virtual_interface.cc vocab.cc ../util//kenutil : <include>.. : : <include>.. <library>../util//kenutil ;
+lib kenlm : bhiksha.cc binary_format.cc config.cc lm_exception.cc model.cc quantize.cc read_arpa.cc search_hashed.cc search_trie.cc trie.cc trie_sort.cc value_build.cc virtual_interface.cc vocab.cc ../util//kenutil : <include>.. : : <include>.. <library>../util//kenutil ;
import testing ;
diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am
index 54fd7f68..a12c5f03 100644
--- a/klm/lm/Makefile.am
+++ b/klm/lm/Makefile.am
@@ -24,6 +24,7 @@ libklm_a_SOURCES = \
search_trie.cc \
trie.cc \
trie_sort.cc \
+ value_build.cc \
virtual_interface.cc \
vocab.cc
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index 4796f6d1..a56e998e 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -57,7 +57,7 @@ struct Sanity {
}
};
-const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
+const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
std::size_t TotalHeaderSize(unsigned char order) {
return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);
diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc
index 8cbb69d0..c4a01cb4 100644
--- a/klm/lm/build_binary.cc
+++ b/klm/lm/build_binary.cc
@@ -66,16 +66,28 @@ uint8_t ParseBitCount(const char *from) {
return val;
}
+void ParseFileList(const char *from, std::vector<std::string> &to) {
+ to.clear();
+ while (true) {
+ const char *i;
+ for (i = from; *i && *i != ' '; ++i) {}
+ to.push_back(std::string(from, i - from));
+ if (!*i) break;
+ from = i + 1;
+ }
+}
+
void ShowSizes(const char *file, const lm::ngram::Config &config) {
std::vector<uint64_t> counts;
util::FilePiece f(file);
lm::ReadARPACounts(f, counts);
- std::size_t sizes[5];
+ std::size_t sizes[6];
sizes[0] = ProbingModel::Size(counts, config);
- sizes[1] = TrieModel::Size(counts, config);
- sizes[2] = QuantTrieModel::Size(counts, config);
- sizes[3] = ArrayTrieModel::Size(counts, config);
- sizes[4] = QuantArrayTrieModel::Size(counts, config);
+ sizes[1] = RestProbingModel::Size(counts, config);
+ sizes[2] = TrieModel::Size(counts, config);
+ sizes[3] = QuantTrieModel::Size(counts, config);
+ sizes[4] = ArrayTrieModel::Size(counts, config);
+ sizes[5] = QuantArrayTrieModel::Size(counts, config);
std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t));
std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t));
std::size_t divide;
@@ -99,10 +111,11 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {
for (long int i = 0; i < length - 2; ++i) std::cout << ' ';
std::cout << prefix << "B\n"
"probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n"
- "trie " << std::setw(length) << (sizes[1] / divide) << " without quantization\n"
- "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"
- "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n"
- "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n";
+ "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r -p " << config.probing_multiplier << "\n"
+ "trie " << std::setw(length) << (sizes[2] / divide) << " without quantization\n"
+ "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"
+ "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n"
+ "trie " << std::setw(length) << (sizes[5] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n";
}
void ProbingQuantizationUnsupported() {
@@ -118,10 +131,10 @@ int main(int argc, char *argv[]) {
using namespace lm::ngram;
try {
- bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false;
+ bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false, rest = false;
lm::ngram::Config config;
int opt;
- while ((opt = getopt(argc, argv, "q:b:a:u:p:t:m:w:si")) != -1) {
+ while ((opt = getopt(argc, argv, "q:b:a:u:p:t:m:w:sir:")) != -1) {
switch(opt) {
case 'q':
config.prob_bits = ParseBitCount(optarg);
@@ -164,6 +177,11 @@ int main(int argc, char *argv[]) {
case 'i':
config.positive_log_probability = lm::SILENT;
break;
+ case 'r':
+ rest = true;
+ ParseFileList(optarg, config.rest_lower_files);
+ config.rest_function = Config::REST_LOWER;
+ break;
default:
Usage(argv[0]);
}
@@ -174,35 +192,48 @@ int main(int argc, char *argv[]) {
}
if (optind + 1 == argc) {
ShowSizes(argv[optind], config);
- } else if (optind + 2 == argc) {
+ return 0;
+ }
+ const char *model_type;
+ const char *from_file;
+
+ if (optind + 2 == argc) {
+ model_type = "probing";
+ from_file = argv[optind];
config.write_mmap = argv[optind + 1];
- if (quantize || set_backoff_bits) ProbingQuantizationUnsupported();
- ProbingModel(argv[optind], config);
} else if (optind + 3 == argc) {
- const char *model_type = argv[optind];
- const char *from_file = argv[optind + 1];
+ model_type = argv[optind];
+ from_file = argv[optind + 1];
config.write_mmap = argv[optind + 2];
- if (!strcmp(model_type, "probing")) {
- if (!set_write_method) config.write_method = Config::WRITE_AFTER;
- if (quantize || set_backoff_bits) ProbingQuantizationUnsupported();
+ } else {
+ Usage(argv[0]);
+ }
+ if (!strcmp(model_type, "probing")) {
+ if (!set_write_method) config.write_method = Config::WRITE_AFTER;
+ if (quantize || set_backoff_bits) ProbingQuantizationUnsupported();
+ if (rest) {
+ RestProbingModel(from_file, config);
+ } else {
ProbingModel(from_file, config);
- } else if (!strcmp(model_type, "trie")) {
- if (!set_write_method) config.write_method = Config::WRITE_MMAP;
- if (quantize) {
- if (bhiksha) {
- QuantArrayTrieModel(from_file, config);
- } else {
- QuantTrieModel(from_file, config);
- }
+ }
+ } else if (!strcmp(model_type, "trie")) {
+ if (rest) {
+ std::cerr << "Rest + trie is not supported yet." << std::endl;
+ return 1;
+ }
+ if (!set_write_method) config.write_method = Config::WRITE_MMAP;
+ if (quantize) {
+ if (bhiksha) {
+ QuantArrayTrieModel(from_file, config);
} else {
- if (bhiksha) {
- ArrayTrieModel(from_file, config);
- } else {
- TrieModel(from_file, config);
- }
+ QuantTrieModel(from_file, config);
}
} else {
- Usage(argv[0]);
+ if (bhiksha) {
+ ArrayTrieModel(from_file, config);
+ } else {
+ TrieModel(from_file, config);
+ }
}
} else {
Usage(argv[0]);
diff --git a/klm/lm/config.cc b/klm/lm/config.cc
index dbe762b3..f9d988ca 100644
--- a/klm/lm/config.cc
+++ b/klm/lm/config.cc
@@ -19,6 +19,7 @@ Config::Config() :
write_mmap(NULL),
write_method(WRITE_AFTER),
include_vocab(true),
+ rest_function(REST_MAX),
prob_bits(8),
backoff_bits(8),
pointer_bhiksha_bits(22),
diff --git a/klm/lm/config.hh b/klm/lm/config.hh
index 01b75632..739cee9c 100644
--- a/klm/lm/config.hh
+++ b/klm/lm/config.hh
@@ -1,11 +1,13 @@
#ifndef LM_CONFIG__
#define LM_CONFIG__
-#include <iosfwd>
-
#include "lm/lm_exception.hh"
#include "util/mmap.hh"
+#include <iosfwd>
+#include <string>
+#include <vector>
+
/* Configuration for ngram model. Separate header to reduce pollution. */
namespace lm {
@@ -63,23 +65,33 @@ struct Config {
const char *temporary_directory_prefix;
// Level of complaining to do when loading from ARPA instead of binary format.
- typedef enum {ALL, EXPENSIVE, NONE} ARPALoadComplain;
+ enum ARPALoadComplain {ALL, EXPENSIVE, NONE};
ARPALoadComplain arpa_complain;
// While loading an ARPA file, also write out this binary format file. Set
// to NULL to disable.
const char *write_mmap;
- typedef enum {
+ enum WriteMethod {
WRITE_MMAP, // Map the file directly.
WRITE_AFTER // Write after we're done.
- } WriteMethod;
+ };
WriteMethod write_method;
// Include the vocab in the binary file? Only effective if write_mmap != NULL.
bool include_vocab;
+ // Left rest options. Only used when the model includes rest costs.
+ enum RestFunction {
+ REST_MAX, // Maximum of any score to the left
+ REST_LOWER, // Use lower-order files given below.
+ };
+ RestFunction rest_function;
+ // Only used for REST_LOWER.
+ std::vector<std::string> rest_lower_files;
+
+
// Quantization options. Only effective for QuantTrieModel. One value is
// reserved for each of prob and backoff, so 2^bits - 1 buckets will be used
diff --git a/klm/lm/left.hh b/klm/lm/left.hh
index a07f9803..c00af88a 100644
--- a/klm/lm/left.hh
+++ b/klm/lm/left.hh
@@ -39,7 +39,7 @@
#define LM_LEFT__
#include "lm/max_order.hh"
-#include "lm/model.hh"
+#include "lm/state.hh"
#include "lm/return.hh"
#include "util/murmur_hash.hh"
@@ -49,72 +49,6 @@
namespace lm {
namespace ngram {
-struct Left {
- bool operator==(const Left &other) const {
- return
- (length == other.length) &&
- pointers[length - 1] == other.pointers[length - 1];
- }
-
- int Compare(const Left &other) const {
- if (length != other.length) return length < other.length ? -1 : 1;
- if (pointers[length - 1] > other.pointers[length - 1]) return 1;
- if (pointers[length - 1] < other.pointers[length - 1]) return -1;
- return 0;
- }
-
- bool operator<(const Left &other) const {
- if (length != other.length) return length < other.length;
- return pointers[length - 1] < other.pointers[length - 1];
- }
-
- void ZeroRemaining() {
- for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i)
- *i = 0;
- }
-
- unsigned char length;
- uint64_t pointers[kMaxOrder - 1];
-};
-
-inline size_t hash_value(const Left &left) {
- return util::MurmurHashNative(&left.length, 1, left.pointers[left.length - 1]);
-}
-
-struct ChartState {
- bool operator==(const ChartState &other) {
- return (left == other.left) && (right == other.right) && (full == other.full);
- }
-
- int Compare(const ChartState &other) const {
- int lres = left.Compare(other.left);
- if (lres) return lres;
- int rres = right.Compare(other.right);
- if (rres) return rres;
- return (int)full - (int)other.full;
- }
-
- bool operator<(const ChartState &other) const {
- return Compare(other) == -1;
- }
-
- void ZeroRemaining() {
- left.ZeroRemaining();
- right.ZeroRemaining();
- }
-
- Left left;
- bool full;
- State right;
-};
-
-inline size_t hash_value(const ChartState &state) {
- size_t hashes[2];
- hashes[0] = hash_value(state.left);
- hashes[1] = hash_value(state.right);
- return util::MurmurHashNative(hashes, sizeof(size_t) * 2, state.full);
-}
-
template <class M> class RuleScore {
public:
explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) {
@@ -131,29 +65,30 @@ template <class M> class RuleScore {
void Terminal(WordIndex word) {
State copy(out_.right);
FullScoreReturn ret(model_.FullScore(copy, word, out_.right));
- prob_ += ret.prob;
- if (left_done_) return;
+ if (left_done_) { prob_ += ret.prob; return; }
if (ret.independent_left) {
+ prob_ += ret.prob;
left_done_ = true;
return;
}
out_.left.pointers[out_.left.length++] = ret.extend_left;
+ prob_ += ret.rest;
if (out_.right.length != copy.length + 1)
left_done_ = true;
}
// Faster version of NonTerminal for the case where the rule begins with a non-terminal.
- void BeginNonTerminal(const ChartState &in, float prob) {
+ void BeginNonTerminal(const ChartState &in, float prob = 0.0) {
prob_ = prob;
out_ = in;
- left_done_ = in.full;
+ left_done_ = in.left.full;
}
- void NonTerminal(const ChartState &in, float prob) {
+ void NonTerminal(const ChartState &in, float prob = 0.0) {
prob_ += prob;
if (!in.left.length) {
- if (in.full) {
+ if (in.left.full) {
for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i;
left_done_ = true;
out_.right = in.right;
@@ -163,12 +98,15 @@ template <class M> class RuleScore {
if (!out_.right.length) {
out_.right = in.right;
- if (left_done_) return;
+ if (left_done_) {
+ prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1);
+ return;
+ }
if (out_.left.length) {
left_done_ = true;
} else {
out_.left = in.left;
- left_done_ = in.full;
+ left_done_ = in.left.full;
}
return;
}
@@ -186,7 +124,7 @@ template <class M> class RuleScore {
std::swap(back, back2);
}
- if (in.full) {
+ if (in.left.full) {
for (const float *i = back; i != back + next_use; ++i) prob_ += *i;
left_done_ = true;
out_.right = in.right;
@@ -213,10 +151,17 @@ template <class M> class RuleScore {
float Finish() {
// A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram.
- out_.full = left_done_ || (out_.left.length == model_.Order() - 1);
+ out_.left.full = left_done_ || (out_.left.length == model_.Order() - 1);
return prob_;
}
+ void Reset() {
+ prob_ = 0.0;
+ left_done_ = false;
+ out_.left.length = 0;
+ out_.right.length = 0;
+ }
+
private:
bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) {
ProcessRet(model_.ExtendLeft(
@@ -228,8 +173,9 @@ template <class M> class RuleScore {
if (next_use != out_.right.length) {
left_done_ = true;
if (!next_use) {
- out_.right = in.right;
// Early exit.
+ out_.right = in.right;
+ prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1);
return true;
}
}
@@ -238,13 +184,17 @@ template <class M> class RuleScore {
}
void ProcessRet(const FullScoreReturn &ret) {
- prob_ += ret.prob;
- if (left_done_) return;
+ if (left_done_) {
+ prob_ += ret.prob;
+ return;
+ }
if (ret.independent_left) {
+ prob_ += ret.prob;
left_done_ = true;
return;
}
out_.left.pointers[out_.left.length++] = ret.extend_left;
+ prob_ += ret.rest;
}
const M &model_;
diff --git a/klm/lm/left_test.cc b/klm/lm/left_test.cc
index c85e5efa..b23e6a0f 100644
--- a/klm/lm/left_test.cc
+++ b/klm/lm/left_test.cc
@@ -24,7 +24,7 @@ template <class M> void Short(const M &m) {
Term("loin");
BOOST_CHECK_CLOSE(-1.206319 - 0.3561665, score.Finish(), 0.001);
}
- BOOST_CHECK(base.full);
+ BOOST_CHECK(base.left.full);
BOOST_CHECK_EQUAL(2, base.left.length);
BOOST_CHECK_EQUAL(1, base.right.length);
VCheck("loin", base.right.words[0]);
@@ -40,7 +40,7 @@ template <class M> void Short(const M &m) {
BOOST_CHECK_EQUAL(3, more_left.left.length);
BOOST_CHECK_EQUAL(1, more_left.right.length);
VCheck("loin", more_left.right.words[0]);
- BOOST_CHECK(more_left.full);
+ BOOST_CHECK(more_left.left.full);
ChartState shorter;
{
@@ -52,7 +52,7 @@ template <class M> void Short(const M &m) {
BOOST_CHECK_EQUAL(1, shorter.left.length);
BOOST_CHECK_EQUAL(1, shorter.right.length);
VCheck("loin", shorter.right.words[0]);
- BOOST_CHECK(shorter.full);
+ BOOST_CHECK(shorter.left.full);
}
template <class M> void Charge(const M &m) {
@@ -66,7 +66,7 @@ template <class M> void Charge(const M &m) {
BOOST_CHECK_EQUAL(1, base.left.length);
BOOST_CHECK_EQUAL(1, base.right.length);
VCheck("more", base.right.words[0]);
- BOOST_CHECK(base.full);
+ BOOST_CHECK(base.left.full);
ChartState extend;
{
@@ -78,7 +78,7 @@ template <class M> void Charge(const M &m) {
BOOST_CHECK_EQUAL(2, extend.left.length);
BOOST_CHECK_EQUAL(1, extend.right.length);
VCheck("more", extend.right.words[0]);
- BOOST_CHECK(extend.full);
+ BOOST_CHECK(extend.left.full);
ChartState tobos;
{
@@ -91,9 +91,9 @@ template <class M> void Charge(const M &m) {
BOOST_CHECK_EQUAL(1, tobos.right.length);
}
-template <class M> float LeftToRight(const M &m, const std::vector<WordIndex> &words) {
+template <class M> float LeftToRight(const M &m, const std::vector<WordIndex> &words, bool begin_sentence = false) {
float ret = 0.0;
- State right = m.NullContextState();
+ State right = begin_sentence ? m.BeginSentenceState() : m.NullContextState();
for (std::vector<WordIndex>::const_iterator i = words.begin(); i != words.end(); ++i) {
State copy(right);
ret += m.Score(copy, *i, right);
@@ -101,12 +101,12 @@ template <class M> float LeftToRight(const M &m, const std::vector<WordIndex> &w
return ret;
}
-template <class M> float RightToLeft(const M &m, const std::vector<WordIndex> &words) {
+template <class M> float RightToLeft(const M &m, const std::vector<WordIndex> &words, bool begin_sentence = false) {
float ret = 0.0;
ChartState state;
state.left.length = 0;
state.right.length = 0;
- state.full = false;
+ state.left.full = false;
for (std::vector<WordIndex>::const_reverse_iterator i = words.rbegin(); i != words.rend(); ++i) {
ChartState copy(state);
RuleScore<M> score(m, state);
@@ -114,10 +114,17 @@ template <class M> float RightToLeft(const M &m, const std::vector<WordIndex> &w
score.NonTerminal(copy, ret);
ret = score.Finish();
}
+ if (begin_sentence) {
+ ChartState copy(state);
+ RuleScore<M> score(m, state);
+ score.BeginSentence();
+ score.NonTerminal(copy, ret);
+ ret = score.Finish();
+ }
return ret;
}
-template <class M> float TreeMiddle(const M &m, const std::vector<WordIndex> &words) {
+template <class M> float TreeMiddle(const M &m, const std::vector<WordIndex> &words, bool begin_sentence = false) {
std::vector<std::pair<ChartState, float> > states(words.size());
for (unsigned int i = 0; i < words.size(); ++i) {
RuleScore<M> score(m, states[i].first);
@@ -137,7 +144,19 @@ template <class M> float TreeMiddle(const M &m, const std::vector<WordIndex> &wo
}
std::swap(states, upper);
}
- return states.empty() ? 0 : states.back().second;
+
+ if (states.empty()) return 0.0;
+
+ if (begin_sentence) {
+ ChartState ignored;
+ RuleScore<M> score(m, ignored);
+ score.BeginSentence();
+ score.NonTerminal(states.front().first, states.front().second);
+ return score.Finish();
+ } else {
+ return states.front().second;
+ }
+
}
template <class M> void LookupVocab(const M &m, const StringPiece &str, std::vector<WordIndex> &out) {
@@ -148,16 +167,15 @@ template <class M> void LookupVocab(const M &m, const StringPiece &str, std::vec
}
#define TEXT_TEST(str) \
-{ \
- std::vector<WordIndex> words; \
LookupVocab(m, str, words); \
- float expect = LeftToRight(m, words); \
- BOOST_CHECK_CLOSE(expect, RightToLeft(m, words), 0.001); \
- BOOST_CHECK_CLOSE(expect, TreeMiddle(m, words), 0.001); \
-}
+ expect = LeftToRight(m, words, rest); \
+ BOOST_CHECK_CLOSE(expect, RightToLeft(m, words, rest), 0.001); \
+ BOOST_CHECK_CLOSE(expect, TreeMiddle(m, words, rest), 0.001); \
// Build sentences, or parts thereof, from right to left.
-template <class M> void GrowBig(const M &m) {
+template <class M> void GrowBig(const M &m, bool rest = false) {
+ std::vector<WordIndex> words;
+ float expect;
TEXT_TEST("in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>");
TEXT_TEST("on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>");
TEXT_TEST("on a little more loin also would consider higher to look good");
@@ -171,6 +189,14 @@ template <class M> void GrowBig(const M &m) {
TEXT_TEST("consider higher");
}
+template <class M> void GrowSmall(const M &m, bool rest = false) {
+ std::vector<WordIndex> words;
+ float expect;
+ TEXT_TEST("in biarritz watching considering looking . </s>");
+ TEXT_TEST("in biarritz watching considering looking .");
+ TEXT_TEST("in biarritz");
+}
+
template <class M> void AlsoWouldConsiderHigher(const M &m) {
ChartState also;
{
@@ -210,7 +236,7 @@ template <class M> void AlsoWouldConsiderHigher(const M &m) {
}
BOOST_CHECK_EQUAL(1, consider.left.length);
BOOST_CHECK_EQUAL(1, consider.right.length);
- BOOST_CHECK(!consider.full);
+ BOOST_CHECK(!consider.left.full);
ChartState higher;
float higher_score;
@@ -222,7 +248,7 @@ template <class M> void AlsoWouldConsiderHigher(const M &m) {
BOOST_CHECK_CLOSE(-1.509559, higher_score, 0.001);
BOOST_CHECK_EQUAL(1, higher.left.length);
BOOST_CHECK_EQUAL(1, higher.right.length);
- BOOST_CHECK(!higher.full);
+ BOOST_CHECK(!higher.left.full);
VCheck("higher", higher.right.words[0]);
BOOST_CHECK_CLOSE(-0.30103, higher.right.backoff[0], 0.001);
@@ -234,7 +260,7 @@ template <class M> void AlsoWouldConsiderHigher(const M &m) {
BOOST_CHECK_CLOSE(-1.509559 - 1.687872 - 0.30103, score.Finish(), 0.001);
}
BOOST_CHECK_EQUAL(2, consider_higher.left.length);
- BOOST_CHECK(!consider_higher.full);
+ BOOST_CHECK(!consider_higher.left.full);
ChartState full;
{
@@ -246,12 +272,6 @@ template <class M> void AlsoWouldConsiderHigher(const M &m) {
BOOST_CHECK_EQUAL(4, full.right.length);
}
-template <class M> void GrowSmall(const M &m) {
- TEXT_TEST("in biarritz watching considering looking . </s>");
- TEXT_TEST("in biarritz watching considering looking .");
- TEXT_TEST("in biarritz");
-}
-
#define CHECK_SCORE(str, val) \
{ \
float got = val; \
@@ -315,7 +335,7 @@ template <class M> void FullGrow(const M &m) {
CHECK_SCORE("looking . </s>", l2_scores[1] = score.Finish());
}
BOOST_CHECK_EQUAL(l2[1].left.length, 1);
- BOOST_CHECK(l2[1].full);
+ BOOST_CHECK(l2[1].left.full);
ChartState top;
{
@@ -362,6 +382,13 @@ BOOST_AUTO_TEST_CASE(ArrayTrieAll) {
Everything<ArrayTrieModel>();
}
+BOOST_AUTO_TEST_CASE(RestProbing) {
+ Config config;
+ config.messages = NULL;
+ RestProbingModel m(FileLocation(), config);
+ GrowBig(m, true);
+}
+
} // namespace
} // namespace ngram
} // namespace lm
diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh
index 71cd23dd..aff9de27 100644
--- a/klm/lm/max_order.hh
+++ b/klm/lm/max_order.hh
@@ -6,7 +6,7 @@ namespace ngram {
// Having this limit means that State can be
// (kMaxOrder - 1) * sizeof(float) bytes instead of
// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
-const unsigned char kMaxOrder = 6;
+const unsigned char kMaxOrder = 5;
} // namespace ngram
} // namespace lm
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index 478ebed1..c081788c 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -38,10 +38,13 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge
State begin_sentence = State();
begin_sentence.length = 1;
begin_sentence.words[0] = vocab_.BeginSentence();
- begin_sentence.backoff[0] = search_.unigram.Lookup(begin_sentence.words[0]).backoff;
+ typename Search::Node ignored_node;
+ bool ignored_independent_left;
+ uint64_t ignored_extend_left;
+ begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff();
State null_context = State();
null_context.length = 0;
- P::Init(begin_sentence, null_context, vocab_, search_.MiddleEnd() - search_.MiddleBegin() + 2);
+ P::Init(begin_sentence, null_context, vocab_, search_.Order());
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd) {
@@ -50,6 +53,9 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.LoadedBinary();
}
+namespace {
+} // namespace
+
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
// Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
util::FilePiece f(backing_.file.release(), file, config.messages);
@@ -79,8 +85,8 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
if (!vocab_.SawUnk()) {
assert(config.unknown_missing != THROW_UP);
// Default probabilities for unknown.
- search_.unigram.Unknown().backoff = 0.0;
- search_.unigram.Unknown().prob = config.unknown_missing_logprob;
+ search_.UnknownUnigram().backoff = 0.0;
+ search_.UnknownUnigram().prob = config.unknown_missing_logprob;
}
FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_);
} catch (util::Exception &e) {
@@ -109,20 +115,22 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
// Add the backoff weights for n-grams of order start to (context_rend - context_rbegin).
unsigned char start = ret.ngram_length;
if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return ret;
+
+ bool independent_left;
+ uint64_t extend_left;
+ typename Search::Node node;
if (start <= 1) {
- ret.prob += search_.unigram.Lookup(*context_rbegin).backoff;
+ ret.prob += search_.LookupUnigram(*context_rbegin, node, independent_left, extend_left).Backoff();
start = 2;
- }
- typename Search::Node node;
- if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) {
+ } else if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) {
return ret;
}
- float backoff;
// i is the order of the backoff we're looking for.
- typename Search::MiddleIter mid_iter = search_.MiddleBegin() + start - 2;
- for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++mid_iter) {
- if (!search_.LookupMiddleNoProb(*mid_iter, *i, backoff, node)) break;
- ret.prob += backoff;
+ unsigned char order_minus_2 = 0;
+ for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++order_minus_2) {
+ typename Search::MiddlePointer p(search_.LookupMiddle(order_minus_2, *i, node, independent_left, extend_left));
+ if (!p.Found()) break;
+ ret.prob += p.Backoff();
}
return ret;
}
@@ -134,17 +142,20 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
out_state.length = 0;
return;
}
- FullScoreReturn ignored;
typename Search::Node node;
- search_.LookupUnigram(*context_rbegin, out_state.backoff[0], node, ignored);
+ bool independent_left;
+ uint64_t extend_left;
+ out_state.backoff[0] = search_.LookupUnigram(*context_rbegin, node, independent_left, extend_left).Backoff();
out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0;
float *backoff_out = out_state.backoff + 1;
- typename Search::MiddleIter mid(search_.MiddleBegin());
- for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) {
- if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) {
+ unsigned char order_minus_2 = 0;
+ for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++order_minus_2) {
+ typename Search::MiddlePointer p(search_.LookupMiddle(order_minus_2, *i, node, independent_left, extend_left));
+ if (!p.Found()) {
std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words);
return;
}
+ *backoff_out = p.Backoff();
if (HasExtension(*backoff_out)) out_state.length = i - context_rbegin + 1;
}
std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words);
@@ -158,43 +169,29 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
float *backoff_out,
unsigned char &next_use) const {
FullScoreReturn ret;
- float subtract_me;
- typename Search::Node node(search_.Unpack(extend_pointer, extend_length, subtract_me));
- ret.prob = subtract_me;
- ret.ngram_length = extend_length;
- next_use = 0;
- // If this function is called, then it does depend on left words.
- ret.independent_left = false;
- ret.extend_left = extend_pointer;
- typename Search::MiddleIter mid_iter(search_.MiddleBegin() + extend_length - 1);
- const WordIndex *i = add_rbegin;
- for (; ; ++i, ++backoff_out, ++mid_iter) {
- if (i == add_rend) {
- // Ran out of words.
- for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b;
- ret.prob -= subtract_me;
- return ret;
- }
- if (mid_iter == search_.MiddleEnd()) break;
- if (ret.independent_left || !search_.LookupMiddle(*mid_iter, *i, *backoff_out, node, ret)) {
- // Didn't match a word.
- ret.independent_left = true;
- for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b;
- ret.prob -= subtract_me;
- return ret;
- }
- ret.ngram_length = mid_iter - search_.MiddleBegin() + 2;
- if (HasExtension(*backoff_out)) next_use = i - add_rbegin + 1;
- }
-
- if (ret.independent_left || !search_.LookupLongest(*i, ret.prob, node)) {
- // The last backoff weight, for Order() - 1.
- ret.prob += backoff_in[i - add_rbegin];
+ typename Search::Node node;
+ if (extend_length == 1) {
+ typename Search::UnigramPointer ptr(search_.LookupUnigram(static_cast<WordIndex>(extend_pointer), node, ret.independent_left, ret.extend_left));
+ ret.rest = ptr.Rest();
+ ret.prob = ptr.Prob();
+ assert(!ret.independent_left);
} else {
- ret.ngram_length = P::Order();
+ typename Search::MiddlePointer ptr(search_.Unpack(extend_pointer, extend_length, node));
+ ret.rest = ptr.Rest();
+ ret.prob = ptr.Prob();
+ ret.extend_left = extend_pointer;
+ // If this function is called, then it does depend on left words.
+ ret.independent_left = false;
}
- ret.independent_left = true;
+ float subtract_me = ret.rest;
+ ret.ngram_length = extend_length;
+ next_use = extend_length;
+ ResumeScore(add_rbegin, add_rend, extend_length - 1, node, backoff_out, next_use, ret);
+ next_use -= extend_length;
+ // Charge backoffs.
+ for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b;
ret.prob -= subtract_me;
+ ret.rest -= subtract_me;
return ret;
}
@@ -215,66 +212,83 @@ void CopyRemainingHistory(const WordIndex *from, State &out_state) {
* new_word.
*/
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff(
- const WordIndex *context_rbegin,
- const WordIndex *context_rend,
+ const WordIndex *const context_rbegin,
+ const WordIndex *const context_rend,
const WordIndex new_word,
State &out_state) const {
FullScoreReturn ret;
// ret.ngram_length contains the last known non-blank ngram length.
ret.ngram_length = 1;
- float *backoff_out(out_state.backoff);
typename Search::Node node;
- search_.LookupUnigram(new_word, *backoff_out, node, ret);
+ typename Search::UnigramPointer uni(search_.LookupUnigram(new_word, node, ret.independent_left, ret.extend_left));
+ out_state.backoff[0] = uni.Backoff();
+ ret.prob = uni.Prob();
+ ret.rest = uni.Rest();
+
// This is the length of the context that should be used for continuation to the right.
- out_state.length = HasExtension(*backoff_out) ? 1 : 0;
+ out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0;
// We'll write the word anyway since it will probably be used and does no harm being there.
out_state.words[0] = new_word;
if (context_rbegin == context_rend) return ret;
- ++backoff_out;
-
- // Ok start by looking up the bigram.
- const WordIndex *hist_iter = context_rbegin;
- typename Search::MiddleIter mid_iter(search_.MiddleBegin());
- for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
- if (hist_iter == context_rend) {
- // Ran out of history. Typically no backoff, but this could be a blank.
- CopyRemainingHistory(context_rbegin, out_state);
- // ret.prob was already set.
- return ret;
- }
- if (mid_iter == search_.MiddleEnd()) break;
+ ResumeScore(context_rbegin, context_rend, 0, node, out_state.backoff + 1, out_state.length, ret);
+ CopyRemainingHistory(context_rbegin, out_state);
+ return ret;
+}
- if (ret.independent_left || !search_.LookupMiddle(*mid_iter, *hist_iter, *backoff_out, node, ret)) {
- // Didn't find an ngram using hist_iter.
- CopyRemainingHistory(context_rbegin, out_state);
- // ret.prob was already set.
- ret.independent_left = true;
- return ret;
- }
- ret.ngram_length = hist_iter - context_rbegin + 2;
+template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::ResumeScore(const WordIndex *hist_iter, const WordIndex *const context_rend, unsigned char order_minus_2, typename Search::Node &node, float *backoff_out, unsigned char &next_use, FullScoreReturn &ret) const {
+ for (; ; ++order_minus_2, ++hist_iter, ++backoff_out) {
+ if (hist_iter == context_rend) return;
+ if (ret.independent_left) return;
+ if (order_minus_2 == P::Order() - 2) break;
+
+ typename Search::MiddlePointer pointer(search_.LookupMiddle(order_minus_2, *hist_iter, node, ret.independent_left, ret.extend_left));
+ if (!pointer.Found()) return;
+ *backoff_out = pointer.Backoff();
+ ret.prob = pointer.Prob();
+ ret.rest = pointer.Rest();
+ ret.ngram_length = order_minus_2 + 2;
if (HasExtension(*backoff_out)) {
- out_state.length = ret.ngram_length;
+ next_use = ret.ngram_length;
}
}
-
- // It passed every lookup in search_.middle. All that's left is to check search_.longest.
- if (!ret.independent_left && search_.LookupLongest(*hist_iter, ret.prob, node)) {
- // It's an P::Order()-gram.
+ ret.independent_left = true;
+ typename Search::LongestPointer longest(search_.LookupLongest(*hist_iter, node));
+ if (longest.Found()) {
+ ret.prob = longest.Prob();
+ ret.rest = ret.prob;
// There is no blank in longest_.
ret.ngram_length = P::Order();
}
- // This handles (N-1)-grams and N-grams.
- CopyRemainingHistory(context_rbegin, out_state);
- ret.independent_left = true;
+}
+
+template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const {
+ float ret;
+ typename Search::Node node;
+ if (first_length == 1) {
+ if (pointers_begin >= pointers_end) return 0.0;
+ bool independent_left;
+ uint64_t extend_left;
+ typename Search::UnigramPointer ptr(search_.LookupUnigram(static_cast<WordIndex>(*pointers_begin), node, independent_left, extend_left));
+ ret = ptr.Prob() - ptr.Rest();
+ ++first_length;
+ ++pointers_begin;
+ } else {
+ ret = 0.0;
+ }
+ for (const uint64_t *i = pointers_begin; i < pointers_end; ++i, ++first_length) {
+ typename Search::MiddlePointer ptr(search_.Unpack(*i, first_length, node));
+ ret += ptr.Prob() - ptr.Rest();
+ }
return ret;
}
-template class GenericModel<ProbingHashedSearch, ProbingVocabulary>; // HASH_PROBING
-template class GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary>; // TRIE_SORTED
+template class GenericModel<HashedSearch<BackoffValue>, ProbingVocabulary>;
+template class GenericModel<HashedSearch<RestValue>, ProbingVocabulary>;
+template class GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary>;
template class GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
-template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary>; // TRIE_SORTED_QUANT
+template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary>;
template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
} // namespace detail
diff --git a/klm/lm/model.hh b/klm/lm/model.hh
index 6ea62a78..be872178 100644
--- a/klm/lm/model.hh
+++ b/klm/lm/model.hh
@@ -9,6 +9,8 @@
#include "lm/quantize.hh"
#include "lm/search_hashed.hh"
#include "lm/search_trie.hh"
+#include "lm/state.hh"
+#include "lm/value.hh"
#include "lm/vocab.hh"
#include "lm/weights.hh"
@@ -23,48 +25,6 @@ namespace util { class FilePiece; }
namespace lm {
namespace ngram {
-
-// This is a POD but if you want memcmp to return the same as operator==, call
-// ZeroRemaining first.
-class State {
- public:
- bool operator==(const State &other) const {
- if (length != other.length) return false;
- return !memcmp(words, other.words, length * sizeof(WordIndex));
- }
-
- // Three way comparison function.
- int Compare(const State &other) const {
- if (length != other.length) return length < other.length ? -1 : 1;
- return memcmp(words, other.words, length * sizeof(WordIndex));
- }
-
- bool operator<(const State &other) const {
- if (length != other.length) return length < other.length;
- return memcmp(words, other.words, length * sizeof(WordIndex)) < 0;
- }
-
- // Call this before using raw memcmp.
- void ZeroRemaining() {
- for (unsigned char i = length; i < kMaxOrder - 1; ++i) {
- words[i] = 0;
- backoff[i] = 0.0;
- }
- }
-
- unsigned char Length() const { return length; }
-
- // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
- // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
- WordIndex words[kMaxOrder - 1];
- float backoff[kMaxOrder - 1];
- unsigned char length;
-};
-
-inline size_t hash_value(const State &state) {
- return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length);
-}
-
namespace detail {
// Should return the same results as SRI.
@@ -119,8 +79,7 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
/* More efficient version of FullScore where a partial n-gram has already
* been scored.
- * NOTE: THE RETURNED .prob IS RELATIVE, NOT ABSOLUTE. So for example, if
- * the n-gram does not end up extending further left, then 0 is returned.
+ * NOTE: THE RETURNED .rest AND .prob ARE RELATIVE TO THE .rest RETURNED BEFORE.
*/
FullScoreReturn ExtendLeft(
// Additional context in reverse order. This will update add_rend to
@@ -136,12 +95,24 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
// Amount of additional content that should be considered by the next call.
unsigned char &next_use) const;
+ /* Return probabilities minus rest costs for an array of pointers. The
+ * first length should be the length of the n-gram to which pointers_begin
+ * points.
+ */
+ float UnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const {
+ // Compiler should optimize this if away.
+ return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0;
+ }
+
private:
friend void lm::ngram::LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to);
static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config);
- FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;
+ FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const;
+
+ // Score bigrams and above. Do not include backoff.
+ void ResumeScore(const WordIndex *context_rbegin, const WordIndex *const context_rend, unsigned char starting_order_minus_2, typename Search::Node &node, float *backoff_out, unsigned char &next_use, FullScoreReturn &ret) const;
// Appears after Size in the cc file.
void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config);
@@ -150,32 +121,38 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
void InitializeFromARPA(const char *file, const Config &config);
+ float InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const;
+
Backing &MutableBacking() { return backing_; }
Backing backing_;
VocabularyT vocab_;
- typedef typename Search::Middle Middle;
-
Search search_;
};
} // namespace detail
-// These must also be instantiated in the cc file.
-typedef ::lm::ngram::ProbingVocabulary Vocabulary;
-typedef detail::GenericModel<detail::ProbingHashedSearch, Vocabulary> ProbingModel; // HASH_PROBING
-// Default implementation. No real reason for it to be the default.
-typedef ProbingModel Model;
+// Instead of typedef, inherit. This allows the Model etc to be forward declared.
+// Oh the joys of C and C++.
+#define LM_COMMA() ,
+#define LM_NAME_MODEL(name, from)\
+class name : public from {\
+ public:\
+ name(const char *file, const Config &config = Config()) : from(file, config) {}\
+};
-// Smaller implementation.
-typedef ::lm::ngram::SortedVocabulary SortedVocabulary;
-typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary> TrieModel; // TRIE_SORTED
-typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary> ArrayTrieModel;
+LM_NAME_MODEL(ProbingModel, detail::GenericModel<detail::HashedSearch<BackoffValue> LM_COMMA() ProbingVocabulary>);
+LM_NAME_MODEL(RestProbingModel, detail::GenericModel<detail::HashedSearch<RestValue> LM_COMMA() ProbingVocabulary>);
+LM_NAME_MODEL(TrieModel, detail::GenericModel<trie::TrieSearch<DontQuantize LM_COMMA() trie::DontBhiksha> LM_COMMA() SortedVocabulary>);
+LM_NAME_MODEL(ArrayTrieModel, detail::GenericModel<trie::TrieSearch<DontQuantize LM_COMMA() trie::ArrayBhiksha> LM_COMMA() SortedVocabulary>);
+LM_NAME_MODEL(QuantTrieModel, detail::GenericModel<trie::TrieSearch<SeparatelyQuantize LM_COMMA() trie::DontBhiksha> LM_COMMA() SortedVocabulary>);
+LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel<trie::TrieSearch<SeparatelyQuantize LM_COMMA() trie::ArrayBhiksha> LM_COMMA() SortedVocabulary>);
-typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED
-typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary> QuantArrayTrieModel;
+// Default implementation. No real reason for it to be the default.
+typedef ::lm::ngram::ProbingVocabulary Vocabulary;
+typedef ProbingModel Model;
} // namespace ngram
} // namespace lm
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
index 461704d4..8a122c60 100644
--- a/klm/lm/model_test.cc
+++ b/klm/lm/model_test.cc
@@ -30,7 +30,15 @@ const char *TestNoUnkLocation() {
return "test_nounk.arpa";
}
return boost::unit_test::framework::master_test_suite().argv[2];
+}
+template <class Model> State GetState(const Model &model, const char *word, const State &in) {
+ WordIndex context[in.length + 1];
+ context[0] = model.GetVocabulary().Index(word);
+ std::copy(in.words, in.words + in.length, context + 1);
+ State ret;
+ model.GetState(context, context + in.length + 1, ret);
+ return ret;
}
#define StartTest(word, ngram, score, indep_left) \
@@ -42,14 +50,7 @@ const char *TestNoUnkLocation() {
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
BOOST_CHECK_GE(std::min<unsigned char>(ngram, 5 - 1), out.length); \
BOOST_CHECK_EQUAL(indep_left, ret.independent_left); \
- {\
- WordIndex context[state.length + 1]; \
- context[0] = model.GetVocabulary().Index(word); \
- std::copy(state.words, state.words + state.length, context + 1); \
- State get_state; \
- model.GetState(context, context + state.length + 1, get_state); \
- BOOST_CHECK_EQUAL(out, get_state); \
- }
+ BOOST_CHECK_EQUAL(out, GetState(model, word, state));
#define AppendTest(word, ngram, score, indep_left) \
StartTest(word, ngram, score, indep_left) \
@@ -182,7 +183,7 @@ template <class M> void ExtendLeftTest(const M &model) {
FullScoreReturn extend_none(model.ExtendLeft(NULL, NULL, NULL, little.extend_left, 1, NULL, next_use));
BOOST_CHECK_EQUAL(0, next_use);
BOOST_CHECK_EQUAL(little.extend_left, extend_none.extend_left);
- BOOST_CHECK_CLOSE(0.0, extend_none.prob, 0.001);
+ BOOST_CHECK_CLOSE(little.prob - little.rest, extend_none.prob, 0.001);
BOOST_CHECK_EQUAL(1, extend_none.ngram_length);
const WordIndex a = model.GetVocabulary().Index("a");
@@ -191,7 +192,7 @@ template <class M> void ExtendLeftTest(const M &model) {
FullScoreReturn extend_a(model.ExtendLeft(&a, &a + 1, &backoff_in, little.extend_left, 1, backoff_out, next_use));
BOOST_CHECK_EQUAL(1, next_use);
BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
- BOOST_CHECK_CLOSE(-0.09132547 - kLittleProb, extend_a.prob, 0.001);
+ BOOST_CHECK_CLOSE(-0.09132547 - little.rest, extend_a.prob, 0.001);
BOOST_CHECK_EQUAL(2, extend_a.ngram_length);
BOOST_CHECK(!extend_a.independent_left);
@@ -199,7 +200,7 @@ template <class M> void ExtendLeftTest(const M &model) {
FullScoreReturn extend_on(model.ExtendLeft(&on, &on + 1, &backoff_in, extend_a.extend_left, 2, backoff_out, next_use));
BOOST_CHECK_EQUAL(1, next_use);
BOOST_CHECK_CLOSE(-0.4771212, backoff_out[0], 0.001);
- BOOST_CHECK_CLOSE(-0.0283603 - -0.09132547, extend_on.prob, 0.001);
+ BOOST_CHECK_CLOSE(-0.0283603 - (extend_a.rest + little.rest), extend_on.prob, 0.001);
BOOST_CHECK_EQUAL(3, extend_on.ngram_length);
BOOST_CHECK(!extend_on.independent_left);
@@ -209,7 +210,7 @@ template <class M> void ExtendLeftTest(const M &model) {
BOOST_CHECK_EQUAL(2, next_use);
BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
BOOST_CHECK_CLOSE(-0.4771212, backoff_out[1], 0.001);
- BOOST_CHECK_CLOSE(-0.0283603 - kLittleProb, extend_both.prob, 0.001);
+ BOOST_CHECK_CLOSE(-0.0283603 - little.rest, extend_both.prob, 0.001);
BOOST_CHECK_EQUAL(3, extend_both.ngram_length);
BOOST_CHECK(!extend_both.independent_left);
BOOST_CHECK_EQUAL(extend_on.extend_left, extend_both.extend_left);
@@ -399,7 +400,10 @@ template <class ModelT> void BinaryTest() {
}
BOOST_AUTO_TEST_CASE(write_and_read_probing) {
- BinaryTest<Model>();
+ BinaryTest<ProbingModel>();
+}
+BOOST_AUTO_TEST_CASE(write_and_read_rest_probing) {
+ BinaryTest<RestProbingModel>();
}
BOOST_AUTO_TEST_CASE(write_and_read_trie) {
BinaryTest<TrieModel>();
@@ -414,6 +418,18 @@ BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) {
BinaryTest<QuantArrayTrieModel>();
}
+BOOST_AUTO_TEST_CASE(rest_max) {
+ Config config;
+ config.arpa_complain = Config::NONE;
+ config.messages = NULL;
+
+ RestProbingModel model(TestLocation(), config);
+ State state, out;
+ FullScoreReturn ret(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("."), state));
+ BOOST_CHECK_CLOSE(-0.2705918, ret.rest, 0.001);
+ BOOST_CHECK_CLOSE(-0.01916512, model.FullScore(state, model.GetVocabulary().EndSentence(), out).rest, 0.001);
+}
+
} // namespace
} // namespace ngram
} // namespace lm
diff --git a/klm/lm/model_type.hh b/klm/lm/model_type.hh
index 5057ed25..8b35c793 100644
--- a/klm/lm/model_type.hh
+++ b/klm/lm/model_type.hh
@@ -6,10 +6,17 @@ namespace ngram {
/* Not the best numbering system, but it grew this way for historical reasons
* and I want to preserve existing binary files. */
-typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType;
+typedef enum {PROBING=0, REST_PROBING=1, TRIE=2, QUANT_TRIE=3, ARRAY_TRIE=4, QUANT_ARRAY_TRIE=5} ModelType;
-const static ModelType kQuantAdd = static_cast<ModelType>(QUANT_TRIE_SORTED - TRIE_SORTED);
-const static ModelType kArrayAdd = static_cast<ModelType>(ARRAY_TRIE_SORTED - TRIE_SORTED);
+// Historical names.
+const ModelType HASH_PROBING = PROBING;
+const ModelType TRIE_SORTED = TRIE;
+const ModelType QUANT_TRIE_SORTED = QUANT_TRIE;
+const ModelType ARRAY_TRIE_SORTED = ARRAY_TRIE;
+const ModelType QUANT_ARRAY_TRIE_SORTED = QUANT_ARRAY_TRIE;
+
+const static ModelType kQuantAdd = static_cast<ModelType>(QUANT_TRIE - TRIE);
+const static ModelType kArrayAdd = static_cast<ModelType>(ARRAY_TRIE - TRIE);
} // namespace ngram
} // namespace lm
diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc
index 8f7a0e1c..49757d9a 100644
--- a/klm/lm/ngram_query.cc
+++ b/klm/lm/ngram_query.cc
@@ -12,22 +12,24 @@ int main(int argc, char *argv[]) {
ModelType model_type;
if (RecognizeBinary(argv[1], model_type)) {
switch(model_type) {
- case HASH_PROBING:
+ case PROBING:
Query<lm::ngram::ProbingModel>(argv[1], sentence_context, std::cin, std::cout);
break;
- case TRIE_SORTED:
+ case REST_PROBING:
+ Query<lm::ngram::RestProbingModel>(argv[1], sentence_context, std::cin, std::cout);
+ break;
+ case TRIE:
Query<TrieModel>(argv[1], sentence_context, std::cin, std::cout);
break;
- case QUANT_TRIE_SORTED:
+ case QUANT_TRIE:
Query<QuantTrieModel>(argv[1], sentence_context, std::cin, std::cout);
break;
- case ARRAY_TRIE_SORTED:
+ case ARRAY_TRIE:
Query<ArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout);
break;
- case QUANT_ARRAY_TRIE_SORTED:
+ case QUANT_ARRAY_TRIE:
Query<QuantArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout);
break;
- case HASH_SORTED:
default:
std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
abort();
@@ -35,8 +37,8 @@ int main(int argc, char *argv[]) {
} else {
Query<ProbingModel>(argv[1], sentence_context, std::cin, std::cout);
}
-
- PrintUsage("Total time including destruction:\n");
+ std::cerr << "Total time including destruction:\n";
+ util::PrintUsage(std::cerr);
} catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
return 1;
diff --git a/klm/lm/ngram_query.hh b/klm/lm/ngram_query.hh
index 4990df22..dfcda170 100644
--- a/klm/lm/ngram_query.hh
+++ b/klm/lm/ngram_query.hh
@@ -3,51 +3,20 @@
#include "lm/enumerate_vocab.hh"
#include "lm/model.hh"
+#include "util/usage.hh"
#include <cstdlib>
-#include <fstream>
#include <iostream>
+#include <ostream>
+#include <istream>
#include <string>
-#include <ctype.h>
-#if !defined(_WIN32) && !defined(_WIN64)
-#include <sys/resource.h>
-#include <sys/time.h>
-#endif
-
namespace lm {
namespace ngram {
-#if !defined(_WIN32) && !defined(_WIN64)
-float FloatSec(const struct timeval &tv) {
- return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_usec) / 1000000000.0);
-}
-#endif
-
-void PrintUsage(const char *message) {
-#if !defined(_WIN32) && !defined(_WIN64)
- struct rusage usage;
- if (getrusage(RUSAGE_SELF, &usage)) {
- perror("getrusage");
- return;
- }
- std::cerr << message;
- std::cerr << "user\t" << FloatSec(usage.ru_utime) << "\nsys\t" << FloatSec(usage.ru_stime) << '\n';
-
- // Linux doesn't set memory usage :-(.
- std::ifstream status("/proc/self/status", std::ios::in);
- std::string line;
- while (getline(status, line)) {
- if (!strncmp(line.c_str(), "VmRSS:\t", 7)) {
- std::cerr << "rss " << (line.c_str() + 7) << '\n';
- break;
- }
- }
-#endif
-}
-
template <class Model> void Query(const Model &model, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) {
- PrintUsage("Loading statistics:\n");
+ std::cerr << "Loading statistics:\n";
+ util::PrintUsage(std::cerr);
typename Model::State state, out;
lm::FullScoreReturn ret;
std::string word;
@@ -84,13 +53,13 @@ template <class Model> void Query(const Model &model, bool sentence_context, std
out_stream << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
}
out_stream << "Total: " << total << " OOV: " << oov << '\n';
- }
- PrintUsage("After queries:\n");
+ }
+ std::cerr << "After queries:\n";
+ util::PrintUsage(std::cerr);
}
template <class M> void Query(const char *file, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) {
Config config;
-// config.load_method = util::LAZY;
M model(file, config);
Query(model, sentence_context, in_stream, out_stream);
}
diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc
index a8e0cb21..b58c3f3f 100644
--- a/klm/lm/quantize.cc
+++ b/klm/lm/quantize.cc
@@ -47,9 +47,7 @@ void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64
util::AdvanceOrThrow(fd, -3);
}
-void SeparatelyQuantize::SetupMemory(void *start, const Config &config) {
- // Reserve 8 byte header for bit counts.
- start_ = reinterpret_cast<float*>(static_cast<uint8_t*>(start) + 8);
+void SeparatelyQuantize::SetupMemory(void *base, unsigned char order, const Config &config) {
prob_bits_ = config.prob_bits;
backoff_bits_ = config.backoff_bits;
// We need the reserved values.
@@ -57,25 +55,35 @@ void SeparatelyQuantize::SetupMemory(void *start, const Config &config) {
if (config.backoff_bits == 0) UTIL_THROW(ConfigException, "You can't quantize backoff to zero");
if (config.prob_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing probability supports at most 25 bits. Currently you have requested " << static_cast<unsigned>(config.prob_bits) << " bits.");
if (config.backoff_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing backoff supports at most 25 bits. Currently you have requested " << static_cast<unsigned>(config.backoff_bits) << " bits.");
+ // Reserve 8 byte header for bit counts.
+ actual_base_ = static_cast<uint8_t*>(base);
+ float *start = reinterpret_cast<float*>(actual_base_ + 8);
+ for (unsigned char i = 0; i < order - 2; ++i) {
+ tables_[i][0] = Bins(prob_bits_, start);
+ start += (1ULL << prob_bits_);
+ tables_[i][1] = Bins(backoff_bits_, start);
+ start += (1ULL << backoff_bits_);
+ }
+ longest_ = tables_[order - 2][0] = Bins(prob_bits_, start);
}
void SeparatelyQuantize::Train(uint8_t order, std::vector<float> &prob, std::vector<float> &backoff) {
TrainProb(order, prob);
// Backoff
- float *centers = start_ + TableStart(order) + ProbTableLength();
+ float *centers = tables_[order - 2][1].Populate();
*(centers++) = kNoExtensionBackoff;
*(centers++) = kExtensionBackoff;
MakeBins(backoff, centers, (1ULL << backoff_bits_) - 2);
}
void SeparatelyQuantize::TrainProb(uint8_t order, std::vector<float> &prob) {
- float *centers = start_ + TableStart(order);
+ float *centers = tables_[order - 2][0].Populate();
MakeBins(prob, centers, (1ULL << prob_bits_));
}
void SeparatelyQuantize::FinishedLoading(const Config &config) {
- uint8_t *actual_base = reinterpret_cast<uint8_t*>(start_) - 8;
+ uint8_t *actual_base = actual_base_;
*(actual_base++) = kSeparatelyQuantizeVersion; // version
*(actual_base++) = config.prob_bits;
*(actual_base++) = config.backoff_bits;
diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh
index 6d130a57..3e9153e3 100644
--- a/klm/lm/quantize.hh
+++ b/klm/lm/quantize.hh
@@ -3,6 +3,7 @@
#include "lm/blank.hh"
#include "lm/config.hh"
+#include "lm/max_order.hh"
#include "lm/model_type.hh"
#include "util/bit_packing.hh"
@@ -27,37 +28,60 @@ class DontQuantize {
static uint8_t MiddleBits(const Config &/*config*/) { return 63; }
static uint8_t LongestBits(const Config &/*config*/) { return 31; }
- struct Middle {
- void Write(void *base, uint64_t bit_offset, float prob, float backoff) const {
- util::WriteNonPositiveFloat31(base, bit_offset, prob);
- util::WriteFloat32(base, bit_offset + 31, backoff);
- }
- void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const {
- prob = util::ReadNonPositiveFloat31(base, bit_offset);
- backoff = util::ReadFloat32(base, bit_offset + 31);
- }
- void ReadProb(const void *base, uint64_t bit_offset, float &prob) const {
- prob = util::ReadNonPositiveFloat31(base, bit_offset);
- }
- void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const {
- backoff = util::ReadFloat32(base, bit_offset + 31);
- }
- uint8_t TotalBits() const { return 63; }
+ class MiddlePointer {
+ public:
+ MiddlePointer(const DontQuantize & /*quant*/, unsigned char /*order_minus_2*/, util::BitAddress address) : address_(address) {}
+
+ MiddlePointer() : address_(NULL, 0) {}
+
+ bool Found() const {
+ return address_.base != NULL;
+ }
+
+ float Prob() const {
+ return util::ReadNonPositiveFloat31(address_.base, address_.offset);
+ }
+
+ float Backoff() const {
+ return util::ReadFloat32(address_.base, address_.offset + 31);
+ }
+
+ float Rest() const { return Prob(); }
+
+ void Write(float prob, float backoff) {
+ util::WriteNonPositiveFloat31(address_.base, address_.offset, prob);
+ util::WriteFloat32(address_.base, address_.offset + 31, backoff);
+ }
+
+ private:
+ util::BitAddress address_;
};
- struct Longest {
- void Write(void *base, uint64_t bit_offset, float prob) const {
- util::WriteNonPositiveFloat31(base, bit_offset, prob);
- }
- void Read(const void *base, uint64_t bit_offset, float &prob) const {
- prob = util::ReadNonPositiveFloat31(base, bit_offset);
- }
- uint8_t TotalBits() const { return 31; }
+ class LongestPointer {
+ public:
+ explicit LongestPointer(const DontQuantize &/*quant*/, util::BitAddress address) : address_(address) {}
+
+ LongestPointer() : address_(NULL, 0) {}
+
+ bool Found() const {
+ return address_.base != NULL;
+ }
+
+ float Prob() const {
+ return util::ReadNonPositiveFloat31(address_.base, address_.offset);
+ }
+
+ void Write(float prob) {
+ util::WriteNonPositiveFloat31(address_.base, address_.offset, prob);
+ }
+
+ private:
+ util::BitAddress address_;
};
DontQuantize() {}
- void SetupMemory(void * /*start*/, const Config & /*config*/) {}
+ void SetupMemory(void * /*start*/, unsigned char /*order*/, const Config & /*config*/) {}
static const bool kTrain = false;
// These should never be called because kTrain is false.
@@ -65,9 +89,6 @@ class DontQuantize {
void TrainProb(uint8_t, std::vector<float> &/*prob*/) {}
void FinishedLoading(const Config &) {}
-
- Middle Mid(uint8_t /*order*/) const { return Middle(); }
- Longest Long(uint8_t /*order*/) const { return Longest(); }
};
class SeparatelyQuantize {
@@ -77,7 +98,9 @@ class SeparatelyQuantize {
// Sigh C++ default constructor
Bins() {}
- Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {}
+ Bins(uint8_t bits, float *begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {}
+
+ float *Populate() { return begin_; }
uint64_t EncodeProb(float value) const {
return Encode(value, 0);
@@ -98,13 +121,13 @@ class SeparatelyQuantize {
private:
uint64_t Encode(float value, size_t reserved) const {
- const float *above = std::lower_bound(begin_ + reserved, end_, value);
+ const float *above = std::lower_bound(static_cast<const float*>(begin_) + reserved, end_, value);
if (above == begin_ + reserved) return reserved;
if (above == end_) return end_ - begin_ - 1;
return above - begin_ - (value - *(above - 1) < *above - value);
}
- const float *begin_;
+ float *begin_;
const float *end_;
uint8_t bits_;
uint64_t mask_;
@@ -125,65 +148,61 @@ class SeparatelyQuantize {
static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; }
static uint8_t LongestBits(const Config &config) { return config.prob_bits; }
- class Middle {
+ class MiddlePointer {
public:
- Middle(uint8_t prob_bits, const float *prob_begin, uint8_t backoff_bits, const float *backoff_begin) :
- total_bits_(prob_bits + backoff_bits), total_mask_((1ULL << total_bits_) - 1), prob_(prob_bits, prob_begin), backoff_(backoff_bits, backoff_begin) {}
+ MiddlePointer(const SeparatelyQuantize &quant, unsigned char order_minus_2, const util::BitAddress &address) : bins_(quant.GetTables(order_minus_2)), address_(address) {}
- void Write(void *base, uint64_t bit_offset, float prob, float backoff) const {
- util::WriteInt57(base, bit_offset, total_bits_,
- (prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff));
- }
+ MiddlePointer() : address_(NULL, 0) {}
- void ReadProb(const void *base, uint64_t bit_offset, float &prob) const {
- prob = prob_.Decode(util::ReadInt25(base, bit_offset + backoff_.Bits(), prob_.Bits(), prob_.Mask()));
- }
+ bool Found() const { return address_.base != NULL; }
- void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const {
- uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_);
- prob = prob_.Decode(both >> backoff_.Bits());
- backoff = backoff_.Decode(both & backoff_.Mask());
+ float Prob() const {
+ return ProbBins().Decode(util::ReadInt25(address_.base, address_.offset + BackoffBins().Bits(), ProbBins().Bits(), ProbBins().Mask()));
}
- void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const {
- backoff = backoff_.Decode(util::ReadInt25(base, bit_offset, backoff_.Bits(), backoff_.Mask()));
+ float Backoff() const {
+ return BackoffBins().Decode(util::ReadInt25(address_.base, address_.offset, BackoffBins().Bits(), BackoffBins().Mask()));
}
- uint8_t TotalBits() const {
- return total_bits_;
+ float Rest() const { return Prob(); }
+
+ void Write(float prob, float backoff) const {
+ util::WriteInt57(address_.base, address_.offset, ProbBins().Bits() + BackoffBins().Bits(),
+ (ProbBins().EncodeProb(prob) << BackoffBins().Bits()) | BackoffBins().EncodeBackoff(backoff));
}
private:
- const uint8_t total_bits_;
- const uint64_t total_mask_;
- const Bins prob_;
- const Bins backoff_;
+ const Bins &ProbBins() const { return bins_[0]; }
+ const Bins &BackoffBins() const { return bins_[1]; }
+ const Bins *bins_;
+
+ util::BitAddress address_;
};
- class Longest {
+ class LongestPointer {
public:
- // Sigh C++ default constructor
- Longest() {}
+ LongestPointer(const SeparatelyQuantize &quant, const util::BitAddress &address) : table_(&quant.LongestTable()), address_(address) {}
+
+ LongestPointer() : address_(NULL, 0) {}
- Longest(uint8_t prob_bits, const float *prob_begin) : prob_(prob_bits, prob_begin) {}
+ bool Found() const { return address_.base != NULL; }
- void Write(void *base, uint64_t bit_offset, float prob) const {
- util::WriteInt25(base, bit_offset, prob_.Bits(), prob_.EncodeProb(prob));
+ void Write(float prob) const {
+ util::WriteInt25(address_.base, address_.offset, table_->Bits(), table_->EncodeProb(prob));
}
- void Read(const void *base, uint64_t bit_offset, float &prob) const {
- prob = prob_.Decode(util::ReadInt25(base, bit_offset, prob_.Bits(), prob_.Mask()));
+ float Prob() const {
+ return table_->Decode(util::ReadInt25(address_.base, address_.offset, table_->Bits(), table_->Mask()));
}
- uint8_t TotalBits() const { return prob_.Bits(); }
-
private:
- Bins prob_;
+ const Bins *table_;
+ util::BitAddress address_;
};
SeparatelyQuantize() {}
- void SetupMemory(void *start, const Config &config);
+ void SetupMemory(void *start, unsigned char order, const Config &config);
static const bool kTrain = true;
// Assumes 0.0 is removed from backoff.
@@ -193,18 +212,17 @@ class SeparatelyQuantize {
void FinishedLoading(const Config &config);
- Middle Mid(uint8_t order) const {
- const float *table = start_ + TableStart(order);
- return Middle(prob_bits_, table, backoff_bits_, table + ProbTableLength());
- }
+ const Bins *GetTables(unsigned char order_minus_2) const { return tables_[order_minus_2]; }
- Longest Long(uint8_t order) const { return Longest(prob_bits_, start_ + TableStart(order)); }
+ const Bins &LongestTable() const { return longest_; }
private:
- size_t TableStart(uint8_t order) const { return ((1ULL << prob_bits_) + (1ULL << backoff_bits_)) * static_cast<uint64_t>(order - 2); }
- size_t ProbTableLength() const { return (1ULL << prob_bits_); }
+ Bins tables_[kMaxOrder - 1][2];
+
+ Bins longest_;
+
+ uint8_t *actual_base_;
- float *start_;
uint8_t prob_bits_, backoff_bits_;
};
diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc
index 05f761be..2d9a337d 100644
--- a/klm/lm/read_arpa.cc
+++ b/klm/lm/read_arpa.cc
@@ -83,7 +83,7 @@ void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) {
}
}
-void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) {
+void ReadBackoff(util::FilePiece &in, float &backoff) {
// Always make zero negative.
// Negative zero means that no (n+1)-gram has this n-gram as context.
// Therefore the hypothesis state can be shorter. Of course, many n-grams
@@ -91,12 +91,12 @@ void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) {
// back and set the backoff to positive zero in these cases.
switch (in.get()) {
case '\t':
- weights.backoff = in.ReadFloat();
- if (weights.backoff == ngram::kExtensionBackoff) weights.backoff = ngram::kNoExtensionBackoff;
+ backoff = in.ReadFloat();
+ if (backoff == ngram::kExtensionBackoff) backoff = ngram::kNoExtensionBackoff;
if ((in.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff");
break;
case '\n':
- weights.backoff = ngram::kNoExtensionBackoff;
+ backoff = ngram::kNoExtensionBackoff;
break;
default:
UTIL_THROW(FormatLoadException, "Expected tab or newline for backoff");
diff --git a/klm/lm/read_arpa.hh b/klm/lm/read_arpa.hh
index ab996bde..234d130c 100644
--- a/klm/lm/read_arpa.hh
+++ b/klm/lm/read_arpa.hh
@@ -16,7 +16,13 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number);
void ReadNGramHeader(util::FilePiece &in, unsigned int length);
void ReadBackoff(util::FilePiece &in, Prob &weights);
-void ReadBackoff(util::FilePiece &in, ProbBackoff &weights);
+void ReadBackoff(util::FilePiece &in, float &backoff);
+inline void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) {
+ ReadBackoff(in, weights.backoff);
+}
+inline void ReadBackoff(util::FilePiece &in, RestWeights &weights) {
+ ReadBackoff(in, weights.backoff);
+}
void ReadEnd(util::FilePiece &in);
@@ -35,7 +41,7 @@ class PositiveProbWarn {
WarningAction action_;
};
-template <class Voc> void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff *unigrams, PositiveProbWarn &warn) {
+template <class Voc, class Weights> void Read1Gram(util::FilePiece &f, Voc &vocab, Weights *unigrams, PositiveProbWarn &warn) {
try {
float prob = f.ReadFloat();
if (prob > 0.0) {
@@ -43,7 +49,7 @@ template <class Voc> void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff
prob = 0.0;
}
if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability");
- ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited(kARPASpaces))];
+ Weights &value = unigrams[vocab.Insert(f.ReadDelimited(kARPASpaces))];
value.prob = prob;
ReadBackoff(f, value);
} catch(util::Exception &e) {
@@ -53,7 +59,7 @@ template <class Voc> void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff
}
// Return true if a positive log probability came out.
-template <class Voc> void Read1Grams(util::FilePiece &f, std::size_t count, Voc &vocab, ProbBackoff *unigrams, PositiveProbWarn &warn) {
+template <class Voc, class Weights> void Read1Grams(util::FilePiece &f, std::size_t count, Voc &vocab, Weights *unigrams, PositiveProbWarn &warn) {
ReadNGramHeader(f, 1);
for (std::size_t i = 0; i < count; ++i) {
Read1Gram(f, vocab, unigrams, warn);
diff --git a/klm/lm/return.hh b/klm/lm/return.hh
index 1b55091b..622320ce 100644
--- a/klm/lm/return.hh
+++ b/klm/lm/return.hh
@@ -33,6 +33,9 @@ struct FullScoreReturn {
*/
bool independent_left;
uint64_t extend_left; // Defined only if independent_left
+
+ // Rest cost for extension to the left.
+ float rest;
};
} // namespace lm
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index 1d6fb5be..13942309 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -3,7 +3,9 @@
#include "lm/binary_format.hh"
#include "lm/blank.hh"
#include "lm/lm_exception.hh"
+#include "lm/model.hh"
#include "lm/read_arpa.hh"
+#include "lm/value.hh"
#include "lm/vocab.hh"
#include "util/bit_packing.hh"
@@ -14,6 +16,8 @@
namespace lm {
namespace ngram {
+class ProbingModel;
+
namespace {
/* These are passed to ReadNGrams so that n-grams with zero backoff that appear as context will still be used in state. */
@@ -37,9 +41,9 @@ template <class Middle> class ActivateLowerMiddle {
Middle &modify_;
};
-class ActivateUnigram {
+template <class Weights> class ActivateUnigram {
public:
- explicit ActivateUnigram(ProbBackoff *unigram) : modify_(unigram) {}
+ explicit ActivateUnigram(Weights *unigram) : modify_(unigram) {}
void operator()(const WordIndex *vocab_ids, const unsigned int /*n*/) {
// assert(n == 2);
@@ -47,43 +51,124 @@ class ActivateUnigram {
}
private:
- ProbBackoff *modify_;
+ Weights *modify_;
};
-template <class Middle> void FixSRI(int lower, float negative_lower_prob, unsigned int n, const uint64_t *keys, const WordIndex *vocab_ids, ProbBackoff *unigrams, std::vector<Middle> &middle) {
- ProbBackoff blank;
- blank.backoff = kNoExtensionBackoff;
- // Fix SRI's stupidity.
- // Note that negative_lower_prob is the negative of the probability (so it's currently >= 0). We still want the sign bit off to indicate left extension, so I just do -= on the backoffs.
- blank.prob = negative_lower_prob;
- // An entry was found at lower (order lower + 2).
- // We need to insert blanks starting at lower + 1 (order lower + 3).
- unsigned int fix = static_cast<unsigned int>(lower + 1);
- uint64_t backoff_hash = detail::CombineWordHash(static_cast<uint64_t>(vocab_ids[1]), vocab_ids[2]);
- if (fix == 0) {
- // Insert a missing bigram.
- blank.prob -= unigrams[vocab_ids[1]].backoff;
- SetExtension(unigrams[vocab_ids[1]].backoff);
- // Bigram including a unigram's backoff
- middle[0].Insert(detail::ProbBackoffEntry::Make(keys[0], blank));
- fix = 1;
- } else {
- for (unsigned int i = 3; i < fix + 2; ++i) backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]);
+// Find the lower order entry, inserting blanks along the way as necessary.
+template <class Value> void FindLower(
+ const std::vector<uint64_t> &keys,
+ typename Value::Weights &unigram,
+ std::vector<util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> > &middle,
+ std::vector<typename Value::Weights *> &between) {
+ typename util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash>::MutableIterator iter;
+ typename Value::ProbingEntry entry;
+ // Backoff will always be 0.0. We'll get the probability and rest in another pass.
+ entry.value.backoff = kNoExtensionBackoff;
+ // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb.
+ for (int lower = keys.size() - 2; ; --lower) {
+ if (lower == -1) {
+ between.push_back(&unigram);
+ return;
+ }
+ entry.key = keys[lower];
+ bool found = middle[lower].FindOrInsert(entry, iter);
+ between.push_back(&iter->value);
+ if (found) return;
+ }
+}
+
+// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here.
+template <class Added, class Build> void AdjustLower(
+ const Added &added,
+ const Build &build,
+ std::vector<typename Build::Value::Weights *> &between,
+ const unsigned int n,
+ const std::vector<WordIndex> &vocab_ids,
+ typename Build::Value::Weights *unigrams,
+ std::vector<util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash> > &middle) {
+ typedef typename Build::Value Value;
+ if (between.size() == 1) {
+ build.MarkExtends(*between.front(), added);
+ return;
+ }
+ typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
+ float prob = -fabs(between.back()->prob);
+ // Order of the n-gram on which probabilities are based.
+ unsigned char basis = n - between.size();
+ assert(basis != 0);
+ typename Build::Value::Weights **change = &between.back();
+ // Skip the basis.
+ --change;
+ if (basis == 1) {
+ // Hallucinate a bigram based on a unigram's backoff and a unigram probability.
+ float &backoff = unigrams[vocab_ids[1]].backoff;
+ SetExtension(backoff);
+ prob += backoff;
+ (*change)->prob = prob;
+ build.SetRest(&*vocab_ids.begin(), 2, **change);
+ basis = 2;
+ --change;
}
- // fix >= 1. Insert trigrams and above.
- for (; fix <= n - 3; ++fix) {
+ uint64_t backoff_hash = static_cast<uint64_t>(vocab_ids[1]);
+ for (unsigned char i = 2; i <= basis; ++i) {
+ backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]);
+ }
+ for (; basis < n - 1; ++basis, --change) {
typename Middle::MutableIterator gotit;
- if (middle[fix - 1].UnsafeMutableFind(backoff_hash, gotit)) {
+ if (middle[basis - 2].UnsafeMutableFind(backoff_hash, gotit)) {
float &backoff = gotit->value.backoff;
SetExtension(backoff);
- blank.prob -= backoff;
+ prob += backoff;
}
- middle[fix].Insert(detail::ProbBackoffEntry::Make(keys[fix], blank));
- backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[fix + 2]);
+ (*change)->prob = prob;
+ build.SetRest(&*vocab_ids.begin(), basis + 1, **change);
+ backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[basis+1]);
+ }
+
+ typename std::vector<typename Value::Weights *>::const_iterator i(between.begin());
+ build.MarkExtends(**i, added);
+ const typename Value::Weights *longer = *i;
+ // Everything has probability but is not marked as extending.
+ for (++i; i != between.end(); ++i) {
+ build.MarkExtends(**i, *longer);
+ longer = *i;
}
}
-template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, ProbBackoff *unigrams, std::vector<Middle> &middle, Activate activate, Store &store, PositiveProbWarn &warn) {
+// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds.
+template <class Build> void MarkLower(
+ const std::vector<uint64_t> &keys,
+ const Build &build,
+ typename Build::Value::Weights &unigram,
+ std::vector<util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash> > &middle,
+ int start_order,
+ const typename Build::Value::Weights &longer) {
+ if (start_order == 0) return;
+ typename util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash>::MutableIterator iter;
+ // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code.
+ for (int even_lower = start_order - 2 /* index in middle */; ; --even_lower) {
+ if (even_lower == -1) {
+ build.MarkExtends(unigram, longer);
+ return;
+ }
+ middle[even_lower].UnsafeMutableFind(keys[even_lower], iter);
+ if (!build.MarkExtends(iter->value, longer)) return;
+ }
+}
+
+template <class Build, class Activate, class Store> void ReadNGrams(
+ util::FilePiece &f,
+ const unsigned int n,
+ const size_t count,
+ const ProbingVocabulary &vocab,
+ const Build &build,
+ typename Build::Value::Weights *unigrams,
+ std::vector<util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash> > &middle,
+ Activate activate,
+ Store &store,
+ PositiveProbWarn &warn) {
+ typedef typename Build::Value Value;
+ typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
assert(n >= 2);
ReadNGramHeader(f, n);
@@ -91,38 +176,25 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(
// vocab ids of words in reverse order.
std::vector<WordIndex> vocab_ids(n);
std::vector<uint64_t> keys(n-1);
- typename Store::Entry::Value value;
- typename Middle::MutableIterator found;
+ typename Store::Entry entry;
+ std::vector<typename Value::Weights *> between;
for (size_t i = 0; i < count; ++i) {
- ReadNGram(f, n, vocab, &*vocab_ids.begin(), value, warn);
+ ReadNGram(f, n, vocab, &*vocab_ids.begin(), entry.value, warn);
+ build.SetRest(&*vocab_ids.begin(), n, entry.value);
keys[0] = detail::CombineWordHash(static_cast<uint64_t>(vocab_ids.front()), vocab_ids[1]);
for (unsigned int h = 1; h < n - 1; ++h) {
keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]);
}
// Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0.
- util::SetSign(value.prob);
- store.Insert(Store::Entry::Make(keys[n-2], value));
- // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb.
- int lower;
- util::FloatEnc fix_prob;
- for (lower = n - 3; ; --lower) {
- if (lower == -1) {
- fix_prob.f = unigrams[vocab_ids.front()].prob;
- fix_prob.i &= ~util::kSignBit;
- unigrams[vocab_ids.front()].prob = fix_prob.f;
- break;
- }
- if (middle[lower].UnsafeMutableFind(keys[lower], found)) {
- // Turn off sign bit to indicate that it extends left.
- fix_prob.f = found->value.prob;
- fix_prob.i &= ~util::kSignBit;
- found->value.prob = fix_prob.f;
- // We don't need to recurse further down because this entry already set the bits for lower entries.
- break;
- }
- }
- if (lower != static_cast<int>(n) - 3) FixSRI(lower, fix_prob.f, n, &*keys.begin(), &*vocab_ids.begin(), unigrams, middle);
+ util::SetSign(entry.value.prob);
+ entry.key = keys[n-2];
+
+ store.Insert(entry);
+ between.clear();
+ FindLower<Value>(keys, unigrams[vocab_ids.front()], middle, between);
+ AdjustLower<typename Store::Entry::Value, Build>(entry.value, build, between, n, vocab_ids, unigrams, middle);
+ if (Build::kMarkEvenLower) MarkLower<Build>(keys, build, unigrams[vocab_ids.front()], middle, n - between.size() - 1, *between.back());
activate(&*vocab_ids.begin(), n);
}
@@ -132,9 +204,9 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(
} // namespace
namespace detail {
-template <class MiddleT, class LongestT> uint8_t *TemplateHashedSearch<MiddleT, LongestT>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
+template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
std::size_t allocated = Unigram::Size(counts[0]);
- unigram = Unigram(start, allocated);
+ unigram_ = Unigram(start, counts[0], allocated);
start += allocated;
for (unsigned int n = 2; n < counts.size(); ++n) {
allocated = Middle::Size(counts[n - 1], config.probing_multiplier);
@@ -142,31 +214,63 @@ template <class MiddleT, class LongestT> uint8_t *TemplateHashedSearch<MiddleT,
start += allocated;
}
allocated = Longest::Size(counts.back(), config.probing_multiplier);
- longest = Longest(start, allocated);
+ longest_ = Longest(start, allocated);
start += allocated;
return start;
}
-template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing) {
+template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing) {
// TODO: fix sorted.
SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config);
PositiveProbWarn warn(config.positive_log_probability);
-
- Read1Grams(f, counts[0], vocab, unigram.Raw(), warn);
+ Read1Grams(f, counts[0], vocab, unigram_.Raw(), warn);
CheckSpecials(config, vocab);
+ DispatchBuild(f, counts, config, vocab, warn);
+}
+
+template <> void HashedSearch<BackoffValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {
+ NoRestBuild build;
+ ApplyBuild(f, counts, config, vocab, warn, build);
+}
+
+template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {
+ switch (config.rest_function) {
+ case Config::REST_MAX:
+ {
+ MaxRestBuild build;
+ ApplyBuild(f, counts, config, vocab, warn, build);
+ }
+ break;
+ case Config::REST_LOWER:
+ {
+ LowerRestBuild<ProbingModel> build(config, counts.size(), vocab);
+ ApplyBuild(f, counts, config, vocab, warn, build);
+ }
+ break;
+ }
+}
+
+template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) {
+ for (WordIndex i = 0; i < counts[0]; ++i) {
+ build.SetRest(&i, (unsigned int)1, unigram_.Raw()[i]);
+ }
try {
if (counts.size() > 2) {
- ReadNGrams(f, 2, counts[1], vocab, unigram.Raw(), middle_, ActivateUnigram(unigram.Raw()), middle_[0], warn);
+ ReadNGrams<Build, ActivateUnigram<typename Value::Weights>, Middle>(
+ f, 2, counts[1], vocab, build, unigram_.Raw(), middle_, ActivateUnigram<typename Value::Weights>(unigram_.Raw()), middle_[0], warn);
}
for (unsigned int n = 3; n < counts.size(); ++n) {
- ReadNGrams(f, n, counts[n-1], vocab, unigram.Raw(), middle_, ActivateLowerMiddle<Middle>(middle_[n-3]), middle_[n-2], warn);
+ ReadNGrams<Build, ActivateLowerMiddle<Middle>, Middle>(
+ f, n, counts[n-1], vocab, build, unigram_.Raw(), middle_, ActivateLowerMiddle<Middle>(middle_[n-3]), middle_[n-2], warn);
}
if (counts.size() > 2) {
- ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, unigram.Raw(), middle_, ActivateLowerMiddle<Middle>(middle_.back()), longest, warn);
+ ReadNGrams<Build, ActivateLowerMiddle<Middle>, Longest>(
+ f, counts.size(), counts[counts.size() - 1], vocab, build, unigram_.Raw(), middle_, ActivateLowerMiddle<Middle>(middle_.back()), longest_, warn);
} else {
- ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, unigram.Raw(), middle_, ActivateUnigram(unigram.Raw()), longest, warn);
+ ReadNGrams<Build, ActivateUnigram<typename Value::Weights>, Longest>(
+ f, counts.size(), counts[counts.size() - 1], vocab, build, unigram_.Raw(), middle_, ActivateUnigram<typename Value::Weights>(unigram_.Raw()), longest_, warn);
}
} catch (util::ProbingSizeException &e) {
UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n");
@@ -174,17 +278,16 @@ template <class MiddleT, class LongestT> template <class Voc> void TemplateHashe
ReadEnd(f);
}
-template <class MiddleT, class LongestT> void TemplateHashedSearch<MiddleT, LongestT>::LoadedBinary() {
- unigram.LoadedBinary();
+template <class Value> void HashedSearch<Value>::LoadedBinary() {
+ unigram_.LoadedBinary();
for (typename std::vector<Middle>::iterator i = middle_.begin(); i != middle_.end(); ++i) {
i->LoadedBinary();
}
- longest.LoadedBinary();
+ longest_.LoadedBinary();
}
-template class TemplateHashedSearch<ProbingHashedSearch::Middle, ProbingHashedSearch::Longest>;
-
-template void TemplateHashedSearch<ProbingHashedSearch::Middle, ProbingHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, ProbingVocabulary &vocab, Backing &backing);
+template class HashedSearch<BackoffValue>;
+template class HashedSearch<RestValue>;
} // namespace detail
} // namespace ngram
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh
index 4352c72d..7e8c1220 100644
--- a/klm/lm/search_hashed.hh
+++ b/klm/lm/search_hashed.hh
@@ -19,6 +19,7 @@ namespace util { class FilePiece; }
namespace lm {
namespace ngram {
struct Backing;
+class ProbingVocabulary;
namespace detail {
inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
@@ -26,54 +27,48 @@ inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
return ret;
}
-struct HashedSearch {
- typedef uint64_t Node;
-
- class Unigram {
- public:
- Unigram() {}
-
- Unigram(void *start, std::size_t /*allocated*/) : unigram_(static_cast<ProbBackoff*>(start)) {}
-
- static std::size_t Size(uint64_t count) {
- return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk>
- }
-
- const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index]; }
+#pragma pack(push)
+#pragma pack(4)
+struct ProbEntry {
+ uint64_t key;
+ Prob value;
+ typedef uint64_t Key;
+ typedef Prob Value;
+ uint64_t GetKey() const {
+ return key;
+ }
+};
- ProbBackoff &Unknown() { return unigram_[0]; }
+#pragma pack(pop)
- void LoadedBinary() {}
+class LongestPointer {
+ public:
+ explicit LongestPointer(const float &to) : to_(&to) {}
- // For building.
- ProbBackoff *Raw() { return unigram_; }
+ LongestPointer() : to_(NULL) {}
- private:
- ProbBackoff *unigram_;
- };
+ bool Found() const {
+ return to_ != NULL;
+ }
- Unigram unigram;
+ float Prob() const {
+ return *to_;
+ }
- void LookupUnigram(WordIndex word, float &backoff, Node &next, FullScoreReturn &ret) const {
- const ProbBackoff &entry = unigram.Lookup(word);
- util::FloatEnc val;
- val.f = entry.prob;
- ret.independent_left = (val.i & util::kSignBit);
- ret.extend_left = static_cast<uint64_t>(word);
- val.i |= util::kSignBit;
- ret.prob = val.f;
- backoff = entry.backoff;
- next = static_cast<Node>(word);
- }
+ private:
+ const float *to_;
};
-template <class MiddleT, class LongestT> class TemplateHashedSearch : public HashedSearch {
+template <class Value> class HashedSearch {
public:
- typedef MiddleT Middle;
+ typedef uint64_t Node;
- typedef LongestT Longest;
- Longest longest;
+ typedef typename Value::ProbingProxy UnigramPointer;
+ typedef typename Value::ProbingProxy MiddlePointer;
+ typedef ::lm::ngram::detail::LongestPointer LongestPointer;
+ static const ModelType kModelType = Value::kProbingModelType;
+ static const bool kDifferentRest = Value::kDifferentRest;
static const unsigned int kVersion = 0;
// TODO: move probing_multiplier here with next binary file format update.
@@ -89,64 +84,55 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
- template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing);
+ void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing);
- typedef typename std::vector<Middle>::const_iterator MiddleIter;
+ void LoadedBinary();
- MiddleIter MiddleBegin() const { return middle_.begin(); }
- MiddleIter MiddleEnd() const { return middle_.end(); }
+ unsigned char Order() const {
+ return middle_.size() + 2;
+ }
- Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const {
- util::FloatEnc val;
- if (extend_length == 1) {
- val.f = unigram.Lookup(static_cast<uint64_t>(extend_pointer)).prob;
- } else {
- typename Middle::ConstIterator found;
- if (!middle_[extend_length - 2].Find(extend_pointer, found)) {
- std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl;
- abort();
- }
- val.f = found->value.prob;
- }
- val.i |= util::kSignBit;
- prob = val.f;
- return extend_pointer;
+ typename Value::Weights &UnknownUnigram() { return unigram_.Unknown(); }
+
+ UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const {
+ extend_left = static_cast<uint64_t>(word);
+ next = extend_left;
+ UnigramPointer ret(unigram_.Lookup(word));
+ independent_left = ret.IndependentLeft();
+ return ret;
}
- bool LookupMiddle(const Middle &middle, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const {
- node = CombineWordHash(node, word);
+#pragma GCC diagnostic ignored "-Wuninitialized"
+ MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const {
+ node = extend_pointer;
typename Middle::ConstIterator found;
- if (!middle.Find(node, found)) return false;
- util::FloatEnc enc;
- enc.f = found->value.prob;
- ret.independent_left = (enc.i & util::kSignBit);
- ret.extend_left = node;
- enc.i |= util::kSignBit;
- ret.prob = enc.f;
- backoff = found->value.backoff;
- return true;
+ bool got = middle_[extend_length - 2].Find(extend_pointer, found);
+ assert(got);
+ (void)got;
+ return MiddlePointer(found->value);
}
- void LoadedBinary();
-
- bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const {
+ MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const {
node = CombineWordHash(node, word);
typename Middle::ConstIterator found;
- if (!middle.Find(node, found)) return false;
- backoff = found->value.backoff;
- return true;
+ if (!middle_[order_minus_2].Find(node, found)) {
+ independent_left = true;
+ return MiddlePointer();
+ }
+ extend_pointer = node;
+ MiddlePointer ret(found->value);
+ independent_left = ret.IndependentLeft();
+ return ret;
}
- bool LookupLongest(WordIndex word, float &prob, Node &node) const {
+ LongestPointer LookupLongest(WordIndex word, const Node &node) const {
// Sign bit is always on because longest n-grams do not extend left.
- node = CombineWordHash(node, word);
typename Longest::ConstIterator found;
- if (!longest.Find(node, found)) return false;
- prob = found->value.prob;
- return true;
+ if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer();
+ return LongestPointer(found->value.prob);
}
- // Geenrate a node without necessarily checking that it actually exists.
+ // Generate a node without necessarily checking that it actually exists.
// Optionally return false if it's know to not exist.
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
assert(begin != end);
@@ -158,55 +144,54 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has
}
private:
- std::vector<Middle> middle_;
-};
+ // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.
+ void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn);
-/* These look like perfect candidates for a template, right? Ancient gcc (4.1
- * on RedHat stale linux) doesn't pack templates correctly. ProbBackoffEntry
- * is a multiple of 8 bytes anyway. ProbEntry is 12 bytes so it's set to pack.
- */
-struct ProbBackoffEntry {
- uint64_t key;
- ProbBackoff value;
- typedef uint64_t Key;
- typedef ProbBackoff Value;
- uint64_t GetKey() const {
- return key;
- }
- static ProbBackoffEntry Make(uint64_t key, ProbBackoff value) {
- ProbBackoffEntry ret;
- ret.key = key;
- ret.value = value;
- return ret;
- }
-};
+ template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build);
-#pragma pack(push)
-#pragma pack(4)
-struct ProbEntry {
- uint64_t key;
- Prob value;
- typedef uint64_t Key;
- typedef Prob Value;
- uint64_t GetKey() const {
- return key;
- }
- static ProbEntry Make(uint64_t key, Prob value) {
- ProbEntry ret;
- ret.key = key;
- ret.value = value;
- return ret;
- }
-};
+ class Unigram {
+ public:
+ Unigram() {}
-#pragma pack(pop)
+ Unigram(void *start, uint64_t count, std::size_t /*allocated*/) :
+ unigram_(static_cast<typename Value::Weights*>(start))
+#ifdef DEBUG
+ , count_(count)
+#endif
+ {}
+
+ static std::size_t Size(uint64_t count) {
+ return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk>
+ }
+
+ const typename Value::Weights &Lookup(WordIndex index) const {
+#ifdef DEBUG
+ assert(index < count_);
+#endif
+ return unigram_[index];
+ }
+
+ typename Value::Weights &Unknown() { return unigram_[0]; }
+ void LoadedBinary() {}
-struct ProbingHashedSearch : public TemplateHashedSearch<
- util::ProbingHashTable<ProbBackoffEntry, util::IdentityHash>,
- util::ProbingHashTable<ProbEntry, util::IdentityHash> > {
+ // For building.
+ typename Value::Weights *Raw() { return unigram_; }
+
+ private:
+ typename Value::Weights *unigram_;
+#ifdef DEBUG
+ uint64_t count_;
+#endif
+ };
+
+ Unigram unigram_;
+
+ typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
+ std::vector<Middle> middle_;
- static const ModelType kModelType = HASH_PROBING;
+ typedef util::ProbingHashTable<ProbEntry, util::IdentityHash> Longest;
+ Longest longest_;
};
} // namespace detail
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index ffadfa94..18e80d5a 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -273,8 +273,9 @@ class FindBlanks {
// Phase to actually write n-grams to the trie.
template <class Quant, class Bhiksha> class WriteEntries {
public:
- WriteEntries(RecordReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, unsigned char order, SRISucks &sri) :
+ WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) :
contexts_(contexts),
+ quant_(quant),
unigrams_(unigrams),
middle_(middle),
longest_(longest),
@@ -290,7 +291,7 @@ template <class Quant, class Bhiksha> class WriteEntries {
void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char /*lower*/, float /*prob_base*/) {
ProbBackoff weights = sri_.GetBlank(order_, order, indices);
- middle_[order - 2].Insert(indices[order - 1], weights.prob, weights.backoff);
+ typename Quant::MiddlePointer(quant_, order - 2, middle_[order - 2].Insert(indices[order - 1])).Write(weights.prob, weights.backoff);
}
void Middle(const unsigned char order, const void *data) {
@@ -301,21 +302,22 @@ template <class Quant, class Bhiksha> class WriteEntries {
SetExtension(weights.backoff);
++context;
}
- middle_[order - 2].Insert(words[order - 1], weights.prob, weights.backoff);
+ typename Quant::MiddlePointer(quant_, order - 2, middle_[order - 2].Insert(words[order - 1])).Write(weights.prob, weights.backoff);
}
void Longest(const void *data) {
const WordIndex *words = reinterpret_cast<const WordIndex*>(data);
- longest_.Insert(words[order_ - 1], reinterpret_cast<const Prob*>(words + order_)->prob);
+ typename Quant::LongestPointer(quant_, longest_.Insert(words[order_ - 1])).Write(reinterpret_cast<const Prob*>(words + order_)->prob);
}
void Cleanup() {}
private:
RecordReader *contexts_;
+ const Quant &quant_;
UnigramValue *const unigrams_;
- BitPackedMiddle<typename Quant::Middle, Bhiksha> *const middle_;
- BitPackedLongest<typename Quant::Longest> &longest_;
+ BitPackedMiddle<Bhiksha> *const middle_;
+ BitPackedLongest &longest_;
BitPacked &bigram_pack_;
const unsigned char order_;
SRISucks &sri_;
@@ -380,7 +382,7 @@ template <class Doing> class BlankManager {
};
template <class Doing> void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) {
- util::ErsatzProgress progress(progress_out, message, unigram_count + 1);
+ util::ErsatzProgress progress(unigram_count + 1, progress_out, message);
WordIndex unigram = 0;
std::priority_queue<Gram> grams;
grams.push(Gram(&unigram, 1));
@@ -502,7 +504,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
inputs[i-2].Rewind();
}
if (Quant::kTrain) {
- util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0));
+ util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), config.messages, "Quantizing");
for (unsigned char i = 2; i < counts.size(); ++i) {
TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant);
}
@@ -510,7 +512,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
quant.FinishedLoading(config);
}
- UnigramValue *unigrams = out.unigram.Raw();
+ UnigramValue *unigrams = out.unigram_.Raw();
PopulateUnigramWeights(unigram_file.get(), counts[0], contexts[0], unigrams);
unigram_file.reset();
@@ -519,7 +521,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
}
// Fill entries except unigram probabilities.
{
- WriteEntries<Quant, Bhiksha> writer(contexts, unigrams, out.middle_begin_, out.longest, counts.size(), sri);
+ WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri);
RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer);
}
@@ -544,14 +546,14 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {
i->FinishedLoading((i+1)->InsertIndex(), config);
}
- (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex(), config);
+ (out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config);
}
}
template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
- quant_.SetupMemory(start, config);
+ quant_.SetupMemory(start, counts.size(), config);
start += Quant::Size(counts.size(), config);
- unigram.Init(start);
+ unigram_.Init(start);
start += Unigram::Size(counts[0]);
FreeMiddles();
middle_begin_ = static_cast<Middle*>(malloc(sizeof(Middle) * (counts.size() - 2)));
@@ -565,23 +567,23 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup
for (unsigned char i = counts.size() - 1; i >= 2; --i) {
new (middle_begin_ + i - 2) Middle(
middle_starts[i-2],
- quant_.Mid(i),
+ quant_.MiddleBits(config),
counts[i-1],
counts[0],
counts[i],
- (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle_begin_[i-1]),
+ (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest_) : static_cast<const BitPacked &>(middle_begin_[i-1]),
config);
}
- longest.Init(start, quant_.Long(counts.size()), counts[0]);
+ longest_.Init(start, quant_.LongestBits(config), counts[0]);
return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() {
- unigram.LoadedBinary();
+ unigram_.LoadedBinary();
for (Middle *i = middle_begin_; i != middle_end_; ++i) {
i->LoadedBinary();
}
- longest.LoadedBinary();
+ longest_.LoadedBinary();
}
template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh
index 5155ca02..10b22ab1 100644
--- a/klm/lm/search_trie.hh
+++ b/klm/lm/search_trie.hh
@@ -28,13 +28,11 @@ template <class Quant, class Bhiksha> class TrieSearch {
public:
typedef NodeRange Node;
- typedef ::lm::ngram::trie::Unigram Unigram;
- Unigram unigram;
-
- typedef trie::BitPackedMiddle<typename Quant::Middle, Bhiksha> Middle;
+ typedef ::lm::ngram::trie::UnigramPointer UnigramPointer;
+ typedef typename Quant::MiddlePointer MiddlePointer;
+ typedef typename Quant::LongestPointer LongestPointer;
- typedef trie::BitPackedLongest<typename Quant::Longest> Longest;
- Longest longest;
+ static const bool kDifferentRest = false;
static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);
@@ -62,55 +60,46 @@ template <class Quant, class Bhiksha> class TrieSearch {
void LoadedBinary();
- typedef const Middle *MiddleIter;
+ void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
- const Middle *MiddleBegin() const { return middle_begin_; }
- const Middle *MiddleEnd() const { return middle_end_; }
+ unsigned char Order() const {
+ return middle_end_ - middle_begin_ + 2;
+ }
- void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
+ ProbBackoff &UnknownUnigram() { return unigram_.Unknown(); }
- void LookupUnigram(WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const {
- unigram.Find(word, ret.prob, backoff, node);
- ret.independent_left = (node.begin == node.end);
- ret.extend_left = static_cast<uint64_t>(word);
+ UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const {
+ extend_left = static_cast<uint64_t>(word);
+ UnigramPointer ret(unigram_.Find(word, next));
+ independent_left = (next.begin == next.end);
+ return ret;
}
- bool LookupMiddle(const Middle &mid, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const {
- if (!mid.Find(word, ret.prob, backoff, node, ret.extend_left)) return false;
- ret.independent_left = (node.begin == node.end);
- return true;
+ MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const {
+ return MiddlePointer(quant_, extend_length - 2, middle_begin_[extend_length - 2].ReadEntry(extend_pointer, node));
}
- bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const {
- return mid.FindNoProb(word, backoff, node);
+ MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_left) const {
+ util::BitAddress address(middle_begin_[order_minus_2].Find(word, node, extend_left));
+ independent_left = (address.base == NULL) || (node.begin == node.end);
+ return MiddlePointer(quant_, order_minus_2, address);
}
- bool LookupLongest(WordIndex word, float &prob, const Node &node) const {
- return longest.Find(word, prob, node);
+ LongestPointer LookupLongest(WordIndex word, const Node &node) const {
+ return LongestPointer(quant_, longest_.Find(word, node));
}
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
- // TODO: don't decode backoff.
assert(begin != end);
- FullScoreReturn ignored;
- float ignored_backoff;
- LookupUnigram(*begin, ignored_backoff, node, ignored);
+ bool independent_left;
+ uint64_t ignored;
+ LookupUnigram(*begin, node, independent_left, ignored);
for (const WordIndex *i = begin + 1; i < end; ++i) {
- if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false;
+ if (independent_left || !LookupMiddle(i - begin - 1, *i, node, independent_left, ignored).Found()) return false;
}
return true;
}
- Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const {
- if (extend_length == 1) {
- float ignored;
- Node ret;
- unigram.Find(static_cast<WordIndex>(extend_pointer), prob, ignored, ret);
- return ret;
- }
- return middle_begin_[extend_length - 2].ReadEntry(extend_pointer, prob);
- }
-
private:
friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
@@ -122,8 +111,16 @@ template <class Quant, class Bhiksha> class TrieSearch {
free(middle_begin_);
}
+ typedef trie::BitPackedMiddle<Bhiksha> Middle;
+
+ typedef trie::BitPackedLongest Longest;
+ Longest longest_;
+
Middle *middle_begin_, *middle_end_;
Quant quant_;
+
+ typedef ::lm::ngram::trie::Unigram Unigram;
+ Unigram unigram_;
};
} // namespace trie
diff --git a/klm/lm/state.hh b/klm/lm/state.hh
new file mode 100644
index 00000000..c7438414
--- /dev/null
+++ b/klm/lm/state.hh
@@ -0,0 +1,123 @@
+#ifndef LM_STATE__
+#define LM_STATE__
+
+#include "lm/max_order.hh"
+#include "lm/word_index.hh"
+#include "util/murmur_hash.hh"
+
+#include <string.h>
+
+namespace lm {
+namespace ngram {
+
+// This is a POD but if you want memcmp to return the same as operator==, call
+// ZeroRemaining first.
+class State {
+ public:
+ bool operator==(const State &other) const {
+ if (length != other.length) return false;
+ return !memcmp(words, other.words, length * sizeof(WordIndex));
+ }
+
+ // Three way comparison function.
+ int Compare(const State &other) const {
+ if (length != other.length) return length < other.length ? -1 : 1;
+ return memcmp(words, other.words, length * sizeof(WordIndex));
+ }
+
+ bool operator<(const State &other) const {
+ if (length != other.length) return length < other.length;
+ return memcmp(words, other.words, length * sizeof(WordIndex)) < 0;
+ }
+
+ // Call this before using raw memcmp.
+ void ZeroRemaining() {
+ for (unsigned char i = length; i < kMaxOrder - 1; ++i) {
+ words[i] = 0;
+ backoff[i] = 0.0;
+ }
+ }
+
+ unsigned char Length() const { return length; }
+
+ // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
+ // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
+ WordIndex words[kMaxOrder - 1];
+ float backoff[kMaxOrder - 1];
+ unsigned char length;
+};
+
+inline uint64_t hash_value(const State &state, uint64_t seed = 0) {
+ return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length, seed);
+}
+
+struct Left {
+ bool operator==(const Left &other) const {
+ return
+ (length == other.length) &&
+ pointers[length - 1] == other.pointers[length - 1] &&
+ full == other.full;
+ }
+
+ int Compare(const Left &other) const {
+ if (length < other.length) return -1;
+ if (length > other.length) return 1;
+ if (pointers[length - 1] > other.pointers[length - 1]) return 1;
+ if (pointers[length - 1] < other.pointers[length - 1]) return -1;
+ return (int)full - (int)other.full;
+ }
+
+ bool operator<(const Left &other) const {
+ return Compare(other) == -1;
+ }
+
+ void ZeroRemaining() {
+ for (uint64_t * i = pointers + length; i < pointers + kMaxOrder - 1; ++i)
+ *i = 0;
+ }
+
+ uint64_t pointers[kMaxOrder - 1];
+ unsigned char length;
+ bool full;
+};
+
+inline uint64_t hash_value(const Left &left) {
+ unsigned char add[2];
+ add[0] = left.length;
+ add[1] = left.full;
+ return util::MurmurHashNative(add, 2, left.length ? left.pointers[left.length - 1] : 0);
+}
+
+struct ChartState {
+ bool operator==(const ChartState &other) {
+ return (right == other.right) && (left == other.left);
+ }
+
+ int Compare(const ChartState &other) const {
+ int lres = left.Compare(other.left);
+ if (lres) return lres;
+ return right.Compare(other.right);
+ }
+
+ bool operator<(const ChartState &other) const {
+ return Compare(other) == -1;
+ }
+
+ void ZeroRemaining() {
+ left.ZeroRemaining();
+ right.ZeroRemaining();
+ }
+
+ Left left;
+ State right;
+};
+
+inline uint64_t hash_value(const ChartState &state) {
+ return hash_value(state.right, hash_value(state.left));
+}
+
+
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_STATE__
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc
index 20075bb8..0f1ca574 100644
--- a/klm/lm/trie.cc
+++ b/klm/lm/trie.cc
@@ -1,7 +1,6 @@
#include "lm/trie.hh"
#include "lm/bhiksha.hh"
-#include "lm/quantize.hh"
#include "util/bit_packing.hh"
#include "util/exception.hh"
#include "util/sorted_uniform.hh"
@@ -58,91 +57,71 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits)
max_vocab_ = max_vocab;
}
-template <class Quant, class Bhiksha> std::size_t BitPackedMiddle<Quant, Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {
+template <class Bhiksha> std::size_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {
return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config));
}
-template <class Quant, class Bhiksha> BitPackedMiddle<Quant, Bhiksha>::BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) :
+template <class Bhiksha> BitPackedMiddle<Bhiksha>::BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) :
BitPacked(),
- quant_(quant),
+ quant_bits_(quant_bits),
// If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary.
bhiksha_(base, entries + 1, max_next, config),
next_source_(&next_source) {
if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57))) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
- BaseInit(reinterpret_cast<uint8_t*>(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant.TotalBits() + bhiksha_.InlineBits());
+ BaseInit(reinterpret_cast<uint8_t*>(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant_bits_ + bhiksha_.InlineBits());
}
-template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::Insert(WordIndex word, float prob, float backoff) {
+template <class Bhiksha> util::BitAddress BitPackedMiddle<Bhiksha>::Insert(WordIndex word) {
assert(word <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
util::WriteInt57(base_, at_pointer, word_bits_, word);
at_pointer += word_bits_;
- quant_.Write(base_, at_pointer, prob, backoff);
- at_pointer += quant_.TotalBits();
+ util::BitAddress ret(base_, at_pointer);
+ at_pointer += quant_bits_;
uint64_t next = next_source_->InsertIndex();
bhiksha_.WriteNext(base_, at_pointer, insert_index_, next);
-
++insert_index_;
+ return ret;
}
-template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const {
+template <class Bhiksha> util::BitAddress BitPackedMiddle<Bhiksha>::Find(WordIndex word, NodeRange &range, uint64_t &pointer) const {
uint64_t at_pointer;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) {
- return false;
+ return util::BitAddress(NULL, 0);
}
pointer = at_pointer;
at_pointer *= total_bits_;
at_pointer += word_bits_;
+ bhiksha_.ReadNext(base_, at_pointer + quant_bits_, pointer, total_bits_, range);
- quant_.Read(base_, at_pointer, prob, backoff);
- at_pointer += quant_.TotalBits();
-
- bhiksha_.ReadNext(base_, at_pointer, pointer, total_bits_, range);
-
- return true;
+ return util::BitAddress(base_, at_pointer);
}
-template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
- uint64_t index;
- if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, index)) return false;
- uint64_t at_pointer = index * total_bits_;
- at_pointer += word_bits_;
- quant_.ReadBackoff(base_, at_pointer, backoff);
- at_pointer += quant_.TotalBits();
- bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range);
- return true;
-}
-
-template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) {
+template <class Bhiksha> void BitPackedMiddle<Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) {
uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits();
bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end);
bhiksha_.FinishedLoading(config);
}
-template <class Quant> void BitPackedLongest<Quant>::Insert(WordIndex index, float prob) {
+util::BitAddress BitPackedLongest::Insert(WordIndex index) {
assert(index <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
util::WriteInt57(base_, at_pointer, word_bits_, index);
at_pointer += word_bits_;
- quant_.Write(base_, at_pointer, prob);
++insert_index_;
+ return util::BitAddress(base_, at_pointer);
}
-template <class Quant> bool BitPackedLongest<Quant>::Find(WordIndex word, float &prob, const NodeRange &range) const {
+util::BitAddress BitPackedLongest::Find(WordIndex word, const NodeRange &range) const {
uint64_t at_pointer;
- if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false;
+ if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return util::BitAddress(NULL, 0);
at_pointer = at_pointer * total_bits_ + word_bits_;
- quant_.Read(base_, at_pointer, prob);
- return true;
+ return util::BitAddress(base_, at_pointer);
}
-template class BitPackedMiddle<DontQuantize::Middle, DontBhiksha>;
-template class BitPackedMiddle<DontQuantize::Middle, ArrayBhiksha>;
-template class BitPackedMiddle<SeparatelyQuantize::Middle, DontBhiksha>;
-template class BitPackedMiddle<SeparatelyQuantize::Middle, ArrayBhiksha>;
-template class BitPackedLongest<DontQuantize::Longest>;
-template class BitPackedLongest<SeparatelyQuantize::Longest>;
+template class BitPackedMiddle<DontBhiksha>;
+template class BitPackedMiddle<ArrayBhiksha>;
} // namespace trie
} // namespace ngram
diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh
index ebe9910f..eff93292 100644
--- a/klm/lm/trie.hh
+++ b/klm/lm/trie.hh
@@ -1,12 +1,13 @@
#ifndef LM_TRIE__
#define LM_TRIE__
-#include <stdint.h>
+#include "lm/weights.hh"
+#include "lm/word_index.hh"
+#include "util/bit_packing.hh"
#include <cstddef>
-#include "lm/word_index.hh"
-#include "lm/weights.hh"
+#include <stdint.h>
namespace lm {
namespace ngram {
@@ -24,6 +25,22 @@ struct UnigramValue {
uint64_t Next() const { return next; }
};
+class UnigramPointer {
+ public:
+ explicit UnigramPointer(const ProbBackoff &to) : to_(&to) {}
+
+ UnigramPointer() : to_(NULL) {}
+
+ bool Found() const { return to_ != NULL; }
+
+ float Prob() const { return to_->prob; }
+ float Backoff() const { return to_->backoff; }
+ float Rest() const { return Prob(); }
+
+ private:
+ const ProbBackoff *to_;
+};
+
class Unigram {
public:
Unigram() {}
@@ -47,12 +64,11 @@ class Unigram {
void LoadedBinary() {}
- void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const {
+ UnigramPointer Find(WordIndex word, NodeRange &next) const {
UnigramValue *val = unigram_ + word;
- prob = val->weights.prob;
- backoff = val->weights.backoff;
next.begin = val->next;
next.end = (val+1)->next;
+ return UnigramPointer(val->weights);
}
private:
@@ -81,40 +97,36 @@ class BitPacked {
uint64_t insert_index_, max_vocab_;
};
-template <class Quant, class Bhiksha> class BitPackedMiddle : public BitPacked {
+template <class Bhiksha> class BitPackedMiddle : public BitPacked {
public:
static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);
// next_source need not be initialized.
- BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);
+ BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);
- void Insert(WordIndex word, float prob, float backoff);
+ util::BitAddress Insert(WordIndex word);
void FinishedLoading(uint64_t next_end, const Config &config);
void LoadedBinary() { bhiksha_.LoadedBinary(); }
- bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const;
-
- bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const;
+ util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const;
- NodeRange ReadEntry(uint64_t pointer, float &prob) {
+ util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) {
uint64_t addr = pointer * total_bits_;
addr += word_bits_;
- quant_.ReadProb(base_, addr, prob);
- NodeRange ret;
- bhiksha_.ReadNext(base_, addr + quant_.TotalBits(), pointer, total_bits_, ret);
- return ret;
+ bhiksha_.ReadNext(base_, addr + quant_bits_, pointer, total_bits_, range);
+ return util::BitAddress(base_, addr);
}
private:
- Quant quant_;
+ uint8_t quant_bits_;
Bhiksha bhiksha_;
const BitPacked *next_source_;
};
-template <class Quant> class BitPackedLongest : public BitPacked {
+class BitPackedLongest : public BitPacked {
public:
static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
return BaseSize(entries, max_vocab, quant_bits);
@@ -122,19 +134,18 @@ template <class Quant> class BitPackedLongest : public BitPacked {
BitPackedLongest() {}
- void Init(void *base, const Quant &quant, uint64_t max_vocab) {
- quant_ = quant;
- BaseInit(base, max_vocab, quant_.TotalBits());
+ void Init(void *base, uint8_t quant_bits, uint64_t max_vocab) {
+ BaseInit(base, max_vocab, quant_bits);
}
void LoadedBinary() {}
- void Insert(WordIndex word, float prob);
+ util::BitAddress Insert(WordIndex word);
- bool Find(WordIndex word, float &prob, const NodeRange &node) const;
+ util::BitAddress Find(WordIndex word, const NodeRange &node) const;
private:
- Quant quant_;
+ uint8_t quant_bits_;
};
} // namespace trie
diff --git a/klm/lm/value.hh b/klm/lm/value.hh
new file mode 100644
index 00000000..85e53f14
--- /dev/null
+++ b/klm/lm/value.hh
@@ -0,0 +1,157 @@
+#ifndef LM_VALUE__
+#define LM_VALUE__
+
+#include "lm/model_type.hh"
+#include "lm/value_build.hh"
+#include "lm/weights.hh"
+#include "util/bit_packing.hh"
+
+#include <inttypes.h>
+
+namespace lm {
+namespace ngram {
+
+// Template proxy for probing unigrams and middle.
+template <class Weights> class GenericProbingProxy {
+ public:
+ explicit GenericProbingProxy(const Weights &to) : to_(&to) {}
+
+ GenericProbingProxy() : to_(0) {}
+
+ bool Found() const { return to_ != 0; }
+
+ float Prob() const {
+ util::FloatEnc enc;
+ enc.f = to_->prob;
+ enc.i |= util::kSignBit;
+ return enc.f;
+ }
+
+ float Backoff() const { return to_->backoff; }
+
+ bool IndependentLeft() const {
+ util::FloatEnc enc;
+ enc.f = to_->prob;
+ return enc.i & util::kSignBit;
+ }
+
+ protected:
+ const Weights *to_;
+};
+
+// Basic proxy for trie unigrams.
+template <class Weights> class GenericTrieUnigramProxy {
+ public:
+ explicit GenericTrieUnigramProxy(const Weights &to) : to_(&to) {}
+
+ GenericTrieUnigramProxy() : to_(0) {}
+
+ bool Found() const { return to_ != 0; }
+ float Prob() const { return to_->prob; }
+ float Backoff() const { return to_->backoff; }
+ float Rest() const { return Prob(); }
+
+ protected:
+ const Weights *to_;
+};
+
+struct BackoffValue {
+ typedef ProbBackoff Weights;
+ static const ModelType kProbingModelType = PROBING;
+
+ class ProbingProxy : public GenericProbingProxy<Weights> {
+ public:
+ explicit ProbingProxy(const Weights &to) : GenericProbingProxy<Weights>(to) {}
+ ProbingProxy() {}
+ float Rest() const { return Prob(); }
+ };
+
+ class TrieUnigramProxy : public GenericTrieUnigramProxy<Weights> {
+ public:
+ explicit TrieUnigramProxy(const Weights &to) : GenericTrieUnigramProxy<Weights>(to) {}
+ TrieUnigramProxy() {}
+ float Rest() const { return Prob(); }
+ };
+
+ struct ProbingEntry {
+ typedef uint64_t Key;
+ typedef Weights Value;
+ uint64_t key;
+ ProbBackoff value;
+ uint64_t GetKey() const { return key; }
+ };
+
+ struct TrieUnigramValue {
+ Weights weights;
+ uint64_t next;
+ uint64_t Next() const { return next; }
+ };
+
+ const static bool kDifferentRest = false;
+
+ template <class Model, class C> void Callback(const Config &, unsigned int, typename Model::Vocabulary &, C &callback) {
+ NoRestBuild build;
+ callback(build);
+ }
+};
+
+struct RestValue {
+ typedef RestWeights Weights;
+ static const ModelType kProbingModelType = REST_PROBING;
+
+ class ProbingProxy : public GenericProbingProxy<RestWeights> {
+ public:
+ explicit ProbingProxy(const Weights &to) : GenericProbingProxy<RestWeights>(to) {}
+ ProbingProxy() {}
+ float Rest() const { return to_->rest; }
+ };
+
+ class TrieUnigramProxy : public GenericTrieUnigramProxy<Weights> {
+ public:
+ explicit TrieUnigramProxy(const Weights &to) : GenericTrieUnigramProxy<Weights>(to) {}
+ TrieUnigramProxy() {}
+ float Rest() const { return to_->rest; }
+ };
+
+// gcc 4.1 doesn't properly back dependent types :-(.
+#pragma pack(push)
+#pragma pack(4)
+ struct ProbingEntry {
+ typedef uint64_t Key;
+ typedef Weights Value;
+ Key key;
+ Value value;
+ Key GetKey() const { return key; }
+ };
+
+ struct TrieUnigramValue {
+ Weights weights;
+ uint64_t next;
+ uint64_t Next() const { return next; }
+ };
+#pragma pack(pop)
+
+ const static bool kDifferentRest = true;
+
+ template <class Model, class C> void Callback(const Config &config, unsigned int order, typename Model::Vocabulary &vocab, C &callback) {
+ switch (config.rest_function) {
+ case Config::REST_MAX:
+ {
+ MaxRestBuild build;
+ callback(build);
+ }
+ break;
+ case Config::REST_LOWER:
+ {
+ LowerRestBuild<Model> build(config, order, vocab);
+ callback(build);
+ }
+ break;
+ }
+ }
+};
+
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_VALUE__
diff --git a/klm/lm/value_build.cc b/klm/lm/value_build.cc
new file mode 100644
index 00000000..6124f8da
--- /dev/null
+++ b/klm/lm/value_build.cc
@@ -0,0 +1,58 @@
+#include "lm/value_build.hh"
+
+#include "lm/model.hh"
+#include "lm/read_arpa.hh"
+
+namespace lm {
+namespace ngram {
+
+template <class Model> LowerRestBuild<Model>::LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab) {
+ UTIL_THROW_IF(config.rest_lower_files.size() != order - 1, ConfigException, "This model has order " << order << " so there should be " << (order - 1) << " lower-order models for rest cost purposes.");
+ Config for_lower = config;
+ for_lower.rest_lower_files.clear();
+
+ // Unigram models aren't supported, so this is a custom loader.
+ // TODO: optimize the unigram loading?
+ {
+ util::FilePiece uni(config.rest_lower_files[0].c_str());
+ std::vector<uint64_t> number;
+ ReadARPACounts(uni, number);
+ UTIL_THROW_IF(number.size() != 1, FormatLoadException, "Expected the unigram model to have order 1, not " << number.size());
+ ReadNGramHeader(uni, 1);
+ unigrams_.resize(number[0]);
+ unigrams_[0] = config.unknown_missing_logprob;
+ PositiveProbWarn warn;
+ for (uint64_t i = 0; i < number[0]; ++i) {
+ WordIndex w;
+ Prob entry;
+ ReadNGram(uni, 1, vocab, &w, entry, warn);
+ unigrams_[w] = entry.prob;
+ }
+ }
+
+ try {
+ for (unsigned int i = 2; i < order; ++i) {
+ models_.push_back(new Model(config.rest_lower_files[i - 1].c_str(), for_lower));
+ UTIL_THROW_IF(models_.back()->Order() != i, FormatLoadException, "Lower order file " << config.rest_lower_files[i-1] << " should have order " << i);
+ }
+ } catch (...) {
+ for (typename std::vector<const Model*>::const_iterator i = models_.begin(); i != models_.end(); ++i) {
+ delete *i;
+ }
+ models_.clear();
+ throw;
+ }
+
+ // TODO: force/check same vocab.
+}
+
+template <class Model> LowerRestBuild<Model>::~LowerRestBuild() {
+ for (typename std::vector<const Model*>::const_iterator i = models_.begin(); i != models_.end(); ++i) {
+ delete *i;
+ }
+}
+
+template class LowerRestBuild<ProbingModel>;
+
+} // namespace ngram
+} // namespace lm
diff --git a/klm/lm/value_build.hh b/klm/lm/value_build.hh
new file mode 100644
index 00000000..687a41a0
--- /dev/null
+++ b/klm/lm/value_build.hh
@@ -0,0 +1,97 @@
+#ifndef LM_VALUE_BUILD__
+#define LM_VALUE_BUILD__
+
+#include "lm/weights.hh"
+#include "lm/word_index.hh"
+#include "util/bit_packing.hh"
+
+#include <vector>
+
+namespace lm {
+namespace ngram {
+
+class Config;
+class BackoffValue;
+class RestValue;
+
+class NoRestBuild {
+ public:
+ typedef BackoffValue Value;
+
+ NoRestBuild() {}
+
+ void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {}
+ void SetRest(const WordIndex *, unsigned int, const ProbBackoff &) const {}
+
+ template <class Second> bool MarkExtends(ProbBackoff &weights, const Second &) const {
+ util::UnsetSign(weights.prob);
+ return false;
+ }
+
+ // Probing doesn't need to go back to unigram.
+ const static bool kMarkEvenLower = false;
+};
+
+class MaxRestBuild {
+ public:
+ typedef RestValue Value;
+
+ MaxRestBuild() {}
+
+ void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {}
+ void SetRest(const WordIndex *, unsigned int, RestWeights &weights) const {
+ weights.rest = weights.prob;
+ util::SetSign(weights.rest);
+ }
+
+ bool MarkExtends(RestWeights &weights, const RestWeights &to) const {
+ util::UnsetSign(weights.prob);
+ if (weights.rest >= to.rest) return false;
+ weights.rest = to.rest;
+ return true;
+ }
+ bool MarkExtends(RestWeights &weights, const Prob &to) const {
+ util::UnsetSign(weights.prob);
+ if (weights.rest >= to.prob) return false;
+ weights.rest = to.prob;
+ return true;
+ }
+
+ // Probing does need to go back to unigram.
+ const static bool kMarkEvenLower = true;
+};
+
+template <class Model> class LowerRestBuild {
+ public:
+ typedef RestValue Value;
+
+ LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab);
+
+ ~LowerRestBuild();
+
+ void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {}
+ void SetRest(const WordIndex *vocab_ids, unsigned int n, RestWeights &weights) const {
+ typename Model::State ignored;
+ if (n == 1) {
+ weights.rest = unigrams_[*vocab_ids];
+ } else {
+ weights.rest = models_[n-2]->FullScoreForgotState(vocab_ids + 1, vocab_ids + n, *vocab_ids, ignored).prob;
+ }
+ }
+
+ template <class Second> bool MarkExtends(RestWeights &weights, const Second &) const {
+ util::UnsetSign(weights.prob);
+ return false;
+ }
+
+ const static bool kMarkEvenLower = false;
+
+ std::vector<float> unigrams_;
+
+ std::vector<const Model*> models_;
+};
+
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_VALUE_BUILD__
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index 9fd698bb..5de68f16 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -196,7 +196,7 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) {
}
}
-void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) {
+void ProbingVocabulary::InternalFinishedLoading() {
lookup_.FinishedInserting();
header_->bound = bound_;
header_->version = kProbingVocabularyVersion;
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index 06fdefe4..c3efcb4a 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -141,7 +141,9 @@ class ProbingVocabulary : public base::Vocabulary {
WordIndex Insert(const StringPiece &str);
- void FinishedLoading(ProbBackoff *reorder_vocab);
+ template <class Weights> void FinishedLoading(Weights * /*reorder_vocab*/) {
+ InternalFinishedLoading();
+ }
std::size_t UnkCountChangePadding() const { return 0; }
@@ -150,6 +152,8 @@ class ProbingVocabulary : public base::Vocabulary {
void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
private:
+ void InternalFinishedLoading();
+
typedef util::ProbingHashTable<ProbingVocabuaryEntry, util::IdentityHash> Lookup;
Lookup lookup_;
diff --git a/klm/lm/weights.hh b/klm/lm/weights.hh
index 1f38cf5e..bd5d8034 100644
--- a/klm/lm/weights.hh
+++ b/klm/lm/weights.hh
@@ -12,6 +12,11 @@ struct ProbBackoff {
float prob;
float backoff;
};
+struct RestWeights {
+ float prob;
+ float backoff;
+ float rest;
+};
} // namespace lm
#endif // LM_WEIGHTS__
diff --git a/klm/util/Jamfile b/klm/util/Jamfile
index b8c14347..3ee2c2c2 100644
--- a/klm/util/Jamfile
+++ b/klm/util/Jamfile
@@ -1,4 +1,4 @@
-lib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc ../..//z : <include>.. : : <include>.. ;
+lib kenutil : bit_packing.cc ersatz_progress.cc exception.cc file.cc file_piece.cc mmap.cc murmur_hash.cc usage.cc ../..//z : <include>.. : : <include>.. ;
import testing ;
diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am
index a8d6299b..5ceccf2c 100644
--- a/klm/util/Makefile.am
+++ b/klm/util/Makefile.am
@@ -25,6 +25,7 @@ libklm_util_a_SOURCES = \
file.cc \
file_piece.cc \
mmap.cc \
- murmur_hash.cc
+ murmur_hash.cc \
+ usage.cc
AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I..
diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh
index 73a5cb22..dcbd814c 100644
--- a/klm/util/bit_packing.hh
+++ b/klm/util/bit_packing.hh
@@ -174,6 +174,13 @@ struct BitsMask {
uint64_t mask;
};
+struct BitAddress {
+ BitAddress(void *in_base, uint64_t in_offset) : base(in_base), offset(in_offset) {}
+
+ void *base;
+ uint64_t offset;
+};
+
} // namespace util
#endif // UTIL_BIT_PACKING__
diff --git a/klm/util/ersatz_progress.cc b/klm/util/ersatz_progress.cc
index a82ce672..07b14e26 100644
--- a/klm/util/ersatz_progress.cc
+++ b/klm/util/ersatz_progress.cc
@@ -12,17 +12,17 @@ namespace { const unsigned char kWidth = 100; }
ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits<std::size_t>::max()), complete_(next_), out_(NULL) {}
ErsatzProgress::~ErsatzProgress() {
- if (!out_) return;
- Finished();
+ if (out_) Finished();
}
-ErsatzProgress::ErsatzProgress(std::ostream *to, const std::string &message, std::size_t complete)
+ErsatzProgress::ErsatzProgress(std::size_t complete, std::ostream *to, const std::string &message)
: current_(0), next_(complete / kWidth), complete_(complete), stones_written_(0), out_(to) {
if (!out_) {
next_ = std::numeric_limits<std::size_t>::max();
return;
}
- *out_ << message << "\n----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n";
+ if (!message.empty()) *out_ << message << '\n';
+ *out_ << "----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n";
}
void ErsatzProgress::Milestone() {
diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh
index 92c345fe..f709dc51 100644
--- a/klm/util/ersatz_progress.hh
+++ b/klm/util/ersatz_progress.hh
@@ -1,7 +1,7 @@
#ifndef UTIL_ERSATZ_PROGRESS__
#define UTIL_ERSATZ_PROGRESS__
-#include <iosfwd>
+#include <iostream>
#include <string>
// Ersatz version of boost::progress so core language model doesn't depend on
@@ -14,7 +14,7 @@ class ErsatzProgress {
ErsatzProgress();
// Null means no output. The null value is useful for passing along the ostream pointer from another caller.
- ErsatzProgress(std::ostream *to, const std::string &message, std::size_t complete);
+ explicit ErsatzProgress(std::size_t complete, std::ostream *to = &std::cerr, const std::string &message = "");
~ErsatzProgress();
diff --git a/klm/util/file.cc b/klm/util/file.cc
index de206bc8..1bd056fc 100644
--- a/klm/util/file.cc
+++ b/klm/util/file.cc
@@ -43,16 +43,6 @@ int OpenReadOrThrow(const char *name) {
return ret;
}
-int CreateOrThrow(const char *name) {
- int ret;
-#if defined(_WIN32) || defined(_WIN64)
- UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name);
-#else
- UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)), ErrnoException, "while creating " << name);
-#endif
- return ret;
-}
-
uint64_t SizeFile(int fd) {
#if defined(_WIN32) || defined(_WIN64)
__int64 ret = _filelengthi64(fd);
diff --git a/klm/util/file.hh b/klm/util/file.hh
index 72c8ea76..5c57e2a9 100644
--- a/klm/util/file.hh
+++ b/klm/util/file.hh
@@ -65,10 +65,7 @@ class scoped_FILE {
std::FILE *file_;
};
-// Open for read only.
int OpenReadOrThrow(const char *name);
-// Create file if it doesn't exist, truncate if it does. Opened for write.
-int CreateOrThrow(const char *name);
// Return value for SizeFile when it can't size properly.
const uint64_t kBadSize = (uint64_t)-1;
diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc
index 081e662b..7b6a01dd 100644
--- a/klm/util/file_piece.cc
+++ b/klm/util/file_piece.cc
@@ -18,31 +18,35 @@
#include <sys/types.h>
#include <sys/stat.h>
+#ifdef HAVE_ZLIB
+#include <zlib.h>
+#endif
+
namespace util {
ParseNumberException::ParseNumberException(StringPiece value) throw() {
*this << "Could not parse \"" << value << "\" into a number";
}
+GZException::GZException(void *file) {
#ifdef HAVE_ZLIB
-GZException::GZException(gzFile file) {
int num;
- *this << gzerror( file, &num) << " from zlib";
-}
+ *this << gzerror(file, &num) << " from zlib";
#endif // HAVE_ZLIB
+}
// Sigh this is the only way I could come up with to do a _const_ bool. It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale).
const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
FilePiece::FilePiece(const char *name, std::ostream *show_progress, std::size_t min_buffer) :
file_(OpenReadOrThrow(name)), total_size_(SizeFile(file_.get())), page_(SizePage()),
- progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) {
+ progress_(total_size_, total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name) {
Initialize(name, show_progress, min_buffer);
}
FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, std::size_t min_buffer) :
file_(fd), total_size_(SizeFile(file_.get())), page_(SizePage()),
- progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) {
+ progress_(total_size_, total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name) {
Initialize(name, show_progress, min_buffer);
}
@@ -149,8 +153,9 @@ template <class T> T FilePiece::ReadNumber() {
SkipSpaces();
while (last_space_ < position_) {
if (at_end_) {
+ if (position_ >= position_end_) throw EndOfFileException();
// Hallucinate a null off the end of the file.
- std::string buffer(position_, position_end_);
+ std::string buffer(position_, position_end_ - position_);
char *end;
T ret;
ParseNumber(buffer.c_str(), end, ret);
diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh
index af93d8aa..b81ac0e2 100644
--- a/klm/util/file_piece.hh
+++ b/klm/util/file_piece.hh
@@ -13,10 +13,6 @@
#include <stdint.h>
-#ifdef HAVE_ZLIB
-#include <zlib.h>
-#endif
-
namespace util {
class ParseNumberException : public Exception {
@@ -27,9 +23,7 @@ class ParseNumberException : public Exception {
class GZException : public Exception {
public:
-#ifdef HAVE_ZLIB
- explicit GZException(gzFile file);
-#endif
+ explicit GZException(void *file);
GZException() throw() {}
~GZException() throw() {}
};
@@ -123,7 +117,7 @@ class FilePiece {
std::string file_name_;
#ifdef HAVE_ZLIB
- gzFile gz_file_;
+ void *gz_file_;
#endif // HAVE_ZLIB
};
diff --git a/klm/util/have.hh b/klm/util/have.hh
index f2f0cf90..b8181e99 100644
--- a/klm/util/have.hh
+++ b/klm/util/have.hh
@@ -3,13 +3,21 @@
#define UTIL_HAVE__
#ifndef HAVE_ZLIB
+#if !defined(_WIN32) && !defined(_WIN64)
#define HAVE_ZLIB
#endif
+#endif
-// #define HAVE_ICU
+#ifndef HAVE_ICU
+//#define HAVE_ICU
+#endif
#ifndef HAVE_BOOST
#define HAVE_BOOST
#endif
+#ifndef HAVE_THREADS
+//#define HAVE_THREADS
+#endif
+
#endif // UTIL_HAVE__
diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc
index 2db35b56..e0d2570b 100644
--- a/klm/util/mmap.cc
+++ b/klm/util/mmap.cc
@@ -171,6 +171,20 @@ void *MapZeroedWrite(int fd, std::size_t size) {
return MapOrThrow(size, true, kFileFlags, false, fd, 0);
}
+namespace {
+
+int CreateOrThrow(const char *name) {
+ int ret;
+#if defined(_WIN32) || defined(_WIN64)
+ UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name);
+#else
+ UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)), ErrnoException, "while creating " << name);
+#endif
+ return ret;
+}
+
+} // namespace
+
void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file) {
file.reset(CreateOrThrow(name));
try {
diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc
index 6accc21a..4f519312 100644
--- a/klm/util/murmur_hash.cc
+++ b/klm/util/murmur_hash.cc
@@ -23,7 +23,7 @@ namespace util {
// 64-bit hash for 64-bit platforms
-uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed )
+uint64_t MurmurHash64A ( const void * key, std::size_t len, uint64_t seed )
{
const uint64_t m = 0xc6a4a7935bd1e995ULL;
const int r = 47;
@@ -81,7 +81,7 @@ uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed )
// 64-bit hash for 32-bit platforms
-uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed )
+uint64_t MurmurHash64B ( const void * key, std::size_t len, uint64_t seed )
{
const unsigned int m = 0x5bd1e995;
const int r = 24;
@@ -150,17 +150,18 @@ uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed )
return h;
}
+
// Trick to test for 64-bit architecture at compile time.
namespace {
-template <unsigned L> uint64_t MurmurHashNativeBackend(const void * key, std::size_t len, unsigned int seed) {
+template <unsigned L> inline uint64_t MurmurHashNativeBackend(const void * key, std::size_t len, uint64_t seed) {
return MurmurHash64A(key, len, seed);
}
-template <> uint64_t MurmurHashNativeBackend<4>(const void * key, std::size_t len, unsigned int seed) {
+template <> inline uint64_t MurmurHashNativeBackend<4>(const void * key, std::size_t len, uint64_t seed) {
return MurmurHash64B(key, len, seed);
}
} // namespace
-uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) {
+uint64_t MurmurHashNative(const void * key, std::size_t len, uint64_t seed) {
return MurmurHashNativeBackend<sizeof(void*)>(key, len, seed);
}
diff --git a/klm/util/murmur_hash.hh b/klm/util/murmur_hash.hh
index 638aaeb2..ae7e88de 100644
--- a/klm/util/murmur_hash.hh
+++ b/klm/util/murmur_hash.hh
@@ -5,9 +5,9 @@
namespace util {
-uint64_t MurmurHash64A(const void * key, std::size_t len, unsigned int seed = 0);
-uint64_t MurmurHash64B(const void * key, std::size_t len, unsigned int seed = 0);
-uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed = 0);
+uint64_t MurmurHash64A(const void * key, std::size_t len, uint64_t seed = 0);
+uint64_t MurmurHash64B(const void * key, std::size_t len, uint64_t seed = 0);
+uint64_t MurmurHashNative(const void * key, std::size_t len, uint64_t seed = 0);
} // namespace util
diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh
index f466cebc..3354b68e 100644
--- a/klm/util/probing_hash_table.hh
+++ b/klm/util/probing_hash_table.hh
@@ -78,12 +78,33 @@ template <class EntryT, class HashT, class EqualT = std::equal_to<typename Entry
}
}
+ // Return true if the value was found (and not inserted). This is consistent with Find but the opposite if hash_map!
+ template <class T> bool FindOrInsert(const T &t, MutableIterator &out) {
+#ifdef DEBUG
+ assert(initialized_);
+#endif
+ for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) {
+ Key got(i->GetKey());
+ if (equal_(got, t.GetKey())) { out = i; return true; }
+ if (equal_(got, invalid_)) {
+ UTIL_THROW_IF(++entries_ >= buckets_, ProbingSizeException, "Hash table with " << buckets_ << " buckets is full.");
+ *i = t;
+ out = i;
+ return false;
+ }
+ if (++i == end_) i = begin_;
+ }
+ }
+
void FinishedInserting() {}
void LoadedBinary() {}
// Don't change anything related to GetKey,
template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) {
+#ifdef DEBUG
+ assert(initialized_);
+#endif
for (MutableIterator i(begin_ + (hash_(key) % buckets_));;) {
Key got(i->GetKey());
if (equal_(got, key)) { out = i; return true; }
diff --git a/klm/util/usage.cc b/klm/util/usage.cc
new file mode 100644
index 00000000..e5cf76f0
--- /dev/null
+++ b/klm/util/usage.cc
@@ -0,0 +1,46 @@
+#include "util/usage.hh"
+
+#include <fstream>
+#include <ostream>
+
+#include <string.h>
+#include <ctype.h>
+#if !defined(_WIN32) && !defined(_WIN64)
+#include <sys/resource.h>
+#include <sys/time.h>
+#endif
+
+namespace util {
+
+namespace {
+#if !defined(_WIN32) && !defined(_WIN64)
+float FloatSec(const struct timeval &tv) {
+ return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_usec) / 1000000.0);
+}
+#endif
+} // namespace
+
+void PrintUsage(std::ostream &out) {
+#if !defined(_WIN32) && !defined(_WIN64)
+ struct rusage usage;
+ if (getrusage(RUSAGE_SELF, &usage)) {
+ perror("getrusage");
+ return;
+ }
+ out << "user\t" << FloatSec(usage.ru_utime) << "\nsys\t" << FloatSec(usage.ru_stime) << '\n';
+
+ // Linux doesn't set memory usage :-(.
+ std::ifstream status("/proc/self/status", std::ios::in);
+ std::string line;
+ while (getline(status, line)) {
+ if (!strncmp(line.c_str(), "VmRSS:\t", 7)) {
+ out << "VmRSS: " << (line.c_str() + 7) << '\n';
+ break;
+ } else if (!strncmp(line.c_str(), "VmPeak:\t", 8)) {
+ out << "VmPeak: " << (line.c_str() + 8) << '\n';
+ }
+ }
+#endif
+}
+
+} // namespace util
diff --git a/klm/util/usage.hh b/klm/util/usage.hh
new file mode 100644
index 00000000..d331ff74
--- /dev/null
+++ b/klm/util/usage.hh
@@ -0,0 +1,8 @@
+#ifndef UTIL_USAGE__
+#define UTIL_USAGE__
+#include <iosfwd>
+
+namespace util {
+void PrintUsage(std::ostream &to);
+} // namespace util
+#endif // UTIL_USAGE__