summaryrefslogtreecommitdiff
path: root/klm/lm/trie_sort.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/trie_sort.hh')
-rw-r--r--klm/lm/trie_sort.hh55
1 files changed, 39 insertions, 16 deletions
diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh
index a6916483..3036319d 100644
--- a/klm/lm/trie_sort.hh
+++ b/klm/lm/trie_sort.hh
@@ -1,6 +1,9 @@
+// Step of trie builder: create sorted files.
+
#ifndef LM_TRIE_SORT__
#define LM_TRIE_SORT__
+#include "lm/max_order.hh"
#include "lm/word_index.hh"
#include "util/file.hh"
@@ -11,20 +14,21 @@
#include <string>
#include <vector>
-#include <inttypes.h>
+#include <stdint.h>
-namespace util { class FilePiece; }
+namespace util {
+class FilePiece;
+class TempMaker;
+} // namespace util
-// Step of trie builder: create sorted files.
namespace lm {
+class PositiveProbWarn;
namespace ngram {
class SortedVocabulary;
class Config;
namespace trie {
-extern const char *kContextSuffix;
-FILE *OpenOrThrow(const char *name, const char *mode);
void WriteOrThrow(FILE *to, const void *data, size_t size);
class EntryCompare : public std::binary_function<const void*, const void*, bool> {
@@ -49,15 +53,15 @@ class RecordReader {
public:
RecordReader() : remains_(true) {}
- void Init(const std::string &name, std::size_t entry_size);
+ void Init(FILE *file, std::size_t entry_size);
void *Data() { return data_.get(); }
const void *Data() const { return data_.get(); }
RecordReader &operator++() {
- std::size_t ret = fread(data_.get(), entry_size_, 1, file_.get());
+ std::size_t ret = fread(data_.get(), entry_size_, 1, file_);
if (!ret) {
- UTIL_THROW_IF(!feof(file_.get()), util::ErrnoException, "Error reading temporary file");
+ UTIL_THROW_IF(!feof(file_), util::ErrnoException, "Error reading temporary file");
remains_ = false;
}
return *this;
@@ -65,27 +69,46 @@ class RecordReader {
operator bool() const { return remains_; }
- void Rewind() {
- rewind(file_.get());
- remains_ = true;
- ++*this;
- }
+ void Rewind();
std::size_t EntrySize() const { return entry_size_; }
void Overwrite(const void *start, std::size_t amount);
private:
+ FILE *file_;
+
util::scoped_malloc data_;
bool remains_;
std::size_t entry_size_;
-
- util::scoped_FILE file_;
};
-void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab);
+class SortedFiles {
+ public:
+ // Build from ARPA
+ SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab);
+
+ int StealUnigram() {
+ return unigram_.release();
+ }
+
+ FILE *Full(unsigned char order) {
+ return full_[order - 2].get();
+ }
+
+ FILE *Context(unsigned char of_order) {
+ return context_[of_order - 2].get();
+ }
+
+ private:
+ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size);
+
+ util::scoped_fd unigram_;
+
+ util::scoped_FILE full_[kMaxOrder - 1], context_[kMaxOrder - 1];
+};
} // namespace trie
} // namespace ngram