summaryrefslogtreecommitdiff
path: root/klm/lm
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm')
-rw-r--r--klm/lm/left.hh26
-rw-r--r--klm/lm/search_trie.cc7
-rw-r--r--klm/lm/search_trie.hh2
-rw-r--r--klm/lm/trie_sort.cc2
4 files changed, 23 insertions, 14 deletions
diff --git a/klm/lm/left.hh b/klm/lm/left.hh
index effa0560..bb3f5539 100644
--- a/klm/lm/left.hh
+++ b/klm/lm/left.hh
@@ -117,7 +117,7 @@ inline size_t hash_value(const ChartState &state) {
template <class M> class RuleScore {
public:
- explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) {
+ explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) {
out.left.length = 0;
out.right.length = 0;
}
@@ -130,15 +130,22 @@ template <class M> class RuleScore {
void Terminal(WordIndex word) {
State copy(out_.right);
- ProcessRet(model_.FullScore(copy, word, out_.right));
- if (out_.right.length != copy.length + 1) left_done_ = true;
+ FullScoreReturn ret(model_.FullScore(copy, word, out_.right));
+ prob_ += ret.prob;
+ if (left_done_) return;
+ if (ret.independent_left) {
+ left_done_ = true;
+ return;
+ }
+ out_.left.pointers[out_.left.length++] = ret.extend_left;
+ if (out_.right.length != copy.length + 1)
+ left_done_ = true;
}
// Faster version of NonTerminal for the case where the rule begins with a non-terminal.
void BeginNonTerminal(const ChartState &in, float prob) {
prob_ = prob;
out_ = in;
- left_write_ = out_.left.pointers + out_.left.length;
left_done_ = in.full;
}
@@ -157,11 +164,10 @@ template <class M> class RuleScore {
if (!out_.right.length) {
out_.right = in.right;
if (left_done_) return;
- if (left_write_ != out_.left.pointers) {
+ if (out_.left.length) {
left_done_ = true;
} else {
out_.left = in.left;
- left_write_ = out_.left.pointers + in.left.length;
left_done_ = in.full;
}
return;
@@ -214,8 +220,8 @@ template <class M> class RuleScore {
}
float Finish() {
- out_.left.length = left_write_ - out_.left.pointers;
- out_.full = left_done_;
+ // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram.
+ out_.full = left_done_ || (out_.left.length == model_.Order() - 1);
return prob_;
}
@@ -227,7 +233,7 @@ template <class M> class RuleScore {
left_done_ = true;
return;
}
- *(left_write_++) = ret.extend_left;
+ out_.left.pointers[out_.left.length++] = ret.extend_left;
}
const M &model_;
@@ -236,8 +242,6 @@ template <class M> class RuleScore {
bool left_done_;
- uint64_t *left_write_;
-
float prob_;
};
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 6479813b..5d8c70db 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -151,6 +151,11 @@ class BackoffMessages {
private:
void FinishedAdding() {
Resize(current_ - (uint8_t*)backing_.get());
+ // Sort requests in same order as files.
+ std::sort(
+ util::SizedIterator(util::SizedProxy(backing_.get(), entry_size_)),
+ util::SizedIterator(util::SizedProxy(current_, entry_size_)),
+ util::SizedCompare<EntryCompare>(EntryCompare((entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex))));
current_ = (uint8_t*)backing_.get();
}
@@ -525,7 +530,7 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre
const RecordReader &context = contexts[order - 2];
if (context) {
FormatLoadException e;
- e << "An " << static_cast<unsigned int>(order) << "-gram has context";
+ e << "A " << static_cast<unsigned int>(order) << "-gram has context";
const WordIndex *ctx = reinterpret_cast<const WordIndex*>(context.Data());
for (const WordIndex *i = ctx; i != ctx + order - 1; ++i) {
e << ' ' << *i;
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh
index c3e02a98..33ae8cff 100644
--- a/klm/lm/search_trie.hh
+++ b/klm/lm/search_trie.hh
@@ -36,7 +36,7 @@ template <class Quant, class Bhiksha> class TrieSearch {
static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);
- static const unsigned int kVersion = 0;
+ static const unsigned int kVersion = 1;
static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
Quant::UpdateConfigFromBinary(fd, counts, config);
diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc
index 86f28493..bb126f18 100644
--- a/klm/lm/trie_sort.cc
+++ b/klm/lm/trie_sort.cc
@@ -191,7 +191,7 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++);
files.push_back(assembled.str());
MergeSortedFiles(files[0], files[1], files.back(), weights_size, order, ThrowCombine());
- MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order, FirstCombine());
+ MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order - 1, FirstCombine());
files.pop_front();
files.pop_front();
}