summaryrefslogtreecommitdiff
path: root/klm/lm
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2013-06-18 11:34:20 -0700
committerKenneth Heafield <github@kheafield.com>2013-06-18 11:34:20 -0700
commit8dc383a74c12d44ab3f51947575ed5828653f4f1 (patch)
tree30371d17e6c2daf3837068b2617e357ad6b37d89 /klm/lm
parent354787aa16539702802b9ea075c4bd8a72071035 (diff)
lazy dd880b4 including kenlm 6eef0f1
Diffstat (limited to 'klm/lm')
-rw-r--r--klm/lm/builder/lmplz_main.cc15
-rw-r--r--klm/lm/builder/ngram.hh2
-rw-r--r--klm/lm/model.cc21
-rw-r--r--klm/lm/model.hh5
-rw-r--r--klm/lm/search_hashed.cc29
-rw-r--r--klm/lm/search_hashed.hh19
-rw-r--r--klm/lm/state.hh2
-rw-r--r--klm/lm/virtual_interface.hh3
-rw-r--r--klm/lm/vocab.hh2
9 files changed, 61 insertions, 37 deletions
diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc
index 1e086dcc..c87abdb8 100644
--- a/klm/lm/builder/lmplz_main.cc
+++ b/klm/lm/builder/lmplz_main.cc
@@ -52,13 +52,14 @@ int main(int argc, char *argv[]) {
std::cerr <<
"Builds unpruned language models with modified Kneser-Ney smoothing.\n\n"
"Please cite:\n"
- "@inproceedings{kenlm,\n"
- "author = {Kenneth Heafield},\n"
- "title = {{KenLM}: Faster and Smaller Language Model Queries},\n"
- "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n"
- "month = {July}, year={2011},\n"
- "address = {Edinburgh, UK},\n"
- "publisher = {Association for Computational Linguistics},\n"
+ "@inproceedings{Heafield-estimate,\n"
+ " author = {Kenneth Heafield and Ivan Pouzyrevsky and Jonathan H. Clark and Philipp Koehn},\n"
+ " title = {Scalable Modified {Kneser-Ney} Language Model Estimation},\n"
+ " year = {2013},\n"
+ " month = {8},\n"
+ " booktitle = {Proceedings of the 51st Annual Meeting of the Association for Computational Linguistics},\n"
+ " address = {Sofia, Bulgaria},\n"
+ " url = {http://kheafield.com/professional/edinburgh/estimate\\_paper.pdf},\n"
"}\n\n"
"Provide the corpus on stdin. The ARPA file will be written to stdout. Order of\n"
"the model (-o) is the only mandatory option. As this is an on-disk program,\n"
diff --git a/klm/lm/builder/ngram.hh b/klm/lm/builder/ngram.hh
index 2984ed0b..f5681516 100644
--- a/klm/lm/builder/ngram.hh
+++ b/klm/lm/builder/ngram.hh
@@ -53,7 +53,7 @@ class NGram {
Payload &Value() { return *reinterpret_cast<Payload *>(end_); }
uint64_t &Count() { return Value().count; }
- const uint64_t Count() const { return Value().count; }
+ uint64_t Count() const { return Value().count; }
std::size_t Order() const { return end_ - begin_; }
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index a40fd2fb..a26654a6 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -304,5 +304,26 @@ template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiks
template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
} // namespace detail
+
+base::Model *LoadVirtual(const char *file_name, const Config &config, ModelType model_type) {
+ RecognizeBinary(file_name, model_type);
+ switch (model_type) {
+ case PROBING:
+ return new ProbingModel(file_name, config);
+ case REST_PROBING:
+ return new RestProbingModel(file_name, config);
+ case TRIE:
+ return new TrieModel(file_name, config);
+ case QUANT_TRIE:
+ return new QuantTrieModel(file_name, config);
+ case ARRAY_TRIE:
+ return new ArrayTrieModel(file_name, config);
+ case QUANT_ARRAY_TRIE:
+ return new QuantArrayTrieModel(file_name, config);
+ default:
+ UTIL_THROW(FormatLoadException, "Confused by model type " << model_type);
+ }
+}
+
} // namespace ngram
} // namespace lm
diff --git a/klm/lm/model.hh b/klm/lm/model.hh
index 13ff864e..60f55110 100644
--- a/klm/lm/model.hh
+++ b/klm/lm/model.hh
@@ -153,6 +153,11 @@ LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel<trie::TrieSearch<Separat
typedef ::lm::ngram::ProbingVocabulary Vocabulary;
typedef ProbingModel Model;
+/* Autorecognize the file type, load, and return the virtual base class. Don't
+ * use the virtual base class if you can avoid it. Instead, use the above
+ * classes as template arguments to your own virtual feature function.*/
+base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), ModelType if_arpa = PROBING);
+
} // namespace ngram
} // namespace lm
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index 2d6f15b2..62275d27 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -54,7 +54,7 @@ template <class Weights> class ActivateUnigram {
Weights *modify_;
};
-// Find the lower order entry, inserting blanks along the way as necessary.
+// Find the lower order entry, inserting blanks along the way as necessary.
template <class Value> void FindLower(
const std::vector<uint64_t> &keys,
typename Value::Weights &unigram,
@@ -64,7 +64,7 @@ template <class Value> void FindLower(
typename Value::ProbingEntry entry;
// Backoff will always be 0.0. We'll get the probability and rest in another pass.
entry.value.backoff = kNoExtensionBackoff;
- // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb.
+ // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb.
for (int lower = keys.size() - 2; ; --lower) {
if (lower == -1) {
between.push_back(&unigram);
@@ -77,11 +77,11 @@ template <class Value> void FindLower(
}
}
-// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here.
+// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here.
template <class Added, class Build> void AdjustLower(
const Added &added,
const Build &build,
- std::vector<typename Build::Value::Weights *> &between,
+ std::vector<typename Build::Value::Weights *> &between,
const unsigned int n,
const std::vector<WordIndex> &vocab_ids,
typename Build::Value::Weights *unigrams,
@@ -93,14 +93,14 @@ template <class Added, class Build> void AdjustLower(
}
typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
float prob = -fabs(between.back()->prob);
- // Order of the n-gram on which probabilities are based.
+ // Order of the n-gram on which probabilities are based.
unsigned char basis = n - between.size();
assert(basis != 0);
typename Build::Value::Weights **change = &between.back();
// Skip the basis.
--change;
if (basis == 1) {
- // Hallucinate a bigram based on a unigram's backoff and a unigram probability.
+ // Hallucinate a bigram based on a unigram's backoff and a unigram probability.
float &backoff = unigrams[vocab_ids[1]].backoff;
SetExtension(backoff);
prob += backoff;
@@ -128,14 +128,14 @@ template <class Added, class Build> void AdjustLower(
typename std::vector<typename Value::Weights *>::const_iterator i(between.begin());
build.MarkExtends(**i, added);
const typename Value::Weights *longer = *i;
- // Everything has probability but is not marked as extending.
+ // Everything has probability but is not marked as extending.
for (++i; i != between.end(); ++i) {
build.MarkExtends(**i, *longer);
longer = *i;
}
}
-// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds.
+// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds.
template <class Build> void MarkLower(
const std::vector<uint64_t> &keys,
const Build &build,
@@ -144,15 +144,15 @@ template <class Build> void MarkLower(
int start_order,
const typename Build::Value::Weights &longer) {
if (start_order == 0) return;
- typename util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash>::MutableIterator iter;
- // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code.
+ // Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code.
for (int even_lower = start_order - 2 /* index in middle */; ; --even_lower) {
if (even_lower == -1) {
build.MarkExtends(unigram, longer);
return;
}
- middle[even_lower].UnsafeMutableFind(keys[even_lower], iter);
- if (!build.MarkExtends(iter->value, longer)) return;
+ if (!build.MarkExtends(
+ middle[even_lower].UnsafeMutableMustFind(keys[even_lower])->value,
+ longer)) return;
}
}
@@ -168,7 +168,6 @@ template <class Build, class Activate, class Store> void ReadNGrams(
Store &store,
PositiveProbWarn &warn) {
typedef typename Build::Value Value;
- typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
assert(n >= 2);
ReadNGramHeader(f, n);
@@ -186,7 +185,7 @@ template <class Build, class Activate, class Store> void ReadNGrams(
for (unsigned int h = 1; h < n - 1; ++h) {
keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]);
}
- // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0.
+ // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0.
util::SetSign(entry.value.prob);
entry.key = keys[n-2];
@@ -203,7 +202,7 @@ template <class Build, class Activate, class Store> void ReadNGrams(
} // namespace
namespace detail {
-
+
template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
std::size_t allocated = Unigram::Size(counts[0]);
unigram_ = Unigram(start, counts[0], allocated);
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh
index 00595796..9d067bc2 100644
--- a/klm/lm/search_hashed.hh
+++ b/klm/lm/search_hashed.hh
@@ -71,7 +71,7 @@ template <class Value> class HashedSearch {
static const bool kDifferentRest = Value::kDifferentRest;
static const unsigned int kVersion = 0;
- // TODO: move probing_multiplier here with next binary file format update.
+ // TODO: move probing_multiplier here with next binary file format update.
static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
@@ -102,14 +102,9 @@ template <class Value> class HashedSearch {
return ret;
}
-#pragma GCC diagnostic ignored "-Wuninitialized"
MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const {
node = extend_pointer;
- typename Middle::ConstIterator found;
- bool got = middle_[extend_length - 2].Find(extend_pointer, found);
- assert(got);
- (void)got;
- return MiddlePointer(found->value);
+ return MiddlePointer(middle_[extend_length - 2].MustFind(extend_pointer)->value);
}
MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const {
@@ -126,14 +121,14 @@ template <class Value> class HashedSearch {
}
LongestPointer LookupLongest(WordIndex word, const Node &node) const {
- // Sign bit is always on because longest n-grams do not extend left.
+ // Sign bit is always on because longest n-grams do not extend left.
typename Longest::ConstIterator found;
if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer();
return LongestPointer(found->value.prob);
}
- // Generate a node without necessarily checking that it actually exists.
- // Optionally return false if it's know to not exist.
+ // Generate a node without necessarily checking that it actually exists.
+ // Optionally return false if it's know to not exist.
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
assert(begin != end);
node = static_cast<Node>(*begin);
@@ -144,7 +139,7 @@ template <class Value> class HashedSearch {
}
private:
- // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.
+ // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.
void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn);
template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build);
@@ -153,7 +148,7 @@ template <class Value> class HashedSearch {
public:
Unigram() {}
- Unigram(void *start, uint64_t count, std::size_t /*allocated*/) :
+ Unigram(void *start, uint64_t count, std::size_t /*allocated*/) :
unigram_(static_cast<typename Value::Weights*>(start))
#ifdef DEBUG
, count_(count)
diff --git a/klm/lm/state.hh b/klm/lm/state.hh
index d8e6c132..a6b9accb 100644
--- a/klm/lm/state.hh
+++ b/klm/lm/state.hh
@@ -91,7 +91,7 @@ inline uint64_t hash_value(const Left &left) {
}
struct ChartState {
- bool operator==(const ChartState &other) {
+ bool operator==(const ChartState &other) const {
return (right == other.right) && (left == other.left);
}
diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh
index 6a5a0196..17f064b2 100644
--- a/klm/lm/virtual_interface.hh
+++ b/klm/lm/virtual_interface.hh
@@ -6,6 +6,7 @@
#include "util/string_piece.hh"
#include <string>
+#include <string.h>
namespace lm {
namespace base {
@@ -119,7 +120,9 @@ class Model {
size_t StateSize() const { return state_size_; }
const void *BeginSentenceMemory() const { return begin_sentence_memory_; }
+ void BeginSentenceWrite(void *to) const { memcpy(to, begin_sentence_memory_, StateSize()); }
const void *NullContextMemory() const { return null_context_memory_; }
+ void NullContextWrite(void *to) const { memcpy(to, null_context_memory_, StateSize()); }
// Requires in_state != out_state
virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index 3902f117..226ae438 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -25,7 +25,7 @@ uint64_t HashForVocab(const char *str, std::size_t len);
inline uint64_t HashForVocab(const StringPiece &str) {
return HashForVocab(str.data(), str.length());
}
-class ProbingVocabularyHeader;
+struct ProbingVocabularyHeader;
} // namespace detail
class WriteWordsWrapper : public EnumerateVocab {