diff options
| author | Kenneth Heafield <github@kheafield.com> | 2014-01-27 17:42:19 -0800 | 
|---|---|---|
| committer | Kenneth Heafield <github@kheafield.com> | 2014-01-27 17:42:19 -0800 | 
| commit | 783c57b2d3312738ddcf992ac55ff750afe7cb47 (patch) | |
| tree | c4811dab0d916836b8631f3c7df94f284a490b9b /klm/lm/search_trie.cc | |
| parent | f7e051a05d65ef25c2ada0b84cd82bfb375ef265 (diff) | |
KenLM 5cc905bc2d214efa7de2db56a9a672b749a95591
Diffstat (limited to 'klm/lm/search_trie.cc')
| -rw-r--r-- | klm/lm/search_trie.cc | 35 | 
1 files changed, 11 insertions, 24 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 1b0d9b26..4a88194e 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -253,11 +253,6 @@ class FindBlanks {        ++counts_.back();      } -    // Unigrams wrote one past. -    void Cleanup() { -      --counts_[0]; -    } -      const std::vector<uint64_t> &Counts() const {        return counts_;      } @@ -310,8 +305,6 @@ template <class Quant, class Bhiksha> class WriteEntries {        typename Quant::LongestPointer(quant_, longest_.Insert(words[order_ - 1])).Write(reinterpret_cast<const Prob*>(words + order_)->prob);      } -    void Cleanup() {} -    private:      RecordReader *contexts_;      const Quant &quant_; @@ -385,14 +378,14 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con    util::ErsatzProgress progress(unigram_count + 1, progress_out, message);    WordIndex unigram = 0;    std::priority_queue<Gram> grams; -  grams.push(Gram(&unigram, 1)); +  if (unigram_count) grams.push(Gram(&unigram, 1));    for (unsigned char i = 2; i <= total_order; ++i) {      if (input[i-2]) grams.push(Gram(reinterpret_cast<const WordIndex*>(input[i-2].Data()), i));    }    BlankManager<Doing> blank(total_order, doing); -  while (true) { +  while (!grams.empty()) {      Gram top = grams.top();      grams.pop();      unsigned char order = top.end - top.begin; @@ -400,8 +393,7 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con        blank.Visit(&unigram, 1, doing.UnigramProb(unigram));        doing.Unigram(unigram);        progress.Set(unigram); -      if (++unigram == unigram_count + 1) break; -      grams.push(top); +      if (++unigram < unigram_count) grams.push(top);      } else {        if (order == total_order) {          blank.Visit(top.begin, order, reinterpret_cast<const Prob*>(top.end)->prob); @@ -414,8 +406,6 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con        if (++reader) grams.push(top);      }    } -  assert(grams.empty()); -  doing.Cleanup();  }  void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) { @@ -469,7 +459,7 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c  } // namespace -template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { +template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing) {    RecordReader inputs[KENLM_MAX_ORDER - 1];    RecordReader contexts[KENLM_MAX_ORDER - 1]; @@ -498,7 +488,10 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve    sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); -  out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config); +  void *vocab_relocate; +  void *search_base = backing.GrowForSearch(TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), vocab.UnkCountChangePadding(), vocab_relocate); +  vocab.Relocate(vocab_relocate); +  out.SetupMemory(reinterpret_cast<uint8_t*>(search_base), fixed_counts, config);    for (unsigned char i = 2; i <= counts.size(); ++i) {      inputs[i-2].Rewind(); @@ -524,6 +517,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve    {      WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri);      RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer); +    // Write the last unigram entry, which is the end pointer for the bigrams.   +    writer.Unigram(counts[0]);    }    // Do not disable this error message or else too little state will be returned.  Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. @@ -579,15 +574,7 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup    return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);  } -template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() { -  unigram_.LoadedBinary(); -  for (Middle *i = middle_begin_; i != middle_end_; ++i) { -    i->LoadedBinary(); -  } -  longest_.LoadedBinary(); -} - -template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing) {    std::string temporary_prefix;    if (config.temporary_directory_prefix) {      temporary_prefix = config.temporary_directory_prefix;  | 
