diff options
Diffstat (limited to 'klm/lm')
-rw-r--r-- | klm/lm/left.hh | 66 | ||||
-rw-r--r-- | klm/lm/max_order.hh | 7 | ||||
-rw-r--r-- | klm/lm/model.cc | 3 | ||||
-rw-r--r-- | klm/lm/search_hashed.cc | 8 | ||||
-rw-r--r-- | klm/lm/search_hashed.hh | 2 | ||||
-rw-r--r-- | klm/lm/vocab.cc | 7 | ||||
-rw-r--r-- | klm/lm/vocab.hh | 5 |
7 files changed, 54 insertions, 44 deletions
diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 8c27232e..85c1ea37 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -51,36 +51,36 @@ namespace ngram { template <class M> class RuleScore { public: - explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) { + explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(&out), left_done_(false), prob_(0.0) { out.left.length = 0; out.right.length = 0; } void BeginSentence() { - out_.right = model_.BeginSentenceState(); - // out_.left is empty. + out_->right = model_.BeginSentenceState(); + // out_->left is empty. left_done_ = true; } void Terminal(WordIndex word) { - State copy(out_.right); - FullScoreReturn ret(model_.FullScore(copy, word, out_.right)); + State copy(out_->right); + FullScoreReturn ret(model_.FullScore(copy, word, out_->right)); if (left_done_) { prob_ += ret.prob; return; } if (ret.independent_left) { prob_ += ret.prob; left_done_ = true; return; } - out_.left.pointers[out_.left.length++] = ret.extend_left; + out_->left.pointers[out_->left.length++] = ret.extend_left; prob_ += ret.rest; - if (out_.right.length != copy.length + 1) + if (out_->right.length != copy.length + 1) left_done_ = true; } // Faster version of NonTerminal for the case where the rule begins with a non-terminal. void BeginNonTerminal(const ChartState &in, float prob = 0.0) { prob_ = prob; - out_ = in; + *out_ = in; left_done_ = in.left.full; } @@ -89,23 +89,23 @@ template <class M> class RuleScore { if (!in.left.length) { if (in.left.full) { - for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i; + for (const float *i = out_->right.backoff; i < out_->right.backoff + out_->right.length; ++i) prob_ += *i; left_done_ = true; - out_.right = in.right; + out_->right = in.right; } return; } - if (!out_.right.length) { - out_.right = in.right; + if (!out_->right.length) { + out_->right = in.right; if (left_done_) { prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1); return; } - if (out_.left.length) { + if (out_->left.length) { left_done_ = true; } else { - out_.left = in.left; + out_->left = in.left; left_done_ = in.left.full; } return; @@ -113,10 +113,10 @@ template <class M> class RuleScore { float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1]; float *back = backoffs, *back2 = backoffs2; - unsigned char next_use = out_.right.length; + unsigned char next_use = out_->right.length; // First word - if (ExtendLeft(in, next_use, 1, out_.right.backoff, back)) return; + if (ExtendLeft(in, next_use, 1, out_->right.backoff, back)) return; // Words after the first, so extending a bigram to begin with for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) { @@ -127,54 +127,58 @@ template <class M> class RuleScore { if (in.left.full) { for (const float *i = back; i != back + next_use; ++i) prob_ += *i; left_done_ = true; - out_.right = in.right; + out_->right = in.right; return; } // Right state was minimized, so it's already independent of the new words to the left. if (in.right.length < in.left.length) { - out_.right = in.right; + out_->right = in.right; return; } // Shift exisiting words down. - for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) { + for (WordIndex *i = out_->right.words + next_use - 1; i >= out_->right.words; --i) { *(i + in.right.length) = *i; } // Add words from in.right. - std::copy(in.right.words, in.right.words + in.right.length, out_.right.words); + std::copy(in.right.words, in.right.words + in.right.length, out_->right.words); // Assemble backoff composed on the existing state's backoff followed by the new state's backoff. - std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff); - std::copy(back, back + next_use, out_.right.backoff + in.right.length); - out_.right.length = in.right.length + next_use; + std::copy(in.right.backoff, in.right.backoff + in.right.length, out_->right.backoff); + std::copy(back, back + next_use, out_->right.backoff + in.right.length); + out_->right.length = in.right.length + next_use; } float Finish() { // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram. - out_.left.full = left_done_ || (out_.left.length == model_.Order() - 1); + out_->left.full = left_done_ || (out_->left.length == model_.Order() - 1); return prob_; } void Reset() { prob_ = 0.0; left_done_ = false; - out_.left.length = 0; - out_.right.length = 0; + out_->left.length = 0; + out_->right.length = 0; + } + void Reset(ChartState &replacement) { + out_ = &replacement; + Reset(); } private: bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) { ProcessRet(model_.ExtendLeft( - out_.right.words, out_.right.words + next_use, // Words to extend into + out_->right.words, out_->right.words + next_use, // Words to extend into back_in, // Backoffs to use in.left.pointers[extend_length - 1], extend_length, // Words to be extended back_out, // Backoffs for the next score next_use)); // Length of n-gram to use in next scoring. - if (next_use != out_.right.length) { + if (next_use != out_->right.length) { left_done_ = true; if (!next_use) { // Early exit. - out_.right = in.right; + out_->right = in.right; prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1); return true; } @@ -193,13 +197,13 @@ template <class M> class RuleScore { left_done_ = true; return; } - out_.left.pointers[out_.left.length++] = ret.extend_left; + out_->left.pointers[out_->left.length++] = ret.extend_left; prob_ += ret.rest; } const M &model_; - ChartState &out_; + ChartState *out_; bool left_done_; diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh index 989f8324..ea0dea46 100644 --- a/klm/lm/max_order.hh +++ b/klm/lm/max_order.hh @@ -4,9 +4,8 @@ * (kMaxOrder - 1) * sizeof(float) bytes instead of * sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead */ -#ifndef KENLM_MAX_ORDER -#define KENLM_MAX_ORDER 6 -#endif #ifndef KENLM_ORDER_MESSAGE -#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --kenlm-max-order=6 -a'. Otherwise, edit lm/max_order.hh." +#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh." #endif + +#define KENLM_MAX_ORDER 5 diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 2fd20481..fc61efee 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -87,7 +87,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT WriteWordsWrapper wrap(config.enumerate_vocab); vocab_.ConfigureEnumerate(&wrap, counts[0]); search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + backing_.search.size()); + wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config)); } else { vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); @@ -227,6 +227,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const { + assert(new_word < vocab_.Bound()); FullScoreReturn ret; // ret.ngram_length contains the last known non-blank ngram length. ret.ngram_length = 1; diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index a1623834..2d6f15b2 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -231,7 +231,7 @@ template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * template <> void HashedSearch<BackoffValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { NoRestBuild build; - ApplyBuild(f, counts, config, vocab, warn, build); + ApplyBuild(f, counts, vocab, warn, build); } template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { @@ -239,19 +239,19 @@ template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, cons case Config::REST_MAX: { MaxRestBuild build; - ApplyBuild(f, counts, config, vocab, warn, build); + ApplyBuild(f, counts, vocab, warn, build); } break; case Config::REST_LOWER: { LowerRestBuild<ProbingModel> build(config, counts.size(), vocab); - ApplyBuild(f, counts, config, vocab, warn, build); + ApplyBuild(f, counts, vocab, warn, build); } break; } } -template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) { +template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) { for (WordIndex i = 0; i < counts[0]; ++i) { build.SetRest(&i, (unsigned int)1, unigram_.Raw()[i]); } diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index a52f107b..00595796 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -147,7 +147,7 @@ template <class Value> class HashedSearch { // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild. void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn); - template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); + template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); class Unigram { public: diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 11c27518..fd7f96dc 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -116,7 +116,9 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { } *end_ = hashed; if (enumerate_) { - strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size()); + void *copied = string_backing_.Allocate(str.size()); + memcpy(copied, str.data(), str.size()); + strings_to_enumerate_[end_ - begin_] = StringPiece(static_cast<const char*>(copied), str.size()); } ++end_; // This is 1 + the offset where it was inserted to make room for unk. @@ -126,7 +128,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { if (enumerate_) { if (!strings_to_enumerate_.empty()) { - util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); + util::PairedIterator<ProbBackoff*, StringPiece*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); util::JointSort(begin_, end_, values); } for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) { @@ -134,6 +136,7 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { enumerate_->Add(i + 1, strings_to_enumerate_[i]); } strings_to_enumerate_.clear(); + string_backing_.FreeAll(); } else { util::JointSort(begin_, end_, reorder_vocab + 1); } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index de54eb06..3902f117 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -4,6 +4,7 @@ #include "lm/enumerate_vocab.hh" #include "lm/lm_exception.hh" #include "lm/virtual_interface.hh" +#include "util/pool.hh" #include "util/probing_hash_table.hh" #include "util/sorted_uniform.hh" #include "util/string_piece.hh" @@ -96,7 +97,9 @@ class SortedVocabulary : public base::Vocabulary { EnumerateVocab *enumerate_; // Actual strings. Used only when loading from ARPA and enumerate_ != NULL - std::vector<std::string> strings_to_enumerate_; + util::Pool string_backing_; + + std::vector<StringPiece> strings_to_enumerate_; }; #pragma pack(push) |