summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--klm/lm/bhiksha.hh5
-rw-r--r--klm/lm/build_binary.cc2
-rw-r--r--klm/lm/left.hh39
-rw-r--r--klm/lm/vocab.cc1
-rw-r--r--klm/lm/vocab.hh1
-rw-r--r--klm/util/probing_hash_table.hh4
6 files changed, 33 insertions, 19 deletions
diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh
index bc705959..3df43dda 100644
--- a/klm/lm/bhiksha.hh
+++ b/klm/lm/bhiksha.hh
@@ -10,6 +10,9 @@
* Currently only used for next pointers.
*/
+#ifndef LM_BHIKSHA__
+#define LM_BHIKSHA__
+
#include <inttypes.h>
#include <assert.h>
@@ -108,3 +111,5 @@ class ArrayBhiksha {
} // namespace trie
} // namespace ngram
} // namespace lm
+
+#endif // LM_BHIKSHA__
diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc
index b7aee4de..fdb62a71 100644
--- a/klm/lm/build_binary.cc
+++ b/klm/lm/build_binary.cc
@@ -15,7 +15,7 @@ namespace ngram {
namespace {
void Usage(const char *name) {
- std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-c bits] [type] input.arpa [output.mmap]\n\n"
+ std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"
"-u sets the log10 probability for <unk> if the ARPA file does not have one.\n"
" Default is -100. The ARPA file will always take precedence.\n"
"-s allows models to be built even if they do not have <s> and </s>.\n"
diff --git a/klm/lm/left.hh b/klm/lm/left.hh
index 15464c82..41f71f84 100644
--- a/klm/lm/left.hh
+++ b/klm/lm/left.hh
@@ -175,24 +175,14 @@ template <class M> class RuleScore {
float backoffs[kMaxOrder - 1], backoffs2[kMaxOrder - 1];
float *back = backoffs, *back2 = backoffs2;
- unsigned char next_use;
+ unsigned char next_use = out_.right.length;
// First word
- ProcessRet(model_.ExtendLeft(out_.right.words, out_.right.words + out_.right.length, out_.right.backoff, in.left.pointers[0], 1, back, next_use));
- if (!next_use) {
- left_done_ = true;
- out_.right = in.right;
- return;
- }
+ if (ExtendLeft(in, next_use, 1, out_.right.backoff, back)) return;
+
// Words after the first, so extending a bigram to begin with
- unsigned char extend_length = 2;
- for (const uint64_t *i = in.left.pointers + 1; i < in.left.pointers + in.left.length; ++i, ++extend_length) {
- ProcessRet(model_.ExtendLeft(out_.right.words, out_.right.words + next_use, back, *i, extend_length, back2, next_use));
- if (!next_use) {
- left_done_ = true;
- out_.right = in.right;
- return;
- }
+ for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) {
+ if (ExtendLeft(in, next_use, extend_length, back, back2)) return;
std::swap(back, back2);
}
@@ -228,6 +218,25 @@ template <class M> class RuleScore {
}
private:
+ bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) {
+ ProcessRet(model_.ExtendLeft(
+ out_.right.words, out_.right.words + next_use, // Words to extend into
+ back_in, // Backoffs to use
+ in.left.pointers[extend_length - 1], extend_length, // Words to be extended
+ back_out, // Backoffs for the next score
+ next_use)); // Length of n-gram to use in next scoring.
+ if (next_use != out_.right.length) {
+ left_done_ = true;
+ if (!next_use) {
+ out_.right = in.right;
+ // Early exit.
+ return true;
+ }
+ }
+ // Continue scoring.
+ return false;
+ }
+
void ProcessRet(const FullScoreReturn &ret) {
prob_ += ret.prob;
if (left_done_) return;
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index 03b0767a..ffec41ca 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -135,6 +135,7 @@ void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
ReadWords(fd, to);
SetSpecial(Index("<s>"), Index("</s>"), 0);
+ bound_ = end_ - begin_ + 1;
}
namespace {
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index 4cf68196..3c3414fb 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -66,7 +66,6 @@ class SortedVocabulary : public base::Vocabulary {
static size_t Size(std::size_t entries, const Config &config);
// Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary.
- // While this number is correct, ProbingVocabulary::Bound might not be correct in some cases.
WordIndex Bound() const { return bound_; }
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh
index 2ec342a6..8122d69c 100644
--- a/klm/util/probing_hash_table.hh
+++ b/klm/util/probing_hash_table.hh
@@ -61,14 +61,14 @@ template <class PackingT, class HashT, class EqualT = std::equal_to<typename Pac
#endif
{}
- template <class T> void Insert(const T &t) {
+ template <class T> MutableIterator Insert(const T &t) {
if (++entries_ >= buckets_)
UTIL_THROW(ProbingSizeException, "Hash table with " << buckets_ << " buckets is full.");
#ifdef DEBUG
assert(initialized_);
#endif
for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) {
- if (equal_(i->GetKey(), invalid_)) { *i = t; return; }
+ if (equal_(i->GetKey(), invalid_)) { *i = t; return i; }
if (++i == end_) { i = begin_; }
}
}