diff options
Diffstat (limited to 'klm')
-rw-r--r-- | klm/lm/bhiksha.hh | 5 | ||||
-rw-r--r-- | klm/lm/build_binary.cc | 2 | ||||
-rw-r--r-- | klm/lm/left.hh | 39 | ||||
-rw-r--r-- | klm/lm/vocab.cc | 1 | ||||
-rw-r--r-- | klm/lm/vocab.hh | 1 | ||||
-rw-r--r-- | klm/util/probing_hash_table.hh | 4 |
6 files changed, 33 insertions, 19 deletions
diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index bc705959..3df43dda 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -10,6 +10,9 @@ * Currently only used for next pointers. */ +#ifndef LM_BHIKSHA__ +#define LM_BHIKSHA__ + #include <inttypes.h> #include <assert.h> @@ -108,3 +111,5 @@ class ArrayBhiksha { } // namespace trie } // namespace ngram } // namespace lm + +#endif // LM_BHIKSHA__ diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index b7aee4de..fdb62a71 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,7 +15,7 @@ namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-c bits] [type] input.arpa [output.mmap]\n\n" + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n" "-u sets the log10 probability for <unk> if the ARPA file does not have one.\n" " Default is -100. The ARPA file will always take precedence.\n" "-s allows models to be built even if they do not have <s> and </s>.\n" diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 15464c82..41f71f84 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -175,24 +175,14 @@ template <class M> class RuleScore { float backoffs[kMaxOrder - 1], backoffs2[kMaxOrder - 1]; float *back = backoffs, *back2 = backoffs2; - unsigned char next_use; + unsigned char next_use = out_.right.length; // First word - ProcessRet(model_.ExtendLeft(out_.right.words, out_.right.words + out_.right.length, out_.right.backoff, in.left.pointers[0], 1, back, next_use)); - if (!next_use) { - left_done_ = true; - out_.right = in.right; - return; - } + if (ExtendLeft(in, next_use, 1, out_.right.backoff, back)) return; + // Words after the first, so extending a bigram to begin with - unsigned char extend_length = 2; - for (const uint64_t *i = in.left.pointers + 1; i < in.left.pointers + in.left.length; ++i, ++extend_length) { - ProcessRet(model_.ExtendLeft(out_.right.words, out_.right.words + next_use, back, *i, extend_length, back2, next_use)); - if (!next_use) { - left_done_ = true; - out_.right = in.right; - return; - } + for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) { + if (ExtendLeft(in, next_use, extend_length, back, back2)) return; std::swap(back, back2); } @@ -228,6 +218,25 @@ template <class M> class RuleScore { } private: + bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) { + ProcessRet(model_.ExtendLeft( + out_.right.words, out_.right.words + next_use, // Words to extend into + back_in, // Backoffs to use + in.left.pointers[extend_length - 1], extend_length, // Words to be extended + back_out, // Backoffs for the next score + next_use)); // Length of n-gram to use in next scoring. + if (next_use != out_.right.length) { + left_done_ = true; + if (!next_use) { + out_.right = in.right; + // Early exit. + return true; + } + } + // Continue scoring. + return false; + } + void ProcessRet(const FullScoreReturn &ret) { prob_ += ret.prob; if (left_done_) return; diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 03b0767a..ffec41ca 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -135,6 +135,7 @@ void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1); ReadWords(fd, to); SetSpecial(Index("<s>"), Index("</s>"), 0); + bound_ = end_ - begin_ + 1; } namespace { diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 4cf68196..3c3414fb 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -66,7 +66,6 @@ class SortedVocabulary : public base::Vocabulary { static size_t Size(std::size_t entries, const Config &config); // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. - // While this number is correct, ProbingVocabulary::Bound might not be correct in some cases. WordIndex Bound() const { return bound_; } // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 2ec342a6..8122d69c 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -61,14 +61,14 @@ template <class PackingT, class HashT, class EqualT = std::equal_to<typename Pac #endif {} - template <class T> void Insert(const T &t) { + template <class T> MutableIterator Insert(const T &t) { if (++entries_ >= buckets_) UTIL_THROW(ProbingSizeException, "Hash table with " << buckets_ << " buckets is full."); #ifdef DEBUG assert(initialized_); #endif for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) { - if (equal_(i->GetKey(), invalid_)) { *i = t; return; } + if (equal_(i->GetKey(), invalid_)) { *i = t; return i; } if (++i == end_) { i = begin_; } } } |