summaryrefslogtreecommitdiff
path: root/klm
diff options
context:
space:
mode:
authorMichael Denkowski <michael.j.denkowski@gmail.com>2012-12-22 16:01:23 -0500
committerMichael Denkowski <michael.j.denkowski@gmail.com>2012-12-22 16:01:23 -0500
commit778a4cec55f82bcc66b3f52de7cc871e8daaeb92 (patch)
tree2a5bccaa85965855104c4e8ac3738b2e1c77f164 /klm
parent57fff9eea5ba0e71fb958fdb4f32d17f2fe31108 (diff)
parentd21491daa5e50b4456c7c5f9c2e51d25afd2a757 (diff)
Merge branch 'master' of git://github.com/redpony/cdec
Diffstat (limited to 'klm')
-rw-r--r--klm/lm/binary_format.cc21
-rw-r--r--klm/lm/config.cc1
-rw-r--r--klm/lm/config.hh59
-rw-r--r--klm/lm/left.hh66
-rw-r--r--klm/lm/max_order.hh5
-rw-r--r--klm/lm/model.cc33
-rw-r--r--klm/lm/search_hashed.cc8
-rw-r--r--klm/lm/search_hashed.hh2
-rw-r--r--klm/lm/search_trie.cc47
-rw-r--r--klm/lm/vocab.cc7
-rw-r--r--klm/lm/vocab.hh5
-rw-r--r--klm/search/Makefile.am4
-rw-r--r--klm/search/applied.hh86
-rw-r--r--klm/search/config.hh25
-rw-r--r--klm/search/context.hh28
-rw-r--r--klm/search/dedupe.hh131
-rw-r--r--klm/search/edge_generator.cc3
-rw-r--r--klm/search/edge_generator.hh1
-rw-r--r--klm/search/final.hh36
-rw-r--r--klm/search/header.hh9
-rw-r--r--klm/search/nbest.cc106
-rw-r--r--klm/search/nbest.hh81
-rw-r--r--klm/search/note.hh12
-rw-r--r--klm/search/rule.cc52
-rw-r--r--klm/search/rule.hh11
-rw-r--r--klm/search/types.hh17
-rw-r--r--klm/search/vertex.cc27
-rw-r--r--klm/search/vertex.hh37
-rw-r--r--klm/search/vertex_generator.cc44
-rw-r--r--klm/search/vertex_generator.hh72
-rw-r--r--klm/search/weights.cc71
-rw-r--r--klm/search/weights.hh52
-rw-r--r--klm/search/weights_test.cc38
-rw-r--r--klm/util/Makefile.am1
-rw-r--r--klm/util/exception.hh8
-rw-r--r--klm/util/file.cc38
-rw-r--r--klm/util/file.hh8
-rw-r--r--klm/util/file_piece.cc66
-rw-r--r--klm/util/file_piece.hh41
-rw-r--r--klm/util/file_piece_test.cc4
-rw-r--r--klm/util/have.hh12
-rw-r--r--klm/util/joint_sort.hh4
-rw-r--r--klm/util/read_compressed.cc403
-rw-r--r--klm/util/read_compressed.hh74
-rw-r--r--klm/util/read_compressed_test.cc94
-rw-r--r--klm/util/scoped.hh65
-rw-r--r--klm/util/string_piece.hh19
-rw-r--r--klm/util/tokenize_piece.hh14
48 files changed, 1438 insertions, 610 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index efa67056..39c4a9b6 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -16,11 +16,11 @@ namespace ngram {
namespace {
const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0";
-// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed).
+// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed).
const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n";
const long int kMagicVersion = 5;
-// Old binary files built on 32-bit machines have this header.
+// Old binary files built on 32-bit machines have this header.
// TODO: eliminate with next binary release.
struct OldSanity {
char magic[sizeof(kMagicBytes)];
@@ -39,7 +39,7 @@ struct OldSanity {
};
-// Test values aligned to 8 bytes.
+// Test values aligned to 8 bytes.
struct Sanity {
char magic[ALIGN8(sizeof(kMagicBytes))];
float zero_f, one_f, minus_half_f;
@@ -101,7 +101,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_
uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) {
std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;
if (config.write_mmap) {
- // Grow the file to accomodate the search, using zeros.
+ // Grow the file to accomodate the search, using zeros.
try {
util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size);
} catch (util::ErrnoException &e) {
@@ -114,7 +114,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
return reinterpret_cast<uint8_t*>(backing.search.get());
}
// mmap it now.
- // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
+ // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
std::size_t page_size = util::SizePage();
std::size_t alignment_cruft = adjusted_vocab % page_size;
backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
@@ -122,7 +122,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
} else {
util::MapAnonymous(memory_size, backing.search);
return reinterpret_cast<uint8_t*>(backing.search.get());
- }
+ }
}
void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) {
@@ -140,7 +140,7 @@ void FinishFile(const Config &config, ModelType model_type, unsigned int search_
util::FSyncOrThrow(backing.file.get());
break;
}
- // header and vocab share the same mmap. The header is written here because we know the counts.
+ // header and vocab share the same mmap. The header is written here because we know the counts.
Parameters params = Parameters();
params.counts = counts;
params.fixed.order = counts.size();
@@ -160,7 +160,7 @@ namespace detail {
bool IsBinaryFormat(int fd) {
const uint64_t size = util::SizeFile(fd);
if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false;
- // Try reading the header.
+ // Try reading the header.
util::scoped_memory memory;
try {
util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory);
@@ -214,7 +214,7 @@ void SeekPastHeader(int fd, const Parameters &params) {
uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing) {
const uint64_t file_size = util::SizeFile(backing.file.get());
- // The header is smaller than a page, so we have to map the whole header as well.
+ // The header is smaller than a page, so we have to map the whole header as well.
std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size);
if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
@@ -233,7 +233,8 @@ void ComplainAboutARPA(const Config &config, ModelType model_type) {
if (config.write_mmap || !config.messages) return;
if (config.arpa_complain == Config::ALL) {
*config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
- } else if (config.arpa_complain == Config::EXPENSIVE && model_type == TRIE_SORTED) {
+ } else if (config.arpa_complain == Config::EXPENSIVE &&
+ (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
*config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
}
}
diff --git a/klm/lm/config.cc b/klm/lm/config.cc
index f9d988ca..9520c41c 100644
--- a/klm/lm/config.cc
+++ b/klm/lm/config.cc
@@ -6,6 +6,7 @@ namespace lm {
namespace ngram {
Config::Config() :
+ show_progress(true),
messages(&std::cerr),
enumerate_vocab(NULL),
unknown_missing(COMPLAIN),
diff --git a/klm/lm/config.hh b/klm/lm/config.hh
index 739cee9c..0de7b7c6 100644
--- a/klm/lm/config.hh
+++ b/klm/lm/config.hh
@@ -11,46 +11,52 @@
/* Configuration for ngram model. Separate header to reduce pollution. */
namespace lm {
-
+
class EnumerateVocab;
namespace ngram {
struct Config {
- // EFFECTIVE FOR BOTH ARPA AND BINARY READS
+ // EFFECTIVE FOR BOTH ARPA AND BINARY READS
+
+ // (default true) print progress bar to messages
+ bool show_progress;
// Where to log messages including the progress bar. Set to NULL for
// silence.
std::ostream *messages;
+ std::ostream *ProgressMessages() const {
+ return show_progress ? messages : 0;
+ }
+
// This will be called with every string in the vocabulary. See
// enumerate_vocab.hh for more detail. Config does not take ownership; you
- // are still responsible for deleting it (or stack allocating).
+ // are still responsible for deleting it (or stack allocating).
EnumerateVocab *enumerate_vocab;
-
// ONLY EFFECTIVE WHEN READING ARPA
- // What to do when <unk> isn't in the provided model.
+ // What to do when <unk> isn't in the provided model.
WarningAction unknown_missing;
- // What to do when <s> or </s> is missing from the model.
- // If THROW_UP, the exception will be of type util::SpecialWordMissingException.
+ // What to do when <s> or </s> is missing from the model.
+ // If THROW_UP, the exception will be of type util::SpecialWordMissingException.
WarningAction sentence_marker_missing;
// What to do with a positive log probability. For COMPLAIN and SILENT, map
- // to 0.
+ // to 0.
WarningAction positive_log_probability;
- // The probability to substitute for <unk> if it's missing from the model.
+ // The probability to substitute for <unk> if it's missing from the model.
// No effect if the model has <unk> or unknown_missing == THROW_UP.
float unknown_missing_logprob;
// Size multiplier for probing hash table. Must be > 1. Space is linear in
// this. Time is probing_multiplier / (probing_multiplier - 1). No effect
- // for sorted variant.
+ // for sorted variant.
// If you find yourself setting this to a low number, consider using the
- // TrieModel which has lower memory consumption.
+ // TrieModel which has lower memory consumption.
float probing_multiplier;
// Amount of memory to use for building. The actual memory usage will be
@@ -58,10 +64,10 @@ struct Config {
// models.
std::size_t building_memory;
- // Template for temporary directory appropriate for passing to mkdtemp.
+ // Template for temporary directory appropriate for passing to mkdtemp.
// The characters XXXXXX are appended before passing to mkdtemp. Only
// applies to trie. If NULL, defaults to write_mmap. If that's NULL,
- // defaults to input file name.
+ // defaults to input file name.
const char *temporary_directory_prefix;
// Level of complaining to do when loading from ARPA instead of binary format.
@@ -69,49 +75,46 @@ struct Config {
ARPALoadComplain arpa_complain;
// While loading an ARPA file, also write out this binary format file. Set
- // to NULL to disable.
+ // to NULL to disable.
const char *write_mmap;
enum WriteMethod {
- WRITE_MMAP, // Map the file directly.
- WRITE_AFTER // Write after we're done.
+ WRITE_MMAP, // Map the file directly.
+ WRITE_AFTER // Write after we're done.
};
WriteMethod write_method;
- // Include the vocab in the binary file? Only effective if write_mmap != NULL.
+ // Include the vocab in the binary file? Only effective if write_mmap != NULL.
bool include_vocab;
- // Left rest options. Only used when the model includes rest costs.
+ // Left rest options. Only used when the model includes rest costs.
enum RestFunction {
REST_MAX, // Maximum of any score to the left
- REST_LOWER, // Use lower-order files given below.
+ REST_LOWER, // Use lower-order files given below.
};
RestFunction rest_function;
- // Only used for REST_LOWER.
+ // Only used for REST_LOWER.
std::vector<std::string> rest_lower_files;
-
// Quantization options. Only effective for QuantTrieModel. One value is
// reserved for each of prob and backoff, so 2^bits - 1 buckets will be used
- // to quantize (and one of the remaining backoffs will be 0).
+ // to quantize (and one of the remaining backoffs will be 0).
uint8_t prob_bits, backoff_bits;
// Bhiksha compression (simple form). Only works with trie.
uint8_t pointer_bhiksha_bits;
-
-
+
// ONLY EFFECTIVE WHEN READING BINARY
-
+
// How to get the giant array into memory: lazy mmap, populate, read etc.
- // See util/mmap.hh for details of MapMethod.
+ // See util/mmap.hh for details of MapMethod.
util::LoadMethod load_method;
-
- // Set defaults.
+ // Set defaults.
Config();
};
diff --git a/klm/lm/left.hh b/klm/lm/left.hh
index 8c27232e..85c1ea37 100644
--- a/klm/lm/left.hh
+++ b/klm/lm/left.hh
@@ -51,36 +51,36 @@ namespace ngram {
template <class M> class RuleScore {
public:
- explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) {
+ explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(&out), left_done_(false), prob_(0.0) {
out.left.length = 0;
out.right.length = 0;
}
void BeginSentence() {
- out_.right = model_.BeginSentenceState();
- // out_.left is empty.
+ out_->right = model_.BeginSentenceState();
+ // out_->left is empty.
left_done_ = true;
}
void Terminal(WordIndex word) {
- State copy(out_.right);
- FullScoreReturn ret(model_.FullScore(copy, word, out_.right));
+ State copy(out_->right);
+ FullScoreReturn ret(model_.FullScore(copy, word, out_->right));
if (left_done_) { prob_ += ret.prob; return; }
if (ret.independent_left) {
prob_ += ret.prob;
left_done_ = true;
return;
}
- out_.left.pointers[out_.left.length++] = ret.extend_left;
+ out_->left.pointers[out_->left.length++] = ret.extend_left;
prob_ += ret.rest;
- if (out_.right.length != copy.length + 1)
+ if (out_->right.length != copy.length + 1)
left_done_ = true;
}
// Faster version of NonTerminal for the case where the rule begins with a non-terminal.
void BeginNonTerminal(const ChartState &in, float prob = 0.0) {
prob_ = prob;
- out_ = in;
+ *out_ = in;
left_done_ = in.left.full;
}
@@ -89,23 +89,23 @@ template <class M> class RuleScore {
if (!in.left.length) {
if (in.left.full) {
- for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i;
+ for (const float *i = out_->right.backoff; i < out_->right.backoff + out_->right.length; ++i) prob_ += *i;
left_done_ = true;
- out_.right = in.right;
+ out_->right = in.right;
}
return;
}
- if (!out_.right.length) {
- out_.right = in.right;
+ if (!out_->right.length) {
+ out_->right = in.right;
if (left_done_) {
prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1);
return;
}
- if (out_.left.length) {
+ if (out_->left.length) {
left_done_ = true;
} else {
- out_.left = in.left;
+ out_->left = in.left;
left_done_ = in.left.full;
}
return;
@@ -113,10 +113,10 @@ template <class M> class RuleScore {
float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1];
float *back = backoffs, *back2 = backoffs2;
- unsigned char next_use = out_.right.length;
+ unsigned char next_use = out_->right.length;
// First word
- if (ExtendLeft(in, next_use, 1, out_.right.backoff, back)) return;
+ if (ExtendLeft(in, next_use, 1, out_->right.backoff, back)) return;
// Words after the first, so extending a bigram to begin with
for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) {
@@ -127,54 +127,58 @@ template <class M> class RuleScore {
if (in.left.full) {
for (const float *i = back; i != back + next_use; ++i) prob_ += *i;
left_done_ = true;
- out_.right = in.right;
+ out_->right = in.right;
return;
}
// Right state was minimized, so it's already independent of the new words to the left.
if (in.right.length < in.left.length) {
- out_.right = in.right;
+ out_->right = in.right;
return;
}
// Shift exisiting words down.
- for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) {
+ for (WordIndex *i = out_->right.words + next_use - 1; i >= out_->right.words; --i) {
*(i + in.right.length) = *i;
}
// Add words from in.right.
- std::copy(in.right.words, in.right.words + in.right.length, out_.right.words);
+ std::copy(in.right.words, in.right.words + in.right.length, out_->right.words);
// Assemble backoff composed on the existing state's backoff followed by the new state's backoff.
- std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff);
- std::copy(back, back + next_use, out_.right.backoff + in.right.length);
- out_.right.length = in.right.length + next_use;
+ std::copy(in.right.backoff, in.right.backoff + in.right.length, out_->right.backoff);
+ std::copy(back, back + next_use, out_->right.backoff + in.right.length);
+ out_->right.length = in.right.length + next_use;
}
float Finish() {
// A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram.
- out_.left.full = left_done_ || (out_.left.length == model_.Order() - 1);
+ out_->left.full = left_done_ || (out_->left.length == model_.Order() - 1);
return prob_;
}
void Reset() {
prob_ = 0.0;
left_done_ = false;
- out_.left.length = 0;
- out_.right.length = 0;
+ out_->left.length = 0;
+ out_->right.length = 0;
+ }
+ void Reset(ChartState &replacement) {
+ out_ = &replacement;
+ Reset();
}
private:
bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) {
ProcessRet(model_.ExtendLeft(
- out_.right.words, out_.right.words + next_use, // Words to extend into
+ out_->right.words, out_->right.words + next_use, // Words to extend into
back_in, // Backoffs to use
in.left.pointers[extend_length - 1], extend_length, // Words to be extended
back_out, // Backoffs for the next score
next_use)); // Length of n-gram to use in next scoring.
- if (next_use != out_.right.length) {
+ if (next_use != out_->right.length) {
left_done_ = true;
if (!next_use) {
// Early exit.
- out_.right = in.right;
+ out_->right = in.right;
prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1);
return true;
}
@@ -193,13 +197,13 @@ template <class M> class RuleScore {
left_done_ = true;
return;
}
- out_.left.pointers[out_.left.length++] = ret.extend_left;
+ out_->left.pointers[out_->left.length++] = ret.extend_left;
prob_ += ret.rest;
}
const M &model_;
- ChartState &out_;
+ ChartState *out_;
bool left_done_;
diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh
index 989f8324..3eb97ccd 100644
--- a/klm/lm/max_order.hh
+++ b/klm/lm/max_order.hh
@@ -4,9 +4,6 @@
* (kMaxOrder - 1) * sizeof(float) bytes instead of
* sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
*/
-#ifndef KENLM_MAX_ORDER
-#define KENLM_MAX_ORDER 6
-#endif
#ifndef KENLM_ORDER_MESSAGE
-#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --kenlm-max-order=6 -a'. Otherwise, edit lm/max_order.hh."
+#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh."
#endif
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index 2fd20481..a40fd2fb 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -37,7 +37,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) {
LoadLM(file, config, *this);
- // g++ prints warnings unless these are fully initialized.
+ // g++ prints warnings unless these are fully initialized.
State begin_sentence = State();
begin_sentence.length = 1;
begin_sentence.words[0] = vocab_.BeginSentence();
@@ -69,8 +69,8 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
- // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
- util::FilePiece f(backing_.file.release(), file, config.messages);
+ // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
+ util::FilePiece f(backing_.file.release(), file, config.ProgressMessages());
try {
std::vector<uint64_t> counts;
// File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
@@ -80,14 +80,14 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config));
- // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
+ // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
if (config.write_mmap) {
WriteWordsWrapper wrap(config.enumerate_vocab);
vocab_.ConfigureEnumerate(&wrap, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + backing_.search.size());
+ wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config));
} else {
vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
@@ -95,7 +95,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
if (!vocab_.SawUnk()) {
assert(config.unknown_missing != THROW_UP);
- // Default probabilities for unknown.
+ // Default probabilities for unknown.
search_.UnknownUnigram().backoff = 0.0;
search_.UnknownUnigram().prob = config.unknown_missing_logprob;
}
@@ -147,7 +147,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const {
- // Generate a state from context.
+ // Generate a state from context.
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
if (context_rend == context_rbegin) {
out_state.length = 0;
@@ -191,7 +191,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
ret.rest = ptr.Rest();
ret.prob = ptr.Prob();
ret.extend_left = extend_pointer;
- // If this function is called, then it does depend on left words.
+ // If this function is called, then it does depend on left words.
ret.independent_left = false;
}
float subtract_me = ret.rest;
@@ -199,7 +199,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
next_use = extend_length;
ResumeScore(add_rbegin, add_rend, extend_length - 1, node, backoff_out, next_use, ret);
next_use -= extend_length;
- // Charge backoffs.
+ // Charge backoffs.
for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b;
ret.prob -= subtract_me;
ret.rest -= subtract_me;
@@ -209,7 +209,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
namespace {
// Do a paraonoid copy of history, assuming new_word has already been copied
// (hence the -1). out_state.length could be zero so I avoided using
-// std::copy.
+// std::copy.
void CopyRemainingHistory(const WordIndex *from, State &out_state) {
WordIndex *out = out_state.words + 1;
const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.length) - 1;
@@ -217,18 +217,19 @@ void CopyRemainingHistory(const WordIndex *from, State &out_state) {
}
} // namespace
-/* Ugly optimized function. Produce a score excluding backoff.
- * The search goes in increasing order of ngram length.
+/* Ugly optimized function. Produce a score excluding backoff.
+ * The search goes in increasing order of ngram length.
* Context goes backward, so context_begin is the word immediately preceeding
- * new_word.
+ * new_word.
*/
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff(
const WordIndex *const context_rbegin,
const WordIndex *const context_rend,
const WordIndex new_word,
State &out_state) const {
+ assert(new_word < vocab_.Bound());
FullScoreReturn ret;
- // ret.ngram_length contains the last known non-blank ngram length.
+ // ret.ngram_length contains the last known non-blank ngram length.
ret.ngram_length = 1;
typename Search::Node node;
@@ -237,9 +238,9 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
ret.prob = uni.Prob();
ret.rest = uni.Rest();
- // This is the length of the context that should be used for continuation to the right.
+ // This is the length of the context that should be used for continuation to the right.
out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0;
- // We'll write the word anyway since it will probably be used and does no harm being there.
+ // We'll write the word anyway since it will probably be used and does no harm being there.
out_state.words[0] = new_word;
if (context_rbegin == context_rend) return ret;
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
index a1623834..2d6f15b2 100644
--- a/klm/lm/search_hashed.cc
+++ b/klm/lm/search_hashed.cc
@@ -231,7 +231,7 @@ template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char *
template <> void HashedSearch<BackoffValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {
NoRestBuild build;
- ApplyBuild(f, counts, config, vocab, warn, build);
+ ApplyBuild(f, counts, vocab, warn, build);
}
template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {
@@ -239,19 +239,19 @@ template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, cons
case Config::REST_MAX:
{
MaxRestBuild build;
- ApplyBuild(f, counts, config, vocab, warn, build);
+ ApplyBuild(f, counts, vocab, warn, build);
}
break;
case Config::REST_LOWER:
{
LowerRestBuild<ProbingModel> build(config, counts.size(), vocab);
- ApplyBuild(f, counts, config, vocab, warn, build);
+ ApplyBuild(f, counts, vocab, warn, build);
}
break;
}
}
-template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) {
+template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) {
for (WordIndex i = 0; i < counts[0]; ++i) {
build.SetRest(&i, (unsigned int)1, unigram_.Raw()[i]);
}
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh
index a52f107b..00595796 100644
--- a/klm/lm/search_hashed.hh
+++ b/klm/lm/search_hashed.hh
@@ -147,7 +147,7 @@ template <class Value> class HashedSearch {
// Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.
void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn);
- template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build);
+ template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build);
class Unigram {
public:
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index debcfd07..1b0d9b26 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -55,7 +55,7 @@ struct ProbPointer {
uint64_t index;
};
-// Array of n-grams and float indices.
+// Array of n-grams and float indices.
class BackoffMessages {
public:
void Init(std::size_t entry_size) {
@@ -100,7 +100,7 @@ class BackoffMessages {
void Apply(float *const *const base, RecordReader &reader) {
FinishedAdding();
if (current_ == allocated_) return;
- // We'll also use the same buffer to record messages to blanks that they extend.
+ // We'll also use the same buffer to record messages to blanks that they extend.
WordIndex *extend_out = reinterpret_cast<WordIndex*>(current_);
const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex);
for (reader.Rewind(); reader && (current_ != allocated_); ) {
@@ -109,7 +109,7 @@ class BackoffMessages {
++reader;
break;
case 1:
- // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends.
+ // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends.
for (const WordIndex *w = reinterpret_cast<const WordIndex *>(current_); w != reinterpret_cast<const WordIndex *>(current_) + order; ++w, ++extend_out) *extend_out = *w;
current_ += entry_size_;
break;
@@ -126,7 +126,7 @@ class BackoffMessages {
break;
}
}
- // Now this is a list of blanks that extend right.
+ // Now this is a list of blanks that extend right.
entry_size_ = sizeof(WordIndex) * order;
Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get()));
current_ = (uint8_t*)backing_.get();
@@ -153,7 +153,7 @@ class BackoffMessages {
private:
void FinishedAdding() {
Resize(current_ - (uint8_t*)backing_.get());
- // Sort requests in same order as files.
+ // Sort requests in same order as files.
std::sort(
util::SizedIterator(util::SizedProxy(backing_.get(), entry_size_)),
util::SizedIterator(util::SizedProxy(current_, entry_size_)),
@@ -220,7 +220,7 @@ class SRISucks {
}
private:
- // This used to be one array. Then I needed to separate it by order for quantization to work.
+ // This used to be one array. Then I needed to separate it by order for quantization to work.
std::vector<float> values_[KENLM_MAX_ORDER - 1];
BackoffMessages messages_[KENLM_MAX_ORDER - 1];
@@ -253,7 +253,7 @@ class FindBlanks {
++counts_.back();
}
- // Unigrams wrote one past.
+ // Unigrams wrote one past.
void Cleanup() {
--counts_[0];
}
@@ -270,15 +270,15 @@ class FindBlanks {
SRISucks &sri_;
};
-// Phase to actually write n-grams to the trie.
+// Phase to actually write n-grams to the trie.
template <class Quant, class Bhiksha> class WriteEntries {
public:
- WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) :
+ WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) :
contexts_(contexts),
quant_(quant),
unigrams_(unigrams),
middle_(middle),
- longest_(longest),
+ longest_(longest),
bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)),
order_(order),
sri_(sri) {}
@@ -328,7 +328,7 @@ struct Gram {
const WordIndex *begin, *end;
- // For queue, this is the direction we want.
+ // For queue, this is the direction we want.
bool operator<(const Gram &other) const {
return std::lexicographical_compare(other.begin, other.end, begin, end);
}
@@ -353,7 +353,7 @@ template <class Doing> class BlankManager {
been_length_ = length;
return;
}
- // There are blanks to insert starting with order blank.
+ // There are blanks to insert starting with order blank.
unsigned char blank = cur - to + 1;
UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context.");
const float *lower_basis;
@@ -363,7 +363,7 @@ template <class Doing> class BlankManager {
assert(*lower_basis != kBadProb);
doing_.MiddleBlank(blank, to, based_on, *lower_basis);
*pre = *cur;
- // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram.
+ // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram.
basis_[blank - 1] = kBadProb;
}
*pre = *cur;
@@ -377,7 +377,7 @@ template <class Doing> class BlankManager {
unsigned char been_length_;
float basis_[KENLM_MAX_ORDER];
-
+
Doing &doing_;
};
@@ -451,7 +451,7 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, Re
}
void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) {
- // Fill unigram probabilities.
+ // Fill unigram probabilities.
try {
rewind(file);
for (WordIndex i = 0; i < unigram_count; ++i) {
@@ -486,7 +486,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
util::scoped_memory unigrams;
MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams);
FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri);
- RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder);
+ RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Identifying n-grams omitted by SRI", finder);
fixed_counts = finder.Counts();
}
unigram_file.reset(util::FDOpenOrThrow(unigram_fd));
@@ -504,7 +504,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
inputs[i-2].Rewind();
}
if (Quant::kTrain) {
- util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), config.messages, "Quantizing");
+ util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0),
+ config.ProgressMessages(), "Quantizing");
for (unsigned char i = 2; i < counts.size(); ++i) {
TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant);
}
@@ -519,13 +520,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Rewind();
}
- // Fill entries except unigram probabilities.
+ // Fill entries except unigram probabilities.
{
WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri);
- RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer);
+ RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer);
}
- // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
+ // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
for (unsigned char order = 2; order <= counts.size(); ++order) {
const RecordReader &context = contexts[order - 2];
if (context) {
@@ -541,13 +542,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
}
/* Set ending offsets so the last entry will be sized properly */
- // Last entry for unigrams was already set.
+ // Last entry for unigrams was already set.
if (out.middle_begin_ != out.middle_end_) {
for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {
i->FinishedLoading((i+1)->InsertIndex(), config);
}
(out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config);
- }
+ }
}
template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
@@ -595,7 +596,7 @@ template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::Initializ
} else {
temporary_prefix = file;
}
- // At least 1MB sorting memory.
+ // At least 1MB sorting memory.
SortedFiles sorted(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_prefix, vocab);
BuildTrie(sorted, counts, config, *this, quant_, vocab, backing);
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index 11c27518..fd7f96dc 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -116,7 +116,9 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {
}
*end_ = hashed;
if (enumerate_) {
- strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size());
+ void *copied = string_backing_.Allocate(str.size());
+ memcpy(copied, str.data(), str.size());
+ strings_to_enumerate_[end_ - begin_] = StringPiece(static_cast<const char*>(copied), str.size());
}
++end_;
// This is 1 + the offset where it was inserted to make room for unk.
@@ -126,7 +128,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {
void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
if (enumerate_) {
if (!strings_to_enumerate_.empty()) {
- util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
+ util::PairedIterator<ProbBackoff*, StringPiece*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
util::JointSort(begin_, end_, values);
}
for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) {
@@ -134,6 +136,7 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
enumerate_->Add(i + 1, strings_to_enumerate_[i]);
}
strings_to_enumerate_.clear();
+ string_backing_.FreeAll();
} else {
util::JointSort(begin_, end_, reorder_vocab + 1);
}
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index de54eb06..3902f117 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -4,6 +4,7 @@
#include "lm/enumerate_vocab.hh"
#include "lm/lm_exception.hh"
#include "lm/virtual_interface.hh"
+#include "util/pool.hh"
#include "util/probing_hash_table.hh"
#include "util/sorted_uniform.hh"
#include "util/string_piece.hh"
@@ -96,7 +97,9 @@ class SortedVocabulary : public base::Vocabulary {
EnumerateVocab *enumerate_;
// Actual strings. Used only when loading from ARPA and enumerate_ != NULL
- std::vector<std::string> strings_to_enumerate_;
+ util::Pool string_backing_;
+
+ std::vector<StringPiece> strings_to_enumerate_;
};
#pragma pack(push)
diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am
index ccc5b7f6..5aea33c2 100644
--- a/klm/search/Makefile.am
+++ b/klm/search/Makefile.am
@@ -2,10 +2,10 @@ noinst_LIBRARIES = libksearch.a
libksearch_a_SOURCES = \
edge_generator.cc \
+ nbest.cc \
rule.cc \
vertex.cc \
- vertex_generator.cc \
- weights.cc
+ vertex_generator.cc
AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I..
diff --git a/klm/search/applied.hh b/klm/search/applied.hh
new file mode 100644
index 00000000..bd659e5c
--- /dev/null
+++ b/klm/search/applied.hh
@@ -0,0 +1,86 @@
+#ifndef SEARCH_APPLIED__
+#define SEARCH_APPLIED__
+
+#include "search/edge.hh"
+#include "search/header.hh"
+#include "util/pool.hh"
+
+#include <math.h>
+
+namespace search {
+
+// A full hypothesis: a score, arity of the rule, a pointer to the decoder's rule (Note), and pointers to non-terminals that were substituted.
+template <class Below> class GenericApplied : public Header {
+ public:
+ GenericApplied() {}
+
+ GenericApplied(void *location, PartialEdge partial)
+ : Header(location) {
+ memcpy(Base(), partial.Base(), kHeaderSize);
+ Below *child_out = Children();
+ const PartialVertex *part = partial.NT();
+ const PartialVertex *const part_end_loop = part + partial.GetArity();
+ for (; part != part_end_loop; ++part, ++child_out)
+ *child_out = Below(part->End());
+ }
+
+ GenericApplied(void *location, Score score, Arity arity, Note note) : Header(location, arity) {
+ SetScore(score);
+ SetNote(note);
+ }
+
+ explicit GenericApplied(History from) : Header(from) {}
+
+
+ // These are arrays of length GetArity().
+ Below *Children() {
+ return reinterpret_cast<Below*>(After());
+ }
+ const Below *Children() const {
+ return reinterpret_cast<const Below*>(After());
+ }
+
+ static std::size_t Size(Arity arity) {
+ return kHeaderSize + arity * sizeof(const Below);
+ }
+};
+
+// Applied rule that references itself.
+class Applied : public GenericApplied<Applied> {
+ private:
+ typedef GenericApplied<Applied> P;
+
+ public:
+ Applied() {}
+ Applied(void *location, PartialEdge partial) : P(location, partial) {}
+ Applied(History from) : P(from) {}
+};
+
+// How to build single-best hypotheses.
+class SingleBest {
+ public:
+ typedef PartialEdge Combine;
+
+ void Add(PartialEdge &existing, PartialEdge add) const {
+ if (!existing.Valid() || existing.GetScore() < add.GetScore())
+ existing = add;
+ }
+
+ NBestComplete Complete(PartialEdge partial) {
+ if (!partial.Valid())
+ return NBestComplete(NULL, lm::ngram::ChartState(), -INFINITY);
+ void *place_final = pool_.Allocate(Applied::Size(partial.GetArity()));
+ Applied(place_final, partial);
+ return NBestComplete(
+ place_final,
+ partial.CompletedState(),
+ partial.GetScore());
+ }
+
+ private:
+ util::Pool pool_;
+};
+
+} // namespace search
+
+#endif // SEARCH_APPLIED__
diff --git a/klm/search/config.hh b/klm/search/config.hh
index ef8e2354..ba18c09e 100644
--- a/klm/search/config.hh
+++ b/klm/search/config.hh
@@ -1,23 +1,36 @@
#ifndef SEARCH_CONFIG__
#define SEARCH_CONFIG__
-#include "search/weights.hh"
-#include "util/string_piece.hh"
+#include "search/types.hh"
namespace search {
+struct NBestConfig {
+ explicit NBestConfig(unsigned int in_size) {
+ keep = in_size;
+ size = in_size;
+ }
+
+ unsigned int keep, size;
+};
+
class Config {
public:
- Config(const Weights &weights, unsigned int pop_limit) :
- weights_(weights), pop_limit_(pop_limit) {}
+ Config(Score lm_weight, unsigned int pop_limit, const NBestConfig &nbest) :
+ lm_weight_(lm_weight), pop_limit_(pop_limit), nbest_(nbest) {}
- const Weights &GetWeights() const { return weights_; }
+ Score LMWeight() const { return lm_weight_; }
unsigned int PopLimit() const { return pop_limit_; }
+ const NBestConfig &GetNBest() const { return nbest_; }
+
private:
- Weights weights_;
+ Score lm_weight_;
+
unsigned int pop_limit_;
+
+ NBestConfig nbest_;
};
} // namespace search
diff --git a/klm/search/context.hh b/klm/search/context.hh
index 62163144..08f21bbf 100644
--- a/klm/search/context.hh
+++ b/klm/search/context.hh
@@ -1,30 +1,16 @@
#ifndef SEARCH_CONTEXT__
#define SEARCH_CONTEXT__
-#include "lm/model.hh"
#include "search/config.hh"
-#include "search/final.hh"
-#include "search/types.hh"
#include "search/vertex.hh"
-#include "util/exception.hh"
-#include "util/pool.hh"
#include <boost/pool/object_pool.hpp>
-#include <boost/ptr_container/ptr_vector.hpp>
-
-#include <vector>
namespace search {
-class Weights;
-
class ContextBase {
public:
- explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {}
-
- util::Pool &FinalPool() {
- return final_pool_;
- }
+ explicit ContextBase(const Config &config) : config_(config) {}
VertexNode *NewVertexNode() {
VertexNode *ret = vertex_node_pool_.construct();
@@ -36,18 +22,16 @@ class ContextBase {
vertex_node_pool_.destroy(node);
}
- unsigned int PopLimit() const { return pop_limit_; }
+ unsigned int PopLimit() const { return config_.PopLimit(); }
- const Weights &GetWeights() const { return weights_; }
+ Score LMWeight() const { return config_.LMWeight(); }
- private:
- util::Pool final_pool_;
+ const Config &GetConfig() const { return config_; }
+ private:
boost::object_pool<VertexNode> vertex_node_pool_;
- unsigned int pop_limit_;
-
- const Weights &weights_;
+ Config config_;
};
template <class Model> class Context : public ContextBase {
diff --git a/klm/search/dedupe.hh b/klm/search/dedupe.hh
new file mode 100644
index 00000000..7eaa3b95
--- /dev/null
+++ b/klm/search/dedupe.hh
@@ -0,0 +1,131 @@
+#ifndef SEARCH_DEDUPE__
+#define SEARCH_DEDUPE__
+
+#include "lm/state.hh"
+#include "search/edge_generator.hh"
+
+#include <boost/pool/object_pool.hpp>
+#include <boost/unordered_map.hpp>
+
+namespace search {
+
+class Dedupe {
+ public:
+ Dedupe() {}
+
+ PartialEdge AllocateEdge(Arity arity) {
+ return behind_.AllocateEdge(arity);
+ }
+
+ void AddEdge(PartialEdge edge) {
+ edge.MutableFlags() = 0;
+
+ uint64_t hash = 0;
+ const PartialVertex *v = edge.NT();
+ const PartialVertex *v_end = v + edge.GetArity();
+ for (; v != v_end; ++v) {
+ const void *ptr = v->Identify();
+ hash = util::MurmurHashNative(&ptr, sizeof(const void*), hash);
+ }
+
+ const lm::ngram::ChartState *c = edge.Between();
+ const lm::ngram::ChartState *const c_end = c + edge.GetArity() + 1;
+ for (; c != c_end; ++c) hash = hash_value(*c, hash);
+
+ std::pair<Table::iterator, bool> ret(table_.insert(std::make_pair(hash, edge)));
+ if (!ret.second) FoundDupe(ret.first->second, edge);
+ }
+
+ bool Empty() const { return behind_.Empty(); }
+
+ template <class Model, class Output> void Search(Context<Model> &context, Output &output) {
+ for (Table::const_iterator i(table_.begin()); i != table_.end(); ++i) {
+ behind_.AddEdge(i->second);
+ }
+ Unpack<Output> unpack(output, *this);
+ behind_.Search(context, unpack);
+ }
+
+ private:
+ void FoundDupe(PartialEdge &table, PartialEdge adding) {
+ if (table.GetFlags() & kPackedFlag) {
+ Packed &packed = *static_cast<Packed*>(table.GetNote().mut);
+ if (table.GetScore() >= adding.GetScore()) {
+ packed.others.push_back(adding);
+ return;
+ }
+ Note original(packed.original);
+ packed.original = adding.GetNote();
+ adding.SetNote(table.GetNote());
+ table.SetNote(original);
+ packed.others.push_back(table);
+ packed.starting = adding.GetScore();
+ table = adding;
+ table.MutableFlags() |= kPackedFlag;
+ return;
+ }
+ PartialEdge loser;
+ if (adding.GetScore() > table.GetScore()) {
+ loser = table;
+ table = adding;
+ } else {
+ loser = adding;
+ }
+ // table is winner, loser is loser...
+ packed_.construct(table, loser);
+ }
+
+ struct Packed {
+ Packed(PartialEdge winner, PartialEdge loser)
+ : original(winner.GetNote()), starting(winner.GetScore()), others(1, loser) {
+ winner.MutableNote().vp = this;
+ winner.MutableFlags() |= kPackedFlag;
+ loser.MutableFlags() &= ~kPackedFlag;
+ }
+ Note original;
+ Score starting;
+ std::vector<PartialEdge> others;
+ };
+
+ template <class Output> class Unpack {
+ public:
+ explicit Unpack(Output &output, Dedupe &owner) : output_(output), owner_(owner) {}
+
+ void NewHypothesis(PartialEdge edge) {
+ if (edge.GetFlags() & kPackedFlag) {
+ Packed &packed = *reinterpret_cast<Packed*>(edge.GetNote().mut);
+ edge.SetNote(packed.original);
+ edge.MutableFlags() = 0;
+ std::size_t copy_size = sizeof(PartialVertex) * edge.GetArity() + sizeof(lm::ngram::ChartState);
+ for (std::vector<PartialEdge>::iterator i = packed.others.begin(); i != packed.others.end(); ++i) {
+ PartialEdge copy(owner_.AllocateEdge(edge.GetArity()));
+ copy.SetScore(edge.GetScore() - packed.starting + i->GetScore());
+ copy.MutableFlags() = 0;
+ copy.SetNote(i->GetNote());
+ memcpy(copy.NT(), edge.NT(), copy_size);
+ output_.NewHypothesis(copy);
+ }
+ }
+ output_.NewHypothesis(edge);
+ }
+
+ void FinishedSearch() {
+ output_.FinishedSearch();
+ }
+
+ private:
+ Output &output_;
+ Dedupe &owner_;
+ };
+
+ EdgeGenerator behind_;
+
+ typedef boost::unordered_map<uint64_t, PartialEdge> Table;
+ Table table_;
+
+ boost::object_pool<Packed> packed_;
+
+ static const uint16_t kPackedFlag = 1;
+};
+} // namespace search
+#endif // SEARCH_DEDUPE__
diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc
index 260159b1..eacf5de5 100644
--- a/klm/search/edge_generator.cc
+++ b/klm/search/edge_generator.cc
@@ -1,6 +1,7 @@
#include "search/edge_generator.hh"
#include "lm/left.hh"
+#include "lm/model.hh"
#include "lm/partial.hh"
#include "search/context.hh"
#include "search/vertex.hh"
@@ -38,7 +39,7 @@ template <class Model> void FastScore(const Context<Model> &context, Arity victi
*cover = *(cover + 1);
}
}
- update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM());
+ update.SetScore(update.GetScore() + adjustment * context.LMWeight());
}
} // namespace
diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh
index 582c78b7..203942c6 100644
--- a/klm/search/edge_generator.hh
+++ b/klm/search/edge_generator.hh
@@ -2,7 +2,6 @@
#define SEARCH_EDGE_GENERATOR__
#include "search/edge.hh"
-#include "search/note.hh"
#include "search/types.hh"
#include <queue>
diff --git a/klm/search/final.hh b/klm/search/final.hh
deleted file mode 100644
index 50e62cf2..00000000
--- a/klm/search/final.hh
+++ /dev/null
@@ -1,36 +0,0 @@
-#ifndef SEARCH_FINAL__
-#define SEARCH_FINAL__
-
-#include "search/header.hh"
-#include "util/pool.hh"
-
-namespace search {
-
-// A full hypothesis with pointers to children.
-class Final : public Header {
- public:
- Final() {}
-
- Final(util::Pool &pool, Score score, Arity arity, Note note)
- : Header(pool.Allocate(Size(arity)), arity) {
- SetScore(score);
- SetNote(note);
- }
-
- // These are arrays of length GetArity().
- Final *Children() {
- return reinterpret_cast<Final*>(After());
- }
- const Final *Children() const {
- return reinterpret_cast<const Final*>(After());
- }
-
- private:
- static std::size_t Size(Arity arity) {
- return kHeaderSize + arity * sizeof(const Final);
- }
-};
-
-} // namespace search
-
-#endif // SEARCH_FINAL__
diff --git a/klm/search/header.hh b/klm/search/header.hh
index 25550dbe..69f0eed0 100644
--- a/klm/search/header.hh
+++ b/klm/search/header.hh
@@ -3,7 +3,6 @@
// Header consisting of Score, Arity, and Note
-#include "search/note.hh"
#include "search/types.hh"
#include <stdint.h>
@@ -24,6 +23,9 @@ class Header {
bool operator<(const Header &other) const {
return GetScore() < other.GetScore();
}
+ bool operator>(const Header &other) const {
+ return GetScore() > other.GetScore();
+ }
Arity GetArity() const {
return *reinterpret_cast<const Arity*>(base_ + sizeof(Score));
@@ -36,9 +38,14 @@ class Header {
*reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to;
}
+ uint8_t *Base() { return base_; }
+ const uint8_t *Base() const { return base_; }
+
protected:
Header() : base_(NULL) {}
+ explicit Header(void *base) : base_(static_cast<uint8_t*>(base)) {}
+
Header(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) {
*reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity;
}
diff --git a/klm/search/nbest.cc b/klm/search/nbest.cc
new file mode 100644
index 00000000..ec3322c9
--- /dev/null
+++ b/klm/search/nbest.cc
@@ -0,0 +1,106 @@
+#include "search/nbest.hh"
+
+#include "util/pool.hh"
+
+#include <algorithm>
+#include <functional>
+#include <queue>
+
+#include <assert.h>
+#include <math.h>
+
+namespace search {
+
+NBestList::NBestList(std::vector<PartialEdge> &partials, util::Pool &entry_pool, std::size_t keep) {
+ assert(!partials.empty());
+ std::vector<PartialEdge>::iterator end;
+ if (partials.size() > keep) {
+ end = partials.begin() + keep;
+ std::nth_element(partials.begin(), end, partials.end(), std::greater<PartialEdge>());
+ } else {
+ end = partials.end();
+ }
+ for (std::vector<PartialEdge>::const_iterator i(partials.begin()); i != end; ++i) {
+ queue_.push(QueueEntry(entry_pool.Allocate(QueueEntry::Size(i->GetArity())), *i));
+ }
+}
+
+Score NBestList::TopAfterConstructor() const {
+ assert(revealed_.empty());
+ return queue_.top().GetScore();
+}
+
+const std::vector<Applied> &NBestList::Extract(util::Pool &pool, std::size_t n) {
+ while (revealed_.size() < n && !queue_.empty()) {
+ MoveTop(pool);
+ }
+ return revealed_;
+}
+
+Score NBestList::Visit(util::Pool &pool, std::size_t index) {
+ if (index + 1 < revealed_.size())
+ return revealed_[index + 1].GetScore() - revealed_[index].GetScore();
+ if (queue_.empty())
+ return -INFINITY;
+ if (index + 1 == revealed_.size())
+ return queue_.top().GetScore() - revealed_[index].GetScore();
+ assert(index == revealed_.size());
+
+ MoveTop(pool);
+
+ if (queue_.empty()) return -INFINITY;
+ return queue_.top().GetScore() - revealed_[index].GetScore();
+}
+
+Applied NBestList::Get(util::Pool &pool, std::size_t index) {
+ assert(index <= revealed_.size());
+ if (index == revealed_.size()) MoveTop(pool);
+ return revealed_[index];
+}
+
+void NBestList::MoveTop(util::Pool &pool) {
+ assert(!queue_.empty());
+ QueueEntry entry(queue_.top());
+ queue_.pop();
+ RevealedRef *const children_begin = entry.Children();
+ RevealedRef *const children_end = children_begin + entry.GetArity();
+ Score basis = entry.GetScore();
+ for (RevealedRef *child = children_begin; child != children_end; ++child) {
+ Score change = child->in_->Visit(pool, child->index_);
+ if (change != -INFINITY) {
+ assert(change < 0.001);
+ QueueEntry new_entry(pool.Allocate(QueueEntry::Size(entry.GetArity())), basis + change, entry.GetArity(), entry.GetNote());
+ std::copy(children_begin, child, new_entry.Children());
+ RevealedRef *update = new_entry.Children() + (child - children_begin);
+ update->in_ = child->in_;
+ update->index_ = child->index_ + 1;
+ std::copy(child + 1, children_end, update + 1);
+ queue_.push(new_entry);
+ }
+ // Gesmundo, A. and Henderson, J. Faster Cube Pruning, IWSLT 2010.
+ if (child->index_) break;
+ }
+
+ // Convert QueueEntry to Applied. This leaves some unused memory.
+ void *overwrite = entry.Children();
+ for (unsigned int i = 0; i < entry.GetArity(); ++i) {
+ RevealedRef from(*(static_cast<const RevealedRef*>(overwrite) + i));
+ *(static_cast<Applied*>(overwrite) + i) = from.in_->Get(pool, from.index_);
+ }
+ revealed_.push_back(Applied(entry.Base()));
+}
+
+NBestComplete NBest::Complete(std::vector<PartialEdge> &partials) {
+ assert(!partials.empty());
+ NBestList *list = list_pool_.construct(partials, entry_pool_, config_.keep);
+ return NBestComplete(
+ list,
+ partials.front().CompletedState(), // All partials have the same state
+ list->TopAfterConstructor());
+}
+
+const std::vector<Applied> &NBest::Extract(History history) {
+ return static_cast<NBestList*>(history)->Extract(entry_pool_, config_.size);
+}
+
+} // namespace search
diff --git a/klm/search/nbest.hh b/klm/search/nbest.hh
new file mode 100644
index 00000000..cb7651bc
--- /dev/null
+++ b/klm/search/nbest.hh
@@ -0,0 +1,81 @@
+#ifndef SEARCH_NBEST__
+#define SEARCH_NBEST__
+
+#include "search/applied.hh"
+#include "search/config.hh"
+#include "search/edge.hh"
+
+#include <boost/pool/object_pool.hpp>
+
+#include <cstddef>
+#include <queue>
+#include <vector>
+
+#include <assert.h>
+
+namespace search {
+
+class NBestList;
+
+class NBestList {
+ private:
+ class RevealedRef {
+ public:
+ explicit RevealedRef(History history)
+ : in_(static_cast<NBestList*>(history)), index_(0) {}
+
+ private:
+ friend class NBestList;
+
+ NBestList *in_;
+ std::size_t index_;
+ };
+
+ typedef GenericApplied<RevealedRef> QueueEntry;
+
+ public:
+ NBestList(std::vector<PartialEdge> &existing, util::Pool &entry_pool, std::size_t keep);
+
+ Score TopAfterConstructor() const;
+
+ const std::vector<Applied> &Extract(util::Pool &pool, std::size_t n);
+
+ private:
+ Score Visit(util::Pool &pool, std::size_t index);
+
+ Applied Get(util::Pool &pool, std::size_t index);
+
+ void MoveTop(util::Pool &pool);
+
+ typedef std::vector<Applied> Revealed;
+ Revealed revealed_;
+
+ typedef std::priority_queue<QueueEntry> Queue;
+ Queue queue_;
+};
+
+class NBest {
+ public:
+ typedef std::vector<PartialEdge> Combine;
+
+ explicit NBest(const NBestConfig &config) : config_(config) {}
+
+ void Add(std::vector<PartialEdge> &existing, PartialEdge addition) const {
+ existing.push_back(addition);
+ }
+
+ NBestComplete Complete(std::vector<PartialEdge> &partials);
+
+ const std::vector<Applied> &Extract(History root);
+
+ private:
+ const NBestConfig config_;
+
+ boost::object_pool<NBestList> list_pool_;
+
+ util::Pool entry_pool_;
+};
+
+} // namespace search
+
+#endif // SEARCH_NBEST__
diff --git a/klm/search/note.hh b/klm/search/note.hh
deleted file mode 100644
index 50bed06e..00000000
--- a/klm/search/note.hh
+++ /dev/null
@@ -1,12 +0,0 @@
-#ifndef SEARCH_NOTE__
-#define SEARCH_NOTE__
-
-namespace search {
-
-union Note {
- const void *vp;
-};
-
-} // namespace search
-
-#endif // SEARCH_NOTE__
diff --git a/klm/search/rule.cc b/klm/search/rule.cc
index 5b00207e..0244a09f 100644
--- a/klm/search/rule.cc
+++ b/klm/search/rule.cc
@@ -1,7 +1,7 @@
#include "search/rule.hh"
+#include "lm/model.hh"
#include "search/context.hh"
-#include "search/final.hh"
#include <ostream>
@@ -9,35 +9,35 @@
namespace search {
-template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing) {
- unsigned int oov_count = 0;
- float prob = 0.0;
- const Model &model = context.LanguageModel();
- const lm::WordIndex oov = model.GetVocabulary().NotFound();
- for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) {
- lm::ngram::RuleScore<Model> scorer(model, *(writing++));
- // TODO: optimize
- if (prepend_bos && (word == words.begin())) {
- scorer.BeginSentence();
- }
- for (; ; ++word) {
- if (word == words.end()) {
- prob += scorer.Finish();
- return static_cast<float>(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM();
- }
- if (*word == kNonTerminal) break;
- if (*word == oov) ++oov_count;
+template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing) {
+ ScoreRuleRet ret;
+ ret.prob = 0.0;
+ ret.oov = 0;
+ const lm::WordIndex oov = model.GetVocabulary().NotFound(), bos = model.GetVocabulary().BeginSentence();
+ lm::ngram::RuleScore<Model> scorer(model, *(writing++));
+ std::vector<lm::WordIndex>::const_iterator word = words.begin();
+ if (word != words.end() && *word == bos) {
+ scorer.BeginSentence();
+ ++word;
+ }
+ for (; word != words.end(); ++word) {
+ if (*word == kNonTerminal) {
+ ret.prob += scorer.Finish();
+ scorer.Reset(*(writing++));
+ } else {
+ if (*word == oov) ++ret.oov;
scorer.Terminal(*word);
}
- prob += scorer.Finish();
}
+ ret.prob += scorer.Finish();
+ return ret;
}
-template float ScoreRule(const Context<lm::ngram::RestProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
-template float ScoreRule(const Context<lm::ngram::ProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
-template float ScoreRule(const Context<lm::ngram::TrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
-template float ScoreRule(const Context<lm::ngram::QuantTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
-template float ScoreRule(const Context<lm::ngram::ArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
-template float ScoreRule(const Context<lm::ngram::QuantArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template ScoreRuleRet ScoreRule(const lm::ngram::RestProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
+template ScoreRuleRet ScoreRule(const lm::ngram::ProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
+template ScoreRuleRet ScoreRule(const lm::ngram::TrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
+template ScoreRuleRet ScoreRule(const lm::ngram::QuantTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
+template ScoreRuleRet ScoreRule(const lm::ngram::ArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
+template ScoreRuleRet ScoreRule(const lm::ngram::QuantArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
} // namespace search
diff --git a/klm/search/rule.hh b/klm/search/rule.hh
index 0ce2794d..43ca6162 100644
--- a/klm/search/rule.hh
+++ b/klm/search/rule.hh
@@ -9,11 +9,16 @@
namespace search {
-template <class Model> class Context;
-
const lm::WordIndex kNonTerminal = lm::kMaxWordIndex;
-template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *state_out);
+struct ScoreRuleRet {
+ Score prob;
+ unsigned int oov;
+};
+
+// Pass <s> and </s> normally.
+// Indicate non-terminals with kNonTerminal.
+template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *state_out);
} // namespace search
diff --git a/klm/search/types.hh b/klm/search/types.hh
index 06eb5bfa..f9c849b3 100644
--- a/klm/search/types.hh
+++ b/klm/search/types.hh
@@ -3,12 +3,29 @@
#include <stdint.h>
+namespace lm { namespace ngram { class ChartState; } }
+
namespace search {
typedef float Score;
typedef uint32_t Arity;
+union Note {
+ const void *vp;
+};
+
+typedef void *History;
+
+struct NBestComplete {
+ NBestComplete(History in_history, const lm::ngram::ChartState &in_state, Score in_score)
+ : history(in_history), state(&in_state), score(in_score) {}
+
+ History history;
+ const lm::ngram::ChartState *state;
+ Score score;
+};
+
} // namespace search
#endif // SEARCH_TYPES__
diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc
index 11f4631f..45842982 100644
--- a/klm/search/vertex.cc
+++ b/klm/search/vertex.cc
@@ -19,21 +19,34 @@ struct GreaterByBound : public std::binary_function<const VertexNode *, const Ve
} // namespace
-void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) {
+void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) {
if (Complete()) {
- assert(end_.Valid());
+ assert(end_);
assert(extend_.empty());
- bound_ = end_.GetScore();
return;
}
- if (extend_.size() == 1 && parent_ptr) {
- *parent_ptr = extend_[0];
- extend_[0]->SortAndSet(context, parent_ptr);
+ if (extend_.size() == 1) {
+ parent_ptr = extend_[0];
+ extend_[0]->RecursiveSortAndSet(context, parent_ptr);
context.DeleteVertexNode(this);
return;
}
for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
- (*i)->SortAndSet(context, &*i);
+ (*i)->RecursiveSortAndSet(context, *i);
+ }
+ std::sort(extend_.begin(), extend_.end(), GreaterByBound());
+ bound_ = extend_.front()->Bound();
+}
+
+void VertexNode::SortAndSet(ContextBase &context) {
+ // This is the root. The root might be empty.
+ if (extend_.empty()) {
+ bound_ = -INFINITY;
+ return;
+ }
+ // The root cannot be replaced. There's always one transition.
+ for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
+ (*i)->RecursiveSortAndSet(context, *i);
}
std::sort(extend_.begin(), extend_.end(), GreaterByBound());
bound_ = extend_.front()->Bound();
diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh
index 52bc1dfe..10b3339b 100644
--- a/klm/search/vertex.hh
+++ b/klm/search/vertex.hh
@@ -2,7 +2,6 @@
#define SEARCH_VERTEX__
#include "lm/left.hh"
-#include "search/final.hh"
#include "search/types.hh"
#include <boost/unordered_set.hpp>
@@ -10,6 +9,7 @@
#include <queue>
#include <vector>
+#include <math.h>
#include <stdint.h>
namespace search {
@@ -18,7 +18,7 @@ class ContextBase;
class VertexNode {
public:
- VertexNode() {}
+ VertexNode() : end_() {}
void InitRoot() {
extend_.clear();
@@ -26,7 +26,7 @@ class VertexNode {
state_.left.length = 0;
state_.right.length = 0;
right_full_ = false;
- end_ = Final();
+ end_ = History();
}
lm::ngram::ChartState &MutableState() { return state_; }
@@ -36,20 +36,21 @@ class VertexNode {
extend_.push_back(next);
}
- void SetEnd(Final end) {
- assert(!end_.Valid());
+ void SetEnd(History end, Score score) {
+ assert(!end_);
end_ = end;
+ bound_ = score;
}
- void SortAndSet(ContextBase &context, VertexNode **parent_pointer);
+ void SortAndSet(ContextBase &context);
// Should only happen to a root node when the entire vertex is empty.
bool Empty() const {
- return !end_.Valid() && extend_.empty();
+ return !end_ && extend_.empty();
}
bool Complete() const {
- return end_.Valid();
+ return end_;
}
const lm::ngram::ChartState &State() const { return state_; }
@@ -64,7 +65,7 @@ class VertexNode {
}
// Will be invalid unless this is a leaf.
- const Final End() const { return end_; }
+ const History End() const { return end_; }
const VertexNode &operator[](size_t index) const {
return *extend_[index];
@@ -75,13 +76,15 @@ class VertexNode {
}
private:
+ void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent);
+
std::vector<VertexNode*> extend_;
lm::ngram::ChartState state_;
bool right_full_;
Score bound_;
- Final end_;
+ History end_;
};
class PartialVertex {
@@ -97,7 +100,7 @@ class PartialVertex {
const lm::ngram::ChartState &State() const { return back_->State(); }
bool RightFull() const { return back_->RightFull(); }
- Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); }
+ Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); }
unsigned char Length() const { return back_->Length(); }
@@ -121,7 +124,7 @@ class PartialVertex {
return ret;
}
- const Final End() const {
+ const History End() const {
return back_->End();
}
@@ -130,16 +133,18 @@ class PartialVertex {
unsigned int index_;
};
+template <class Output> class VertexGenerator;
+
class Vertex {
public:
Vertex() {}
PartialVertex RootPartial() const { return PartialVertex(root_); }
- const Final BestChild() const {
+ const History BestChild() const {
PartialVertex top(RootPartial());
if (top.Empty()) {
- return Final();
+ return History();
} else {
PartialVertex continuation;
while (!top.Complete()) {
@@ -150,8 +155,8 @@ class Vertex {
}
private:
- friend class VertexGenerator;
-
+ template <class Output> friend class VertexGenerator;
+ template <class Output> friend class RootVertexGenerator;
VertexNode root_;
};
diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc
index 0945fe55..73139ffc 100644
--- a/klm/search/vertex_generator.cc
+++ b/klm/search/vertex_generator.cc
@@ -4,26 +4,18 @@
#include "search/context.hh"
#include "search/edge.hh"
+#include <boost/unordered_map.hpp>
+#include <boost/version.hpp>
+
#include <stdint.h>
namespace search {
-VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) {
- gen.root_.InitRoot();
-}
-
+#if BOOST_VERSION > 104200
namespace {
const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
-// Parallel structure to VertexNode.
-struct Trie {
- Trie() : under(NULL) {}
-
- VertexNode *under;
- boost::unordered_map<uint64_t, Trie> extend;
-};
-
Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {
Trie &next = node.extend[added];
if (!next.under) {
@@ -39,19 +31,10 @@ Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::n
return next;
}
-void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) {
- Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote());
- Final *child_out = final.Children();
- const PartialVertex *part = partial.NT();
- const PartialVertex *const part_end_loop = part + partial.GetArity();
- for (; part != part_end_loop; ++part, ++child_out)
- *child_out = part->End();
-
- starter.under->SetEnd(final);
-}
+} // namespace
-void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) {
- const lm::ngram::ChartState &state = partial.CompletedState();
+void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end) {
+ const lm::ngram::ChartState &state = *end.state;
unsigned char left = 0, right = 0;
Trie *node = &root;
@@ -77,18 +60,9 @@ void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) {
}
node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true);
- CompleteTransition(context, *node, partial);
+ node->under->SetEnd(end.history, end.score);
}
-} // namespace
-
-void VertexGenerator::FinishedSearch() {
- Trie root;
- root.under = &gen_.root_;
- for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) {
- AddHypothesis(context_, root, i->second);
- }
- root.under->SortAndSet(context_, NULL);
-}
+#endif // BOOST_VERSION
} // namespace search
diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh
index 60e86112..da563c2d 100644
--- a/klm/search/vertex_generator.hh
+++ b/klm/search/vertex_generator.hh
@@ -2,9 +2,11 @@
#define SEARCH_VERTEX_GENERATOR__
#include "search/edge.hh"
+#include "search/types.hh"
#include "search/vertex.hh"
#include <boost/unordered_map.hpp>
+#include <boost/version.hpp>
namespace lm {
namespace ngram {
@@ -15,21 +17,44 @@ class ChartState;
namespace search {
class ContextBase;
-class Final;
-class VertexGenerator {
+#if BOOST_VERSION > 104200
+// Parallel structure to VertexNode.
+struct Trie {
+ Trie() : under(NULL) {}
+
+ VertexNode *under;
+ boost::unordered_map<uint64_t, Trie> extend;
+};
+
+void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end);
+
+#endif // BOOST_VERSION
+
+// Output makes the single-best or n-best list.
+template <class Output> class VertexGenerator {
public:
- VertexGenerator(ContextBase &context, Vertex &gen);
+ VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) {
+ gen.root_.InitRoot();
+ }
void NewHypothesis(PartialEdge partial) {
- const lm::ngram::ChartState &state = partial.CompletedState();
- std::pair<Existing::iterator, bool> ret(existing_.insert(std::make_pair(hash_value(state), partial)));
- if (!ret.second && ret.first->second < partial) {
- ret.first->second = partial;
- }
+ nbest_.Add(existing_[hash_value(partial.CompletedState())], partial);
}
- void FinishedSearch();
+ void FinishedSearch() {
+#if BOOST_VERSION > 104200
+ Trie root;
+ root.under = &gen_.root_;
+ for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) {
+ AddHypothesis(context_, root, nbest_.Complete(i->second));
+ }
+ existing_.clear();
+ root.under->SortAndSet(context_);
+#else
+ UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search.");
+#endif
+ }
const Vertex &Generating() const { return gen_; }
@@ -38,8 +63,35 @@ class VertexGenerator {
Vertex &gen_;
- typedef boost::unordered_map<uint64_t, PartialEdge> Existing;
+ typedef boost::unordered_map<uint64_t, typename Output::Combine> Existing;
Existing existing_;
+
+ Output &nbest_;
+};
+
+// Special case for root vertex: everything should come together into the root
+// node. In theory, this should happen naturally due to state collapsing with
+// <s> and </s>. If that's the case, VertexGenerator is fine, though it will
+// make one connection.
+template <class Output> class RootVertexGenerator {
+ public:
+ RootVertexGenerator(Vertex &gen, Output &out) : gen_(gen), out_(out) {}
+
+ void NewHypothesis(PartialEdge partial) {
+ out_.Add(combine_, partial);
+ }
+
+ void FinishedSearch() {
+ gen_.root_.InitRoot();
+ NBestComplete completed(out_.Complete(combine_));
+ gen_.root_.SetEnd(completed.history, completed.score);
+ }
+
+ private:
+ Vertex &gen_;
+
+ typename Output::Combine combine_;
+ Output &out_;
};
} // namespace search
diff --git a/klm/search/weights.cc b/klm/search/weights.cc
deleted file mode 100644
index d65471ad..00000000
--- a/klm/search/weights.cc
+++ /dev/null
@@ -1,71 +0,0 @@
-#include "search/weights.hh"
-#include "util/tokenize_piece.hh"
-
-#include <cstdlib>
-
-namespace search {
-
-namespace {
-struct Insert {
- void operator()(boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) const {
- std::string copy(name.data(), name.size());
- map[copy] = score;
- }
-};
-
-struct DotProduct {
- search::Score total;
- DotProduct() : total(0.0) {}
-
- void operator()(const boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) {
- boost::unordered_map<std::string, search::Score>::const_iterator i(FindStringPiece(map, name));
- if (i != map.end())
- total += score * i->second;
- }
-};
-
-template <class Map, class Op> void Parse(StringPiece text, Map &map, Op &op) {
- for (util::TokenIter<util::SingleCharacter, true> spaces(text, ' '); spaces; ++spaces) {
- util::TokenIter<util::SingleCharacter> equals(*spaces, '=');
- UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces);
- StringPiece name(*equals);
- UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces);
- char *end;
- // Assumes proper termination.
- double value = std::strtod(equals->data(), &end);
- UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals);
- UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces);
- op(map, name, value);
- }
-}
-
-} // namespace
-
-Weights::Weights(StringPiece text) {
- Insert op;
- Parse<Map, Insert>(text, map_, op);
- lm_ = Steal("LanguageModel");
- oov_ = Steal("OOV");
- word_penalty_ = Steal("WordPenalty");
-}
-
-Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {}
-
-search::Score Weights::DotNoLM(StringPiece text) const {
- DotProduct dot;
- Parse<const Map, DotProduct>(text, map_, dot);
- return dot.total;
-}
-
-float Weights::Steal(const std::string &str) {
- Map::iterator i(map_.find(str));
- if (i == map_.end()) {
- return 0.0;
- } else {
- float ret = i->second;
- map_.erase(i);
- return ret;
- }
-}
-
-} // namespace search
diff --git a/klm/search/weights.hh b/klm/search/weights.hh
deleted file mode 100644
index df1c419f..00000000
--- a/klm/search/weights.hh
+++ /dev/null
@@ -1,52 +0,0 @@
-// For now, the individual features are not kept.
-#ifndef SEARCH_WEIGHTS__
-#define SEARCH_WEIGHTS__
-
-#include "search/types.hh"
-#include "util/exception.hh"
-#include "util/string_piece.hh"
-
-#include <boost/unordered_map.hpp>
-
-#include <string>
-
-namespace search {
-
-class WeightParseException : public util::Exception {
- public:
- WeightParseException() {}
- ~WeightParseException() throw() {}
-};
-
-class Weights {
- public:
- // Parses weights, sets lm_weight_, removes it from map_.
- explicit Weights(StringPiece text);
-
- // Just the three scores we care about adding.
- Weights(Score lm, Score oov, Score word_penalty);
-
- Score DotNoLM(StringPiece text) const;
-
- Score LM() const { return lm_; }
-
- Score OOV() const { return oov_; }
-
- Score WordPenalty() const { return word_penalty_; }
-
- // Mostly for testing.
- const boost::unordered_map<std::string, Score> &GetMap() const { return map_; }
-
- private:
- float Steal(const std::string &str);
-
- typedef boost::unordered_map<std::string, Score> Map;
-
- Map map_;
-
- Score lm_, oov_, word_penalty_;
-};
-
-} // namespace search
-
-#endif // SEARCH_WEIGHTS__
diff --git a/klm/search/weights_test.cc b/klm/search/weights_test.cc
deleted file mode 100644
index 4811ff06..00000000
--- a/klm/search/weights_test.cc
+++ /dev/null
@@ -1,38 +0,0 @@
-#include "search/weights.hh"
-
-#define BOOST_TEST_MODULE WeightTest
-#include <boost/test/unit_test.hpp>
-#include <boost/test/floating_point_comparison.hpp>
-
-namespace search {
-namespace {
-
-#define CHECK_WEIGHT(value, string) \
- i = parsed.find(string); \
- BOOST_REQUIRE(i != parsed.end()); \
- BOOST_CHECK_CLOSE((value), i->second, 0.001);
-
-BOOST_AUTO_TEST_CASE(parse) {
- // These are not real feature weights.
- Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5");
- const boost::unordered_map<std::string, search::Score> &parsed = w.GetMap();
- boost::unordered_map<std::string, search::Score>::const_iterator i;
- CHECK_WEIGHT(0.0, "rarity");
- CHECK_WEIGHT(0.0, "phrase-SGT");
- CHECK_WEIGHT(9.45117, "phrase-TGS");
- CHECK_WEIGHT(2.33833, "lexical-SGT");
- BOOST_CHECK(parsed.end() == parsed.find("lm"));
- BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001);
- CHECK_WEIGHT(-28.3317, "lexical-TGS");
- CHECK_WEIGHT(5.0, "glue?");
-}
-
-BOOST_AUTO_TEST_CASE(dot) {
- Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5");
- BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001);
- BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001);
- BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001);
-}
-
-} // namespace
-} // namespace search
diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am
index 5306850f..a676bdb3 100644
--- a/klm/util/Makefile.am
+++ b/klm/util/Makefile.am
@@ -27,6 +27,7 @@ libklm_util_a_SOURCES = \
mmap.cc \
murmur_hash.cc \
pool.cc \
+ read_compressed.cc \
string_piece.cc \
usage.cc
diff --git a/klm/util/exception.hh b/klm/util/exception.hh
index 053a850b..0165a7a3 100644
--- a/klm/util/exception.hh
+++ b/klm/util/exception.hh
@@ -87,8 +87,14 @@ template <class Except, class Data> typename Except::template ExceptionTag<Excep
throw UTIL_e; \
} while (0)
+#if __GNUC__ >= 3
+#define UTIL_UNLIKELY(x) __builtin_expect (!!(x), 0)
+#else
+#define UTIL_UNLIKELY(x) (x)
+#endif
+
#define UTIL_THROW_IF(Condition, Exception, Modify) do { \
- if (Condition) { \
+ if (UTIL_UNLIKELY(Condition)) { \
Exception UTIL_e; \
UTIL_SET_LOCATION(UTIL_e, #Exception, #Condition); \
UTIL_e << Modify; \
diff --git a/klm/util/file.cc b/klm/util/file.cc
index 6bf879ac..b9a77cf9 100644
--- a/klm/util/file.cc
+++ b/klm/util/file.cc
@@ -15,6 +15,8 @@
#if defined(_WIN32) || defined(_WIN64)
#include <windows.h>
#include <io.h>
+#include <algorithm>
+#include <limits.h>
#else
#include <unistd.h>
#endif
@@ -48,7 +50,7 @@ int OpenReadOrThrow(const char *name) {
int CreateOrThrow(const char *name) {
int ret;
#if defined(_WIN32) || defined(_WIN64)
- UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name);
+ UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR | _O_BINARY, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name);
#else
UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)), ErrnoException, "while creating " << name);
#endif
@@ -74,16 +76,22 @@ void ResizeOrThrow(int fd, uint64_t to) {
#endif
}
-#ifdef WIN32
-typedef int ssize_t;
+std::size_t PartialRead(int fd, void *to, std::size_t amount) {
+#if defined(_WIN32) || defined(_WIN64)
+ amount = min(static_cast<std::size_t>(INT_MAX), amount);
+ int ret = _read(fd, to, amount);
+#else
+ ssize_t ret = read(fd, to, amount);
#endif
+ UTIL_THROW_IF(ret < 0, ErrnoException, "Reading " << amount << " from fd " << fd << " failed.");
+ return static_cast<std::size_t>(ret);
+}
void ReadOrThrow(int fd, void *to_void, std::size_t amount) {
uint8_t *to = static_cast<uint8_t*>(to_void);
while (amount) {
- ssize_t ret = read(fd, to, amount);
- UTIL_THROW_IF(ret == -1, ErrnoException, "Reading " << amount << " from fd " << fd << " failed.");
- UTIL_THROW_IF(ret == 0, EndOfFileException, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read.");
+ std::size_t ret = PartialRead(fd, to, amount);
+ UTIL_THROW_IF(ret == 0, EndOfFileException, " in fd " << fd << " but there should be " << amount << " more bytes to read.");
amount -= ret;
to += ret;
}
@@ -93,8 +101,7 @@ std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount) {
uint8_t *to = static_cast<uint8_t*>(to_void);
std::size_t remaining = amount;
while (remaining) {
- ssize_t ret = read(fd, to, remaining);
- UTIL_THROW_IF(ret == -1, ErrnoException, "Reading " << remaining << " from fd " << fd << " failed.");
+ std::size_t ret = PartialRead(fd, to, remaining);
if (!ret) return amount - remaining;
remaining -= ret;
to += ret;
@@ -105,7 +112,11 @@ std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount) {
void WriteOrThrow(int fd, const void *data_void, std::size_t size) {
const uint8_t *data = static_cast<const uint8_t*>(data_void);
while (size) {
+#if defined(_WIN32) || defined(_WIN64)
+ int ret = write(fd, data, min(static_cast<std::size_t>(INT_MAX), size));
+#else
ssize_t ret = write(fd, data, size);
+#endif
if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed");
data += ret;
size -= ret;
@@ -114,7 +125,7 @@ void WriteOrThrow(int fd, const void *data_void, std::size_t size) {
void WriteOrThrow(FILE *to, const void *data, std::size_t size) {
assert(size);
- if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
+ UTIL_THROW_IF(1 != std::fwrite(data, size, 1, to), util::ErrnoException, "Short write; requested size " << size);
}
void FSyncOrThrow(int fd) {
@@ -149,14 +160,15 @@ void SeekEnd(int fd) {
std::FILE *FDOpenOrThrow(scoped_fd &file) {
std::FILE *ret = fdopen(file.get(), "r+b");
- if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen");
+ if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen descriptor " << file.get());
file.release();
return ret;
}
-std::FILE *FOpenOrThrow(const char *path, const char *mode) {
- std::FILE *ret;
- UTIL_THROW_IF(!(ret = fopen(path, mode)), util::ErrnoException, "Could not fopen " << path << " for " << mode);
+std::FILE *FDOpenReadOrThrow(scoped_fd &file) {
+ std::FILE *ret = fdopen(file.get(), "rb");
+ if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen descriptor " << file.get());
+ file.release();
return ret;
}
diff --git a/klm/util/file.hh b/klm/util/file.hh
index 185cb1f3..c24580d6 100644
--- a/klm/util/file.hh
+++ b/klm/util/file.hh
@@ -32,8 +32,6 @@ class scoped_fd {
return ret;
}
- operator bool() { return fd_ != -1; }
-
private:
int fd_;
@@ -76,8 +74,9 @@ uint64_t SizeFile(int fd);
void ResizeOrThrow(int fd, uint64_t to);
+std::size_t PartialRead(int fd, void *to, std::size_t size);
void ReadOrThrow(int fd, void *to, std::size_t size);
-std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount);
+std::size_t ReadOrEOF(int fd, void *to_void, std::size_t size);
void WriteOrThrow(int fd, const void *data_void, std::size_t size);
void WriteOrThrow(FILE *to, const void *data, std::size_t size);
@@ -90,8 +89,7 @@ void AdvanceOrThrow(int fd, int64_t off);
void SeekEnd(int fd);
std::FILE *FDOpenOrThrow(scoped_fd &file);
-
-std::FILE *FOpenOrThrow(const char *path, const char *mode);
+std::FILE *FDOpenReadOrThrow(scoped_fd &file);
class TempMaker {
public:
diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc
index 280f438c..5a208eff 100644
--- a/klm/util/file_piece.cc
+++ b/klm/util/file_piece.cc
@@ -14,7 +14,6 @@
#include <limits>
#include <assert.h>
-#include <ctype.h>
#include <fcntl.h>
#include <stdlib.h>
#include <sys/types.h>
@@ -26,13 +25,6 @@ ParseNumberException::ParseNumberException(StringPiece value) throw() {
*this << "Could not parse \"" << value << "\" into a number";
}
-#ifdef HAVE_ZLIB
-GZException::GZException(gzFile file) {
- int num;
- *this << gzerror(file, &num) << " from zlib";
-}
-#endif // HAVE_ZLIB
-
// Sigh this is the only way I could come up with to do a _const_ bool. It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale).
const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
@@ -48,19 +40,7 @@ FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, std:
Initialize(name, show_progress, min_buffer);
}
-FilePiece::~FilePiece() {
-#ifdef HAVE_ZLIB
- if (gz_file_) {
- // zlib took ownership
- file_.release();
- int ret;
- if (Z_OK != (ret = gzclose(gz_file_))) {
- std::cerr << "could not close file " << file_name_ << " using zlib" << std::endl;
- abort();
- }
- }
-#endif
-}
+FilePiece::~FilePiece() {}
StringPiece FilePiece::ReadLine(char delim) {
std::size_t skip = 0;
@@ -95,9 +75,6 @@ unsigned long int FilePiece::ReadULong() {
}
void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::size_t min_buffer) {
-#ifdef HAVE_ZLIB
- gz_file_ = NULL;
-#endif
file_name_ = name;
default_map_size_ = page_ * std::max<std::size_t>((min_buffer / page_ + 1), 2);
@@ -117,10 +94,7 @@ void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::s
}
Shift();
// gzip detect.
- if ((position_end_ - position_) > 2 && *position_ == 0x1f && static_cast<unsigned char>(*(position_ + 1)) == 0x8b) {
-#ifndef HAVE_ZLIB
- UTIL_THROW(GZException, "Looks like a gzip file but support was not compiled in.");
-#endif
+ if ((position_end_ - position_) >= ReadCompressed::kMagicSize && ReadCompressed::DetectCompressedMagic(position_)) {
if (!fallback_to_read_) {
at_end_ = false;
TransitionToRead();
@@ -197,7 +171,7 @@ void FilePiece::Shift() {
if (fallback_to_read_) ReadShift();
for (last_space_ = position_end_ - 1; last_space_ >= position_; --last_space_) {
- if (isspace(*last_space_)) break;
+ if (kSpaces[static_cast<unsigned char>(*last_space_)]) break;
}
}
@@ -248,17 +222,14 @@ void FilePiece::TransitionToRead() {
position_ = data_.begin();
position_end_ = position_;
-#ifdef HAVE_ZLIB
- assert(!gz_file_);
- gz_file_ = gzdopen(file_.get(), "r");
- UTIL_THROW_IF(!gz_file_, GZException, "zlib failed to open " << file_name_);
-#endif
+ try {
+ fell_back_.Reset(file_.release());
+ } catch (util::Exception &e) {
+ e << " in file " << file_name_;
+ throw;
+ }
}
-#ifdef WIN32
-typedef int ssize_t;
-#endif
-
void FilePiece::ReadShift() {
assert(fallback_to_read_);
// Bytes [data_.begin(), position_) have been consumed.
@@ -283,7 +254,7 @@ void FilePiece::ReadShift() {
position_ = data_.begin();
position_end_ = position_ + valid_length;
} else {
- size_t moving = position_end_ - position_;
+ std::size_t moving = position_end_ - position_;
memmove(data_.get(), position_, moving);
position_ = data_.begin();
position_end_ = position_ + moving;
@@ -291,20 +262,9 @@ void FilePiece::ReadShift() {
}
}
- ssize_t read_return;
-#ifdef HAVE_ZLIB
- read_return = gzread(gz_file_, static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read);
- if (read_return == -1) throw GZException(gz_file_);
- if (total_size_ != kBadSize) {
- // Just get the position, don't actually seek. Apparently this is how you do it. . .
- off_t ret = lseek(file_.get(), 0, SEEK_CUR);
- if (ret != -1) progress_.Set(ret);
- }
-#else
- read_return = read(file_.get(), static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read);
- UTIL_THROW_IF(read_return == -1, ErrnoException, "read failed");
- progress_.Set(mapped_offset_);
-#endif
+ std::size_t read_return = fell_back_.Read(static_cast<uint8_t*>(data_.get()) + already_read, default_map_size_ - already_read);
+ progress_.Set(fell_back_.RawAmount());
+
if (read_return == 0) {
at_end_ = true;
}
diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh
index af93d8aa..39bd1581 100644
--- a/klm/util/file_piece.hh
+++ b/klm/util/file_piece.hh
@@ -4,8 +4,8 @@
#include "util/ersatz_progress.hh"
#include "util/exception.hh"
#include "util/file.hh"
-#include "util/have.hh"
#include "util/mmap.hh"
+#include "util/read_compressed.hh"
#include "util/string_piece.hh"
#include <cstddef>
@@ -13,10 +13,6 @@
#include <stdint.h>
-#ifdef HAVE_ZLIB
-#include <zlib.h>
-#endif
-
namespace util {
class ParseNumberException : public Exception {
@@ -25,28 +21,19 @@ class ParseNumberException : public Exception {
~ParseNumberException() throw() {}
};
-class GZException : public Exception {
- public:
-#ifdef HAVE_ZLIB
- explicit GZException(gzFile file);
-#endif
- GZException() throw() {}
- ~GZException() throw() {}
-};
-
extern const bool kSpaces[256];
-// Memory backing the returned StringPiece may vanish on the next call.
+// Memory backing the returned StringPiece may vanish on the next call.
class FilePiece {
public:
- // 32 MB default.
- explicit FilePiece(const char *file, std::ostream *show_progress = NULL, std::size_t min_buffer = 33554432);
- // Takes ownership of fd. name is used for messages.
- explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, std::size_t min_buffer = 33554432);
+ // 1 MB default.
+ explicit FilePiece(const char *file, std::ostream *show_progress = NULL, std::size_t min_buffer = 1048576);
+ // Takes ownership of fd. name is used for messages.
+ explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, std::size_t min_buffer = 1048576);
~FilePiece();
-
- char get() {
+
+ char get() {
if (position_ == position_end_) {
Shift();
if (at_end_) throw EndOfFileException();
@@ -54,14 +41,14 @@ class FilePiece {
return *(position_++);
}
- // Leaves the delimiter, if any, to be returned by get(). Delimiters defined by isspace().
+ // Leaves the delimiter, if any, to be returned by get(). Delimiters defined by isspace().
StringPiece ReadDelimited(const bool *delim = kSpaces) {
SkipSpaces(delim);
return Consume(FindDelimiterOrEOF(delim));
}
// Unlike ReadDelimited, this includes leading spaces and consumes the delimiter.
- // It is similar to getline in that way.
+ // It is similar to getline in that way.
StringPiece ReadLine(char delim = '\n');
float ReadFloat();
@@ -69,7 +56,7 @@ class FilePiece {
long int ReadLong();
unsigned long int ReadULong();
- // Skip spaces defined by isspace.
+ // Skip spaces defined by isspace.
void SkipSpaces(const bool *delim = kSpaces) {
for (; ; ++position_) {
if (position_ == position_end_) Shift();
@@ -82,7 +69,7 @@ class FilePiece {
}
const std::string &FileName() const { return file_name_; }
-
+
private:
void Initialize(const char *name, std::ostream *show_progress, std::size_t min_buffer);
@@ -122,9 +109,7 @@ class FilePiece {
std::string file_name_;
-#ifdef HAVE_ZLIB
- gzFile gz_file_;
-#endif // HAVE_ZLIB
+ ReadCompressed fell_back_;
};
} // namespace util
diff --git a/klm/util/file_piece_test.cc b/klm/util/file_piece_test.cc
index f912e18a..e79ece7a 100644
--- a/klm/util/file_piece_test.cc
+++ b/klm/util/file_piece_test.cc
@@ -38,7 +38,7 @@ BOOST_AUTO_TEST_CASE(MMapReadLine) {
BOOST_CHECK_THROW(test.get(), EndOfFileException);
}
-#ifndef __APPLE__
+#if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)
/* Apple isn't happy with the popen, fileno, dup. And I don't want to
* reimplement popen. This is an issue with the test.
*/
@@ -65,7 +65,7 @@ BOOST_AUTO_TEST_CASE(StreamReadLine) {
BOOST_CHECK_THROW(test.get(), EndOfFileException);
BOOST_REQUIRE(!pclose(catter));
}
-#endif // __APPLE__
+#endif
#ifdef HAVE_ZLIB
diff --git a/klm/util/have.hh b/klm/util/have.hh
index b8181e99..85b838e4 100644
--- a/klm/util/have.hh
+++ b/klm/util/have.hh
@@ -2,22 +2,16 @@
#ifndef UTIL_HAVE__
#define UTIL_HAVE__
-#ifndef HAVE_ZLIB
-#if !defined(_WIN32) && !defined(_WIN64)
-#define HAVE_ZLIB
-#endif
-#endif
-
#ifndef HAVE_ICU
//#define HAVE_ICU
#endif
#ifndef HAVE_BOOST
-#define HAVE_BOOST
+//#define HAVE_BOOST
#endif
-#ifndef HAVE_THREADS
-//#define HAVE_THREADS
+#ifdef HAVE_CONFIG_H
+#include "config.h"
#endif
#endif // UTIL_HAVE__
diff --git a/klm/util/joint_sort.hh b/klm/util/joint_sort.hh
index cf3d8432..1b43ddcf 100644
--- a/klm/util/joint_sort.hh
+++ b/klm/util/joint_sort.hh
@@ -60,7 +60,7 @@ template <class KeyIter, class ValueIter> class JointProxy {
JointProxy(const KeyIter &key_iter, const ValueIter &value_iter) : inner_(key_iter, value_iter) {}
JointProxy(const JointProxy<KeyIter, ValueIter> &other) : inner_(other.inner_) {}
- operator const value_type() const {
+ operator value_type() const {
value_type ret;
ret.key = *inner_.key_;
ret.value = *inner_.value_;
@@ -121,7 +121,7 @@ template <class Proxy, class Less> class LessWrapper : public std::binary_functi
template <class KeyIter, class ValueIter> class PairedIterator : public ProxyIterator<detail::JointProxy<KeyIter, ValueIter> > {
public:
- PairedIterator(const KeyIter &key, const ValueIter &value) :
+ PairedIterator(const KeyIter &key, const ValueIter &value) :
ProxyIterator<detail::JointProxy<KeyIter, ValueIter> >(detail::JointProxy<KeyIter, ValueIter>(key, value)) {}
};
diff --git a/klm/util/read_compressed.cc b/klm/util/read_compressed.cc
new file mode 100644
index 00000000..4ec94c4e
--- /dev/null
+++ b/klm/util/read_compressed.cc
@@ -0,0 +1,403 @@
+#include "util/read_compressed.hh"
+
+#include "util/file.hh"
+#include "util/have.hh"
+#include "util/scoped.hh"
+
+#include <algorithm>
+#include <iostream>
+
+#include <assert.h>
+#include <limits.h>
+#include <stdlib.h>
+#include <string.h>
+
+#ifdef HAVE_ZLIB
+#include <zlib.h>
+#endif
+
+#ifdef HAVE_BZLIB
+#include <bzlib.h>
+#endif
+
+#ifdef HAVE_XZLIB
+#include <lzma.h>
+#endif
+
+namespace util {
+
+CompressedException::CompressedException() throw() {}
+CompressedException::~CompressedException() throw() {}
+
+GZException::GZException() throw() {}
+GZException::~GZException() throw() {}
+
+BZException::BZException() throw() {}
+BZException::~BZException() throw() {}
+
+XZException::XZException() throw() {}
+XZException::~XZException() throw() {}
+
+class ReadBase {
+ public:
+ virtual ~ReadBase() {}
+
+ virtual std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) = 0;
+
+ protected:
+ static void ReplaceThis(ReadBase *with, ReadCompressed &thunk) {
+ thunk.internal_.reset(with);
+ }
+
+ static uint64_t &ReadCount(ReadCompressed &thunk) {
+ return thunk.raw_amount_;
+ }
+};
+
+namespace {
+
+// Completed file that other classes can thunk to.
+class Complete : public ReadBase {
+ public:
+ std::size_t Read(void *, std::size_t, ReadCompressed &) {
+ return 0;
+ }
+};
+
+class Uncompressed : public ReadBase {
+ public:
+ explicit Uncompressed(int fd) : fd_(fd) {}
+
+ std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
+ std::size_t got = PartialRead(fd_.get(), to, amount);
+ ReadCount(thunk) += got;
+ return got;
+ }
+
+ private:
+ scoped_fd fd_;
+};
+
+class UncompressedWithHeader : public ReadBase {
+ public:
+ UncompressedWithHeader(int fd, void *already_data, std::size_t already_size) : fd_(fd) {
+ assert(already_size);
+ buf_.reset(malloc(already_size));
+ if (!buf_.get()) throw std::bad_alloc();
+ memcpy(buf_.get(), already_data, already_size);
+ remain_ = static_cast<uint8_t*>(buf_.get());
+ end_ = remain_ + already_size;
+ }
+
+ std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
+ assert(buf_.get());
+ std::size_t sending = std::min<std::size_t>(amount, end_ - remain_);
+ memcpy(to, remain_, sending);
+ remain_ += sending;
+ if (remain_ == end_) {
+ ReplaceThis(new Uncompressed(fd_.release()), thunk);
+ }
+ return sending;
+ }
+
+ private:
+ scoped_malloc buf_;
+ uint8_t *remain_;
+ uint8_t *end_;
+
+ scoped_fd fd_;
+};
+
+#ifdef HAVE_ZLIB
+class GZip : public ReadBase {
+ private:
+ static const std::size_t kInputBuffer = 16384;
+ public:
+ GZip(int fd, void *already_data, std::size_t already_size)
+ : file_(fd), in_buffer_(malloc(kInputBuffer)) {
+ if (!in_buffer_.get()) throw std::bad_alloc();
+ assert(already_size < kInputBuffer);
+ if (already_size) {
+ memcpy(in_buffer_.get(), already_data, already_size);
+ stream_.next_in = static_cast<Bytef *>(in_buffer_.get());
+ stream_.avail_in = already_size;
+ stream_.avail_in += ReadOrEOF(file_.get(), static_cast<uint8_t*>(in_buffer_.get()) + already_size, kInputBuffer - already_size);
+ } else {
+ stream_.avail_in = 0;
+ }
+ stream_.zalloc = Z_NULL;
+ stream_.zfree = Z_NULL;
+ stream_.opaque = Z_NULL;
+ stream_.msg = NULL;
+ // 32 for zlib and gzip decoding with automatic header detection.
+ // 15 for maximum window size.
+ UTIL_THROW_IF(Z_OK != inflateInit2(&stream_, 32 + 15), GZException, "Failed to initialize zlib.");
+ }
+
+ ~GZip() {
+ if (Z_OK != inflateEnd(&stream_)) {
+ std::cerr << "zlib could not close properly." << std::endl;
+ abort();
+ }
+ }
+
+ std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
+ if (amount == 0) return 0;
+ stream_.next_out = static_cast<Bytef*>(to);
+ stream_.avail_out = std::min<std::size_t>(std::numeric_limits<uInt>::max(), amount);
+ do {
+ if (!stream_.avail_in) ReadInput(thunk);
+ int result = inflate(&stream_, 0);
+ switch (result) {
+ case Z_OK:
+ break;
+ case Z_STREAM_END:
+ {
+ std::size_t ret = static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to);
+ ReplaceThis(new Complete(), thunk);
+ return ret;
+ }
+ case Z_ERRNO:
+ UTIL_THROW(ErrnoException, "zlib error");
+ default:
+ UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result);
+ }
+ } while (stream_.next_out == to);
+ return static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to);
+ }
+
+ private:
+ void ReadInput(ReadCompressed &thunk) {
+ assert(!stream_.avail_in);
+ stream_.next_in = static_cast<Bytef *>(in_buffer_.get());
+ stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer);
+ ReadCount(thunk) += stream_.avail_in;
+ }
+
+ scoped_fd file_;
+ scoped_malloc in_buffer_;
+ z_stream stream_;
+};
+#endif // HAVE_ZLIB
+
+#ifdef HAVE_BZLIB
+class BZip : public ReadBase {
+ public:
+ explicit BZip(int fd, void *already_data, std::size_t already_size) {
+ scoped_fd hold(fd);
+ closer_.reset(FDOpenReadOrThrow(hold));
+ int bzerror = BZ_OK;
+ file_ = BZ2_bzReadOpen(&bzerror, closer_.get(), 0, 0, already_data, already_size);
+ switch (bzerror) {
+ case BZ_OK:
+ return;
+ case BZ_CONFIG_ERROR:
+ UTIL_THROW(BZException, "Looks like bzip2 was miscompiled.");
+ case BZ_PARAM_ERROR:
+ UTIL_THROW(BZException, "Parameter error");
+ case BZ_IO_ERROR:
+ UTIL_THROW(BZException, "IO error reading file");
+ case BZ_MEM_ERROR:
+ throw std::bad_alloc();
+ }
+ }
+
+ ~BZip() {
+ int bzerror = BZ_OK;
+ BZ2_bzReadClose(&bzerror, file_);
+ if (bzerror != BZ_OK) {
+ std::cerr << "bz2 readclose error" << std::endl;
+ abort();
+ }
+ }
+
+ std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
+ int bzerror = BZ_OK;
+ int ret = BZ2_bzRead(&bzerror, file_, to, std::min<std::size_t>(static_cast<std::size_t>(INT_MAX), amount));
+ long pos;
+ switch (bzerror) {
+ case BZ_STREAM_END:
+ pos = ftell(closer_.get());
+ if (pos != -1) ReadCount(thunk) = pos;
+ ReplaceThis(new Complete(), thunk);
+ return ret;
+ case BZ_OK:
+ pos = ftell(closer_.get());
+ if (pos != -1) ReadCount(thunk) = pos;
+ return ret;
+ default:
+ UTIL_THROW(BZException, "bzip2 error " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror);
+ }
+ }
+
+ private:
+ scoped_FILE closer_;
+ BZFILE *file_;
+};
+#endif // HAVE_BZLIB
+
+#ifdef HAVE_XZLIB
+class XZip : public ReadBase {
+ private:
+ static const std::size_t kInputBuffer = 16384;
+ public:
+ XZip(int fd, void *already_data, std::size_t already_size)
+ : file_(fd), in_buffer_(malloc(kInputBuffer)), stream_(), action_(LZMA_RUN) {
+ if (!in_buffer_.get()) throw std::bad_alloc();
+ assert(already_size < kInputBuffer);
+ if (already_size) {
+ memcpy(in_buffer_.get(), already_data, already_size);
+ stream_.next_in = static_cast<const uint8_t*>(in_buffer_.get());
+ stream_.avail_in = already_size;
+ stream_.avail_in += ReadOrEOF(file_.get(), static_cast<uint8_t*>(in_buffer_.get()) + already_size, kInputBuffer - already_size);
+ } else {
+ stream_.avail_in = 0;
+ }
+ stream_.allocator = NULL;
+ lzma_ret ret = lzma_stream_decoder(&stream_, UINT64_MAX, LZMA_CONCATENATED);
+ switch (ret) {
+ case LZMA_OK:
+ break;
+ case LZMA_MEM_ERROR:
+ UTIL_THROW(ErrnoException, "xz open error");
+ default:
+ UTIL_THROW(XZException, "xz error code " << ret);
+ }
+ }
+
+ ~XZip() {
+ lzma_end(&stream_);
+ }
+
+ std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) {
+ if (amount == 0) return 0;
+ stream_.next_out = static_cast<uint8_t*>(to);
+ stream_.avail_out = amount;
+ do {
+ if (!stream_.avail_in) ReadInput(thunk);
+ lzma_ret status = lzma_code(&stream_, action_);
+ switch (status) {
+ case LZMA_OK:
+ break;
+ case LZMA_STREAM_END:
+ UTIL_THROW_IF(action_ != LZMA_FINISH, XZException, "Input not finished yet.");
+ {
+ std::size_t ret = static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to);
+ ReplaceThis(new Complete(), thunk);
+ return ret;
+ }
+ case LZMA_MEM_ERROR:
+ throw std::bad_alloc();
+ case LZMA_FORMAT_ERROR:
+ UTIL_THROW(XZException, "xzlib says file format not recognized");
+ case LZMA_OPTIONS_ERROR:
+ UTIL_THROW(XZException, "xzlib says unsupported compression options");
+ case LZMA_DATA_ERROR:
+ UTIL_THROW(XZException, "xzlib says this file is corrupt");
+ case LZMA_BUF_ERROR:
+ UTIL_THROW(XZException, "xzlib says unexpected end of input");
+ default:
+ UTIL_THROW(XZException, "unrecognized xzlib error " << status);
+ }
+ } while (stream_.next_out == to);
+ return static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to);
+ }
+
+ private:
+ void ReadInput(ReadCompressed &thunk) {
+ assert(!stream_.avail_in);
+ stream_.next_in = static_cast<const uint8_t*>(in_buffer_.get());
+ stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer);
+ if (!stream_.avail_in) action_ = LZMA_FINISH;
+ ReadCount(thunk) += stream_.avail_in;
+ }
+
+ scoped_fd file_;
+ scoped_malloc in_buffer_;
+ lzma_stream stream_;
+
+ lzma_action action_;
+};
+#endif // HAVE_XZLIB
+
+enum MagicResult {
+ UNKNOWN, GZIP, BZIP, XZIP
+};
+
+MagicResult DetectMagic(const void *from_void) {
+ const uint8_t *header = static_cast<const uint8_t*>(from_void);
+ if (header[0] == 0x1f && header[1] == 0x8b) {
+ return GZIP;
+ }
+ if (header[0] == 'B' && header[1] == 'Z') {
+ return BZIP;
+ }
+ const uint8_t xzmagic[6] = { 0xFD, '7', 'z', 'X', 'Z', 0x00 };
+ if (!memcmp(header, xzmagic, 6)) {
+ return XZIP;
+ }
+ return UNKNOWN;
+}
+
+ReadBase *ReadFactory(int fd, uint64_t &raw_amount) {
+ scoped_fd hold(fd);
+ unsigned char header[ReadCompressed::kMagicSize];
+ raw_amount = ReadOrEOF(fd, header, ReadCompressed::kMagicSize);
+ if (!raw_amount)
+ return new Uncompressed(hold.release());
+ if (raw_amount != ReadCompressed::kMagicSize)
+ return new UncompressedWithHeader(hold.release(), header, raw_amount);
+ switch (DetectMagic(header)) {
+ case GZIP:
+#ifdef HAVE_ZLIB
+ return new GZip(hold.release(), header, ReadCompressed::kMagicSize);
+#else
+ UTIL_THROW(CompressedException, "This looks like a gzip file but gzip support was not compiled in.");
+#endif
+ case BZIP:
+#ifdef HAVE_BZLIB
+ return new BZip(hold.release(), header, ReadCompressed::kMagicSize);
+#else
+ UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZ), but bzip support was not compiled in.");
+#endif
+ case XZIP:
+#ifdef HAVE_XZLIB
+ return new XZip(hold.release(), header, ReadCompressed::kMagicSize);
+#else
+ UTIL_THROW(CompressedException, "This looks like an xz file, but xz support was not compiled in.");
+#endif
+ case UNKNOWN:
+ break;
+ }
+ try {
+ AdvanceOrThrow(fd, -ReadCompressed::kMagicSize);
+ } catch (const util::ErrnoException &e) {
+ return new UncompressedWithHeader(hold.release(), header, ReadCompressed::kMagicSize);
+ }
+ return new Uncompressed(hold.release());
+}
+
+} // namespace
+
+bool ReadCompressed::DetectCompressedMagic(const void *from_void) {
+ return DetectMagic(from_void) != UNKNOWN;
+}
+
+ReadCompressed::ReadCompressed(int fd) {
+ Reset(fd);
+}
+
+ReadCompressed::ReadCompressed() {}
+
+ReadCompressed::~ReadCompressed() {}
+
+void ReadCompressed::Reset(int fd) {
+ internal_.reset();
+ internal_.reset(ReadFactory(fd, raw_amount_));
+}
+
+std::size_t ReadCompressed::Read(void *to, std::size_t amount) {
+ return internal_->Read(to, amount, *this);
+}
+
+} // namespace util
diff --git a/klm/util/read_compressed.hh b/klm/util/read_compressed.hh
new file mode 100644
index 00000000..83ca9fb2
--- /dev/null
+++ b/klm/util/read_compressed.hh
@@ -0,0 +1,74 @@
+#ifndef UTIL_READ_COMPRESSED__
+#define UTIL_READ_COMPRESSED__
+
+#include "util/exception.hh"
+#include "util/scoped.hh"
+
+#include <cstddef>
+
+#include <stdint.h>
+
+namespace util {
+
+class CompressedException : public Exception {
+ public:
+ CompressedException() throw();
+ virtual ~CompressedException() throw();
+};
+
+class GZException : public CompressedException {
+ public:
+ GZException() throw();
+ ~GZException() throw();
+};
+
+class BZException : public CompressedException {
+ public:
+ BZException() throw();
+ ~BZException() throw();
+};
+
+class XZException : public CompressedException {
+ public:
+ XZException() throw();
+ ~XZException() throw();
+};
+
+class ReadBase;
+
+class ReadCompressed {
+ public:
+ static const std::size_t kMagicSize = 6;
+ // Must have at least kMagicSize bytes.
+ static bool DetectCompressedMagic(const void *from);
+
+ // Takes ownership of fd.
+ explicit ReadCompressed(int fd);
+
+ // Must call Reset later.
+ ReadCompressed();
+
+ ~ReadCompressed();
+
+ // Takes ownership of fd.
+ void Reset(int fd);
+
+ std::size_t Read(void *to, std::size_t amount);
+
+ uint64_t RawAmount() const { return raw_amount_; }
+
+ private:
+ friend class ReadBase;
+
+ scoped_ptr<ReadBase> internal_;
+
+ uint64_t raw_amount_;
+
+ // No copying.
+ ReadCompressed(const ReadCompressed &);
+ void operator=(const ReadCompressed &);
+};
+
+} // namespace util
+
+#endif // UTIL_READ_COMPRESSED__
diff --git a/klm/util/read_compressed_test.cc b/klm/util/read_compressed_test.cc
new file mode 100644
index 00000000..6fd97e5e
--- /dev/null
+++ b/klm/util/read_compressed_test.cc
@@ -0,0 +1,94 @@
+#include "util/read_compressed.hh"
+
+#include "util/file.hh"
+#include "util/have.hh"
+
+#define BOOST_TEST_MODULE ReadCompressedTest
+#include <boost/test/unit_test.hpp>
+#include <boost/scoped_ptr.hpp>
+
+#include <fstream>
+#include <string>
+
+#include <stdlib.h>
+
+namespace util {
+namespace {
+
+void ReadLoop(ReadCompressed &reader, void *to_void, std::size_t amount) {
+ uint8_t *to = static_cast<uint8_t*>(to_void);
+ while (amount) {
+ std::size_t ret = reader.Read(to, amount);
+ BOOST_REQUIRE(ret);
+ to += ret;
+ amount -= ret;
+ }
+}
+
+void TestRandom(const char *compressor) {
+ const uint32_t kSize4 = 100000 / 4;
+ char name[] = "tempXXXXXX";
+
+ // Write test file.
+ {
+ scoped_fd original(mkstemp(name));
+ BOOST_REQUIRE(original.get() > 0);
+ for (uint32_t i = 0; i < kSize4; ++i) {
+ WriteOrThrow(original.get(), &i, sizeof(uint32_t));
+ }
+ }
+
+ char gzname[] = "tempXXXXXX";
+ scoped_fd gzipped(mkstemp(gzname));
+
+ std::string command(compressor);
+#ifdef __CYGWIN__
+ command += ".exe";
+#endif
+ command += " <\"";
+ command += name;
+ command += "\" >\"";
+ command += gzname;
+ command += "\"";
+ BOOST_REQUIRE_EQUAL(0, system(command.c_str()));
+
+ BOOST_CHECK_EQUAL(0, unlink(name));
+ BOOST_CHECK_EQUAL(0, unlink(gzname));
+
+ ReadCompressed reader(gzipped.release());
+ for (uint32_t i = 0; i < kSize4; ++i) {
+ uint32_t got;
+ ReadLoop(reader, &got, sizeof(uint32_t));
+ BOOST_CHECK_EQUAL(i, got);
+ }
+
+ char ignored;
+ BOOST_CHECK_EQUAL((std::size_t)0, reader.Read(&ignored, 1));
+ // Test double EOF call.
+ BOOST_CHECK_EQUAL((std::size_t)0, reader.Read(&ignored, 1));
+}
+
+BOOST_AUTO_TEST_CASE(Uncompressed) {
+ TestRandom("cat");
+}
+
+#ifdef HAVE_ZLIB
+BOOST_AUTO_TEST_CASE(ReadGZ) {
+ TestRandom("gzip");
+}
+#endif // HAVE_ZLIB
+
+#ifdef HAVE_BZLIB
+BOOST_AUTO_TEST_CASE(ReadBZ) {
+ TestRandom("bzip2");
+}
+#endif // HAVE_BZLIB
+
+#ifdef HAVE_XZLIB
+BOOST_AUTO_TEST_CASE(ReadXZ) {
+ TestRandom("xz");
+}
+#endif
+
+} // namespace
+} // namespace util
diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh
index 93e2e817..d62c6df1 100644
--- a/klm/util/scoped.hh
+++ b/klm/util/scoped.hh
@@ -1,40 +1,13 @@
#ifndef UTIL_SCOPED__
#define UTIL_SCOPED__
+/* Other scoped objects in the style of scoped_ptr. */
#include "util/exception.hh"
-
-/* Other scoped objects in the style of scoped_ptr. */
#include <cstddef>
#include <cstdlib>
namespace util {
-template <class T, class R, R (*Free)(T*)> class scoped_thing {
- public:
- explicit scoped_thing(T *c = static_cast<T*>(0)) : c_(c) {}
-
- ~scoped_thing() { if (c_) Free(c_); }
-
- void reset(T *c) {
- if (c_) Free(c_);
- c_ = c;
- }
-
- T &operator*() { return *c_; }
- const T&operator*() const { return *c_; }
- T &operator->() { return *c_; }
- const T&operator->() const { return *c_; }
-
- T *get() { return c_; }
- const T *get() const { return c_; }
-
- private:
- T *c_;
-
- scoped_thing(const scoped_thing &);
- scoped_thing &operator=(const scoped_thing &);
-};
-
class scoped_malloc {
public:
scoped_malloc() : p_(NULL) {}
@@ -77,9 +50,6 @@ template <class T> class scoped_array {
T &operator*() { return *c_; }
const T&operator*() const { return *c_; }
- T &operator->() { return *c_; }
- const T&operator->() const { return *c_; }
-
T &operator[](std::size_t idx) { return c_[idx]; }
const T &operator[](std::size_t idx) const { return c_[idx]; }
@@ -90,6 +60,39 @@ template <class T> class scoped_array {
private:
T *c_;
+
+ scoped_array(const scoped_array &);
+ void operator=(const scoped_array &);
+};
+
+template <class T> class scoped_ptr {
+ public:
+ explicit scoped_ptr(T *content = NULL) : c_(content) {}
+
+ ~scoped_ptr() { delete c_; }
+
+ T *get() { return c_; }
+ const T* get() const { return c_; }
+
+ T &operator*() { return *c_; }
+ const T&operator*() const { return *c_; }
+
+ T *operator->() { return c_; }
+ const T*operator->() const { return c_; }
+
+ T &operator[](std::size_t idx) { return c_[idx]; }
+ const T &operator[](std::size_t idx) const { return c_[idx]; }
+
+ void reset(T *to = NULL) {
+ scoped_ptr<T> other(c_);
+ c_ = to;
+ }
+
+ private:
+ T *c_;
+
+ scoped_ptr(const scoped_ptr &);
+ void operator=(const scoped_ptr &);
};
} // namespace util
diff --git a/klm/util/string_piece.hh b/klm/util/string_piece.hh
index be6a643d..51481646 100644
--- a/klm/util/string_piece.hh
+++ b/klm/util/string_piece.hh
@@ -1,6 +1,6 @@
/* If you use ICU in your program, then compile with -DHAVE_ICU -licui18n. If
* you don't use ICU, then this will use the Google implementation from Chrome.
- * This has been modified from the original version to let you choose.
+ * This has been modified from the original version to let you choose.
*/
// Copyright 2008, Google Inc.
@@ -62,9 +62,9 @@
#include <unicode/stringpiece.h>
#include <unicode/uversion.h>
-// Old versions of ICU don't define operator== and operator!=.
+// Old versions of ICU don't define operator== and operator!=.
#if (U_ICU_VERSION_MAJOR_NUM < 4) || ((U_ICU_VERSION_MAJOR_NUM == 4) && (U_ICU_VERSION_MINOR_NUM < 4))
-#warning You are using an old version of ICU. Consider upgrading to ICU >= 4.6.
+#warning You are using an old version of ICU. Consider upgrading to ICU >= 4.6.
inline bool operator==(const StringPiece& x, const StringPiece& y) {
if (x.size() != y.size())
return false;
@@ -274,15 +274,28 @@ struct StringPieceCompatibleEquals : public std::binary_function<const StringPie
}
};
template <class T> typename T::const_iterator FindStringPiece(const T &t, const StringPiece &key) {
+#if BOOST_VERSION < 104200
+ std::string temp(key.data(), key.size());
+ return t.find(temp);
+#else
return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals());
+#endif
}
+
template <class T> typename T::iterator FindStringPiece(T &t, const StringPiece &key) {
+#if BOOST_VERSION < 104200
+ std::string temp(key.data(), key.size());
+ return t.find(temp);
+#else
return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals());
+#endif
}
#endif
#ifdef HAVE_ICU
U_NAMESPACE_END
+using U_NAMESPACE_QUALIFIER StringPiece;
#endif
+
#endif // BASE_STRING_PIECE_H__
diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh
index 4a7f5460..a588c3fc 100644
--- a/klm/util/tokenize_piece.hh
+++ b/klm/util/tokenize_piece.hh
@@ -20,6 +20,7 @@ class OutOfTokens : public Exception {
class SingleCharacter {
public:
+ SingleCharacter() {}
explicit SingleCharacter(char delim) : delim_(delim) {}
StringPiece Find(const StringPiece &in) const {
@@ -32,6 +33,8 @@ class SingleCharacter {
class MultiCharacter {
public:
+ MultiCharacter() {}
+
explicit MultiCharacter(const StringPiece &delimiter) : delimiter_(delimiter) {}
StringPiece Find(const StringPiece &in) const {
@@ -44,6 +47,7 @@ class MultiCharacter {
class AnyCharacter {
public:
+ AnyCharacter() {}
explicit AnyCharacter(const StringPiece &chars) : chars_(chars) {}
StringPiece Find(const StringPiece &in) const {
@@ -56,6 +60,8 @@ class AnyCharacter {
class AnyCharacterLast {
public:
+ AnyCharacterLast() {}
+
explicit AnyCharacterLast(const StringPiece &chars) : chars_(chars) {}
StringPiece Find(const StringPiece &in) const {
@@ -81,8 +87,8 @@ template <class Find, bool SkipEmpty = false> class TokenIter : public boost::it
return current_.data() != 0;
}
- static TokenIter<Find> end() {
- return TokenIter<Find>();
+ static TokenIter<Find, SkipEmpty> end() {
+ return TokenIter<Find, SkipEmpty>();
}
private:
@@ -100,8 +106,8 @@ template <class Find, bool SkipEmpty = false> class TokenIter : public boost::it
} while (SkipEmpty && current_.data() && current_.empty()); // Compiler should optimize this away if SkipEmpty is false.
}
- bool equal(const TokenIter<Find> &other) const {
- return after_.data() == other.after_.data();
+ bool equal(const TokenIter<Find, SkipEmpty> &other) const {
+ return current_.data() == other.current_.data();
}
const StringPiece &dereference() const {