diff options
| author | Michael Denkowski <michael.j.denkowski@gmail.com> | 2012-12-22 16:01:23 -0500 | 
|---|---|---|
| committer | Michael Denkowski <michael.j.denkowski@gmail.com> | 2012-12-22 16:01:23 -0500 | 
| commit | 778a4cec55f82bcc66b3f52de7cc871e8daaeb92 (patch) | |
| tree | 2a5bccaa85965855104c4e8ac3738b2e1c77f164 /klm | |
| parent | 57fff9eea5ba0e71fb958fdb4f32d17f2fe31108 (diff) | |
| parent | d21491daa5e50b4456c7c5f9c2e51d25afd2a757 (diff) | |
Merge branch 'master' of git://github.com/redpony/cdec
Diffstat (limited to 'klm')
48 files changed, 1438 insertions, 610 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index efa67056..39c4a9b6 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -16,11 +16,11 @@ namespace ngram {  namespace {  const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";  const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0"; -// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed).  +// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed).  const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n";  const long int kMagicVersion = 5; -// Old binary files built on 32-bit machines have this header.   +// Old binary files built on 32-bit machines have this header.  // TODO: eliminate with next binary release.  struct OldSanity {    char magic[sizeof(kMagicBytes)]; @@ -39,7 +39,7 @@ struct OldSanity {  }; -// Test values aligned to 8 bytes.     +// Test values aligned to 8 bytes.  struct Sanity {    char magic[ALIGN8(sizeof(kMagicBytes))];    float zero_f, one_f, minus_half_f; @@ -101,7 +101,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_  uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) {    std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;    if (config.write_mmap) { -    // Grow the file to accomodate the search, using zeros.   +    // Grow the file to accomodate the search, using zeros.      try {        util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size);      } catch (util::ErrnoException &e) { @@ -114,7 +114,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t        return reinterpret_cast<uint8_t*>(backing.search.get());      }      // mmap it now. -    // We're skipping over the header and vocab for the search space mmap.  mmap likes page aligned offsets, so some arithmetic to round the offset down.   +    // We're skipping over the header and vocab for the search space mmap.  mmap likes page aligned offsets, so some arithmetic to round the offset down.      std::size_t page_size = util::SizePage();      std::size_t alignment_cruft = adjusted_vocab % page_size;      backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); @@ -122,7 +122,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t    } else {      util::MapAnonymous(memory_size, backing.search);      return reinterpret_cast<uint8_t*>(backing.search.get()); -  }  +  }  }  void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) { @@ -140,7 +140,7 @@ void FinishFile(const Config &config, ModelType model_type, unsigned int search_        util::FSyncOrThrow(backing.file.get());        break;    } -  // header and vocab share the same mmap.  The header is written here because we know the counts.   +  // header and vocab share the same mmap.  The header is written here because we know the counts.    Parameters params = Parameters();    params.counts = counts;    params.fixed.order = counts.size(); @@ -160,7 +160,7 @@ namespace detail {  bool IsBinaryFormat(int fd) {    const uint64_t size = util::SizeFile(fd);    if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false; -  // Try reading the header.   +  // Try reading the header.    util::scoped_memory memory;    try {      util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory); @@ -214,7 +214,7 @@ void SeekPastHeader(int fd, const Parameters ¶ms) {  uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing) {    const uint64_t file_size = util::SizeFile(backing.file.get()); -  // The header is smaller than a page, so we have to map the whole header as well.   +  // The header is smaller than a page, so we have to map the whole header as well.    std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size);    if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)      UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); @@ -233,7 +233,8 @@ void ComplainAboutARPA(const Config &config, ModelType model_type) {    if (config.write_mmap || !config.messages) return;    if (config.arpa_complain == Config::ALL) {      *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; -  } else if (config.arpa_complain == Config::EXPENSIVE && model_type == TRIE_SORTED) { +  } else if (config.arpa_complain == Config::EXPENSIVE && +             (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {      *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive.  Save time by building a binary format." << std::endl;    }  } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index f9d988ca..9520c41c 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -6,6 +6,7 @@ namespace lm {  namespace ngram {  Config::Config() : +  show_progress(true),    messages(&std::cerr),    enumerate_vocab(NULL),    unknown_missing(COMPLAIN), diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 739cee9c..0de7b7c6 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -11,46 +11,52 @@  /* Configuration for ngram model.  Separate header to reduce pollution. */  namespace lm { -   +  class EnumerateVocab;  namespace ngram {  struct Config { -  // EFFECTIVE FOR BOTH ARPA AND BINARY READS  +  // EFFECTIVE FOR BOTH ARPA AND BINARY READS + +  // (default true) print progress bar to messages +  bool show_progress;    // Where to log messages including the progress bar.  Set to NULL for    // silence.    std::ostream *messages; +  std::ostream *ProgressMessages() const { +    return show_progress ? messages : 0; +  } +    // This will be called with every string in the vocabulary.  See    // enumerate_vocab.hh for more detail.  Config does not take ownership; you -  // are still responsible for deleting it (or stack allocating).   +  // are still responsible for deleting it (or stack allocating).    EnumerateVocab *enumerate_vocab; -    // ONLY EFFECTIVE WHEN READING ARPA -  // What to do when <unk> isn't in the provided model.  +  // What to do when <unk> isn't in the provided model.    WarningAction unknown_missing; -  // What to do when <s> or </s> is missing from the model.  -  // If THROW_UP, the exception will be of type util::SpecialWordMissingException.   +  // What to do when <s> or </s> is missing from the model. +  // If THROW_UP, the exception will be of type util::SpecialWordMissingException.    WarningAction sentence_marker_missing;    // What to do with a positive log probability.  For COMPLAIN and SILENT, map -  // to 0.   +  // to 0.    WarningAction positive_log_probability; -  // The probability to substitute for <unk> if it's missing from the model.   +  // The probability to substitute for <unk> if it's missing from the model.    // No effect if the model has <unk> or unknown_missing == THROW_UP.    float unknown_missing_logprob;    // Size multiplier for probing hash table.  Must be > 1.  Space is linear in    // this.  Time is probing_multiplier / (probing_multiplier - 1).  No effect -  // for sorted variant.   +  // for sorted variant.    // If you find yourself setting this to a low number, consider using the -  // TrieModel which has lower memory consumption.   +  // TrieModel which has lower memory consumption.    float probing_multiplier;    // Amount of memory to use for building.  The actual memory usage will be @@ -58,10 +64,10 @@ struct Config {    // models.    std::size_t building_memory; -  // Template for temporary directory appropriate for passing to mkdtemp.   +  // Template for temporary directory appropriate for passing to mkdtemp.    // The characters XXXXXX are appended before passing to mkdtemp.  Only    // applies to trie.  If NULL, defaults to write_mmap.  If that's NULL, -  // defaults to input file name.   +  // defaults to input file name.    const char *temporary_directory_prefix;    // Level of complaining to do when loading from ARPA instead of binary format. @@ -69,49 +75,46 @@ struct Config {    ARPALoadComplain arpa_complain;    // While loading an ARPA file, also write out this binary format file.  Set -  // to NULL to disable.   +  // to NULL to disable.    const char *write_mmap;    enum WriteMethod { -    WRITE_MMAP, // Map the file directly.   -    WRITE_AFTER // Write after we're done.   +    WRITE_MMAP, // Map the file directly. +    WRITE_AFTER // Write after we're done.    };    WriteMethod write_method; -  // Include the vocab in the binary file?  Only effective if write_mmap != NULL.   +  // Include the vocab in the binary file?  Only effective if write_mmap != NULL.    bool include_vocab; -  // Left rest options.  Only used when the model includes rest costs.   +  // Left rest options.  Only used when the model includes rest costs.    enum RestFunction {      REST_MAX,   // Maximum of any score to the left -    REST_LOWER, // Use lower-order files given below.   +    REST_LOWER, // Use lower-order files given below.    };    RestFunction rest_function; -  // Only used for REST_LOWER.   +  // Only used for REST_LOWER.    std::vector<std::string> rest_lower_files; -    // Quantization options.  Only effective for QuantTrieModel.  One value is    // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used -  // to quantize (and one of the remaining backoffs will be 0).   +  // to quantize (and one of the remaining backoffs will be 0).    uint8_t prob_bits, backoff_bits;    // Bhiksha compression (simple form).  Only works with trie.    uint8_t pointer_bhiksha_bits; -   -   +    // ONLY EFFECTIVE WHEN READING BINARY -   +    // How to get the giant array into memory: lazy mmap, populate, read etc. -  // See util/mmap.hh for details of MapMethod.   +  // See util/mmap.hh for details of MapMethod.    util::LoadMethod load_method; - -  // Set defaults.  +  // Set defaults.    Config();  }; diff --git a/klm/lm/left.hh b/klm/lm/left.hh index 8c27232e..85c1ea37 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -51,36 +51,36 @@ namespace ngram {  template <class M> class RuleScore {    public: -    explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), prob_(0.0) { +    explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(&out), left_done_(false), prob_(0.0) {        out.left.length = 0;        out.right.length = 0;      }      void BeginSentence() { -      out_.right = model_.BeginSentenceState(); -      // out_.left is empty. +      out_->right = model_.BeginSentenceState(); +      // out_->left is empty.        left_done_ = true;      }      void Terminal(WordIndex word) { -      State copy(out_.right); -      FullScoreReturn ret(model_.FullScore(copy, word, out_.right)); +      State copy(out_->right); +      FullScoreReturn ret(model_.FullScore(copy, word, out_->right));        if (left_done_) { prob_ += ret.prob; return; }        if (ret.independent_left) {          prob_ += ret.prob;          left_done_ = true;          return;        } -      out_.left.pointers[out_.left.length++] = ret.extend_left; +      out_->left.pointers[out_->left.length++] = ret.extend_left;        prob_ += ret.rest; -      if (out_.right.length != copy.length + 1) +      if (out_->right.length != copy.length + 1)          left_done_ = true;      }      // Faster version of NonTerminal for the case where the rule begins with a non-terminal.        void BeginNonTerminal(const ChartState &in, float prob = 0.0) {        prob_ = prob; -      out_ = in; +      *out_ = in;        left_done_ = in.left.full;      } @@ -89,23 +89,23 @@ template <class M> class RuleScore {        if (!in.left.length) {          if (in.left.full) { -          for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i; +          for (const float *i = out_->right.backoff; i < out_->right.backoff + out_->right.length; ++i) prob_ += *i;            left_done_ = true; -          out_.right = in.right; +          out_->right = in.right;          }          return;        } -      if (!out_.right.length) { -        out_.right = in.right; +      if (!out_->right.length) { +        out_->right = in.right;          if (left_done_) {            prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1);            return;          } -        if (out_.left.length) { +        if (out_->left.length) {            left_done_ = true;          } else { -          out_.left = in.left; +          out_->left = in.left;            left_done_ = in.left.full;          }          return; @@ -113,10 +113,10 @@ template <class M> class RuleScore {        float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1];        float *back = backoffs, *back2 = backoffs2; -      unsigned char next_use = out_.right.length; +      unsigned char next_use = out_->right.length;        // First word -      if (ExtendLeft(in, next_use, 1, out_.right.backoff, back)) return; +      if (ExtendLeft(in, next_use, 1, out_->right.backoff, back)) return;        // Words after the first, so extending a bigram to begin with        for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) { @@ -127,54 +127,58 @@ template <class M> class RuleScore {        if (in.left.full) {          for (const float *i = back; i != back + next_use; ++i) prob_ += *i;          left_done_ = true; -        out_.right = in.right; +        out_->right = in.right;          return;        }        // Right state was minimized, so it's already independent of the new words to the left.          if (in.right.length < in.left.length) { -        out_.right = in.right; +        out_->right = in.right;          return;        }        // Shift exisiting words down.   -      for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) { +      for (WordIndex *i = out_->right.words + next_use - 1; i >= out_->right.words; --i) {          *(i + in.right.length) = *i;        }        // Add words from in.right.   -      std::copy(in.right.words, in.right.words + in.right.length, out_.right.words); +      std::copy(in.right.words, in.right.words + in.right.length, out_->right.words);        // Assemble backoff composed on the existing state's backoff followed by the new state's backoff.   -      std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff); -      std::copy(back, back + next_use, out_.right.backoff + in.right.length); -      out_.right.length = in.right.length + next_use; +      std::copy(in.right.backoff, in.right.backoff + in.right.length, out_->right.backoff); +      std::copy(back, back + next_use, out_->right.backoff + in.right.length); +      out_->right.length = in.right.length + next_use;      }      float Finish() {        // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram.   -      out_.left.full = left_done_ || (out_.left.length == model_.Order() - 1); +      out_->left.full = left_done_ || (out_->left.length == model_.Order() - 1);        return prob_;      }      void Reset() {        prob_ = 0.0;        left_done_ = false; -      out_.left.length = 0; -      out_.right.length = 0; +      out_->left.length = 0; +      out_->right.length = 0; +    } +    void Reset(ChartState &replacement) { +      out_ = &replacement; +      Reset();      }    private:      bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) {        ProcessRet(model_.ExtendLeft( -            out_.right.words, out_.right.words + next_use, // Words to extend into +            out_->right.words, out_->right.words + next_use, // Words to extend into              back_in, // Backoffs to use              in.left.pointers[extend_length - 1], extend_length, // Words to be extended              back_out, // Backoffs for the next score              next_use)); // Length of n-gram to use in next scoring.   -      if (next_use != out_.right.length) { +      if (next_use != out_->right.length) {          left_done_ = true;          if (!next_use) {            // Early exit.   -          out_.right = in.right; +          out_->right = in.right;            prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1);            return true;          } @@ -193,13 +197,13 @@ template <class M> class RuleScore {          left_done_ = true;          return;        } -      out_.left.pointers[out_.left.length++] = ret.extend_left; +      out_->left.pointers[out_->left.length++] = ret.extend_left;        prob_ += ret.rest;      }      const M &model_; -    ChartState &out_; +    ChartState *out_;      bool left_done_; diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh index 989f8324..3eb97ccd 100644 --- a/klm/lm/max_order.hh +++ b/klm/lm/max_order.hh @@ -4,9 +4,6 @@   * (kMaxOrder - 1) * sizeof(float) bytes instead of   * sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead   */ -#ifndef KENLM_MAX_ORDER -#define KENLM_MAX_ORDER 6 -#endif  #ifndef KENLM_ORDER_MESSAGE -#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile.  In the KenLM tarball or Moses, use e.g. `bjam --kenlm-max-order=6 -a'.  Otherwise, edit lm/max_order.hh." +#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile.  In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'.  Otherwise, edit lm/max_order.hh."  #endif diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 2fd20481..a40fd2fb 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -37,7 +37,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT  template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) {    LoadLM(file, config, *this); -  // g++ prints warnings unless these are fully initialized.   +  // g++ prints warnings unless these are fully initialized.    State begin_sentence = State();    begin_sentence.length = 1;    begin_sentence.words[0] = vocab_.BeginSentence(); @@ -69,8 +69,8 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT  }  template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) { -  // Backing file is the ARPA.  Steal it so we can make the backing file the mmap output if any.   -  util::FilePiece f(backing_.file.release(), file, config.messages); +  // Backing file is the ARPA.  Steal it so we can make the backing file the mmap output if any. +  util::FilePiece f(backing_.file.release(), file, config.ProgressMessages());    try {      std::vector<uint64_t> counts;      // File counts do not include pruned trigrams that extend to quadgrams etc.   These will be fixed by search_. @@ -80,14 +80,14 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT      if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");      std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config)); -    // Setup the binary file for writing the vocab lookup table.  The search_ is responsible for growing the binary file to its needs.   +    // Setup the binary file for writing the vocab lookup table.  The search_ is responsible for growing the binary file to its needs.      vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);      if (config.write_mmap) {        WriteWordsWrapper wrap(config.enumerate_vocab);        vocab_.ConfigureEnumerate(&wrap, counts[0]);        search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); -      wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + backing_.search.size()); +      wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config));      } else {        vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);        search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); @@ -95,7 +95,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT      if (!vocab_.SawUnk()) {        assert(config.unknown_missing != THROW_UP); -      // Default probabilities for unknown.   +      // Default probabilities for unknown.        search_.UnknownUnigram().backoff = 0.0;        search_.UnknownUnigram().prob = config.unknown_missing_logprob;      } @@ -147,7 +147,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,  }  template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const { -  // Generate a state from context.   +  // Generate a state from context.    context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);    if (context_rend == context_rbegin) {      out_state.length = 0; @@ -191,7 +191,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,      ret.rest = ptr.Rest();      ret.prob = ptr.Prob();      ret.extend_left = extend_pointer; -    // If this function is called, then it does depend on left words.    +    // If this function is called, then it does depend on left words.      ret.independent_left = false;    }    float subtract_me = ret.rest; @@ -199,7 +199,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,    next_use = extend_length;    ResumeScore(add_rbegin, add_rend, extend_length - 1, node, backoff_out, next_use, ret);    next_use -= extend_length; -  // Charge backoffs.   +  // Charge backoffs.    for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b;    ret.prob -= subtract_me;    ret.rest -= subtract_me; @@ -209,7 +209,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,  namespace {  // Do a paraonoid copy of history, assuming new_word has already been copied  // (hence the -1).  out_state.length could be zero so I avoided using -// std::copy.    +// std::copy.  void CopyRemainingHistory(const WordIndex *from, State &out_state) {    WordIndex *out = out_state.words + 1;    const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.length) - 1; @@ -217,18 +217,19 @@ void CopyRemainingHistory(const WordIndex *from, State &out_state) {  }  } // namespace -/* Ugly optimized function.  Produce a score excluding backoff.   - * The search goes in increasing order of ngram length.   +/* Ugly optimized function.  Produce a score excluding backoff. + * The search goes in increasing order of ngram length.   * Context goes backward, so context_begin is the word immediately preceeding - * new_word.   + * new_word.   */  template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff(      const WordIndex *const context_rbegin,      const WordIndex *const context_rend,      const WordIndex new_word,      State &out_state) const { +  assert(new_word < vocab_.Bound());    FullScoreReturn ret; -  // ret.ngram_length contains the last known non-blank ngram length.   +  // ret.ngram_length contains the last known non-blank ngram length.    ret.ngram_length = 1;    typename Search::Node node; @@ -237,9 +238,9 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,    ret.prob = uni.Prob();    ret.rest = uni.Rest(); -  // This is the length of the context that should be used for continuation to the right.   +  // This is the length of the context that should be used for continuation to the right.    out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0; -  // We'll write the word anyway since it will probably be used and does no harm being there.   +  // We'll write the word anyway since it will probably be used and does no harm being there.    out_state.words[0] = new_word;    if (context_rbegin == context_rend) return ret; diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index a1623834..2d6f15b2 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -231,7 +231,7 @@ template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char *  template <> void HashedSearch<BackoffValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {    NoRestBuild build; -  ApplyBuild(f, counts, config, vocab, warn, build); +  ApplyBuild(f, counts, vocab, warn, build);  }  template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) { @@ -239,19 +239,19 @@ template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, cons      case Config::REST_MAX:        {          MaxRestBuild build; -        ApplyBuild(f, counts, config, vocab, warn, build); +        ApplyBuild(f, counts, vocab, warn, build);        }        break;      case Config::REST_LOWER:        {          LowerRestBuild<ProbingModel> build(config, counts.size(), vocab); -        ApplyBuild(f, counts, config, vocab, warn, build); +        ApplyBuild(f, counts, vocab, warn, build);        }        break;    }  } -template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) { +template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) {    for (WordIndex i = 0; i < counts[0]; ++i) {      build.SetRest(&i, (unsigned int)1, unigram_.Raw()[i]);    } diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index a52f107b..00595796 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -147,7 +147,7 @@ template <class Value> class HashedSearch {      // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.        void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn); -    template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); +    template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build);      class Unigram {        public: diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index debcfd07..1b0d9b26 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -55,7 +55,7 @@ struct ProbPointer {    uint64_t index;  }; -// Array of n-grams and float indices.   +// Array of n-grams and float indices.  class BackoffMessages {    public:      void Init(std::size_t entry_size) { @@ -100,7 +100,7 @@ class BackoffMessages {      void Apply(float *const *const base, RecordReader &reader) {        FinishedAdding();        if (current_ == allocated_) return; -      // We'll also use the same buffer to record messages to blanks that they extend.   +      // We'll also use the same buffer to record messages to blanks that they extend.        WordIndex *extend_out = reinterpret_cast<WordIndex*>(current_);        const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex);        for (reader.Rewind(); reader && (current_ != allocated_); ) { @@ -109,7 +109,7 @@ class BackoffMessages {              ++reader;              break;            case 1: -            // Message but nobody to receive it.  Write it down at the beginning of the buffer so we can inform this blank that it extends.   +            // Message but nobody to receive it.  Write it down at the beginning of the buffer so we can inform this blank that it extends.              for (const WordIndex *w = reinterpret_cast<const WordIndex *>(current_); w != reinterpret_cast<const WordIndex *>(current_) + order; ++w, ++extend_out) *extend_out = *w;              current_ += entry_size_;              break; @@ -126,7 +126,7 @@ class BackoffMessages {              break;          }        } -      // Now this is a list of blanks that extend right.   +      // Now this is a list of blanks that extend right.        entry_size_ = sizeof(WordIndex) * order;        Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get()));        current_ = (uint8_t*)backing_.get(); @@ -153,7 +153,7 @@ class BackoffMessages {    private:      void FinishedAdding() {        Resize(current_ - (uint8_t*)backing_.get()); -      // Sort requests in same order as files.   +      // Sort requests in same order as files.        std::sort(            util::SizedIterator(util::SizedProxy(backing_.get(), entry_size_)),            util::SizedIterator(util::SizedProxy(current_, entry_size_)), @@ -220,7 +220,7 @@ class SRISucks {      }    private: -    // This used to be one array.  Then I needed to separate it by order for quantization to work.   +    // This used to be one array.  Then I needed to separate it by order for quantization to work.      std::vector<float> values_[KENLM_MAX_ORDER - 1];      BackoffMessages messages_[KENLM_MAX_ORDER - 1]; @@ -253,7 +253,7 @@ class FindBlanks {        ++counts_.back();      } -    // Unigrams wrote one past.   +    // Unigrams wrote one past.      void Cleanup() {        --counts_[0];      } @@ -270,15 +270,15 @@ class FindBlanks {      SRISucks &sri_;  }; -// Phase to actually write n-grams to the trie.   +// Phase to actually write n-grams to the trie.  template <class Quant, class Bhiksha> class WriteEntries {    public: -    WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) :  +    WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) :        contexts_(contexts),        quant_(quant),        unigrams_(unigrams),        middle_(middle), -      longest_(longest),  +      longest_(longest),        bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)),        order_(order),        sri_(sri) {} @@ -328,7 +328,7 @@ struct Gram {    const WordIndex *begin, *end; -  // For queue, this is the direction we want.   +  // For queue, this is the direction we want.    bool operator<(const Gram &other) const {      return std::lexicographical_compare(other.begin, other.end, begin, end);    } @@ -353,7 +353,7 @@ template <class Doing> class BlankManager {          been_length_ = length;          return;        } -      // There are blanks to insert starting with order blank.   +      // There are blanks to insert starting with order blank.        unsigned char blank = cur - to + 1;        UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context.");        const float *lower_basis; @@ -363,7 +363,7 @@ template <class Doing> class BlankManager {          assert(*lower_basis != kBadProb);          doing_.MiddleBlank(blank, to, based_on, *lower_basis);          *pre = *cur; -        // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram.   +        // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram.          basis_[blank - 1] = kBadProb;        }        *pre = *cur; @@ -377,7 +377,7 @@ template <class Doing> class BlankManager {      unsigned char been_length_;      float basis_[KENLM_MAX_ORDER]; -     +      Doing &doing_;  }; @@ -451,7 +451,7 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, Re  }  void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) { -  // Fill unigram probabilities.   +  // Fill unigram probabilities.    try {      rewind(file);      for (WordIndex i = 0; i < unigram_count; ++i) { @@ -486,7 +486,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve      util::scoped_memory unigrams;      MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams);      FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri); -    RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); +    RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Identifying n-grams omitted by SRI", finder);      fixed_counts = finder.Counts();    }    unigram_file.reset(util::FDOpenOrThrow(unigram_fd)); @@ -504,7 +504,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve      inputs[i-2].Rewind();    }    if (Quant::kTrain) { -    util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), config.messages, "Quantizing"); +    util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), +                                  config.ProgressMessages(), "Quantizing");      for (unsigned char i = 2; i < counts.size(); ++i) {        TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant);      } @@ -519,13 +520,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve    for (unsigned char i = 2; i <= counts.size(); ++i) {      inputs[i-2].Rewind();    } -  // Fill entries except unigram probabilities.   +  // Fill entries except unigram probabilities.    {      WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri); -    RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); +    RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer);    } -  // Do not disable this error message or else too little state will be returned.  Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.    +  // Do not disable this error message or else too little state will be returned.  Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.    for (unsigned char order = 2; order <= counts.size(); ++order) {      const RecordReader &context = contexts[order - 2];      if (context) { @@ -541,13 +542,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve    }    /* Set ending offsets so the last entry will be sized properly */ -  // Last entry for unigrams was already set.   +  // Last entry for unigrams was already set.    if (out.middle_begin_ != out.middle_end_) {      for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {        i->FinishedLoading((i+1)->InsertIndex(), config);      }      (out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config); -  }   +  }  }  template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { @@ -595,7 +596,7 @@ template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::Initializ    } else {      temporary_prefix = file;    } -  // At least 1MB sorting memory.   +  // At least 1MB sorting memory.    SortedFiles sorted(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_prefix, vocab);    BuildTrie(sorted, counts, config, *this, quant_, vocab, backing); diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 11c27518..fd7f96dc 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -116,7 +116,9 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {    }    *end_ = hashed;    if (enumerate_) { -    strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size()); +    void *copied = string_backing_.Allocate(str.size()); +    memcpy(copied, str.data(), str.size()); +    strings_to_enumerate_[end_ - begin_] = StringPiece(static_cast<const char*>(copied), str.size());    }    ++end_;    // This is 1 + the offset where it was inserted to make room for unk.   @@ -126,7 +128,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {  void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {    if (enumerate_) {      if (!strings_to_enumerate_.empty()) { -      util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); +      util::PairedIterator<ProbBackoff*, StringPiece*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());        util::JointSort(begin_, end_, values);      }      for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) { @@ -134,6 +136,7 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {        enumerate_->Add(i + 1, strings_to_enumerate_[i]);      }      strings_to_enumerate_.clear(); +    string_backing_.FreeAll();    } else {      util::JointSort(begin_, end_, reorder_vocab + 1);    } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index de54eb06..3902f117 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -4,6 +4,7 @@  #include "lm/enumerate_vocab.hh"  #include "lm/lm_exception.hh"  #include "lm/virtual_interface.hh" +#include "util/pool.hh"  #include "util/probing_hash_table.hh"  #include "util/sorted_uniform.hh"  #include "util/string_piece.hh" @@ -96,7 +97,9 @@ class SortedVocabulary : public base::Vocabulary {      EnumerateVocab *enumerate_;      // Actual strings.  Used only when loading from ARPA and enumerate_ != NULL  -    std::vector<std::string> strings_to_enumerate_; +    util::Pool string_backing_; + +    std::vector<StringPiece> strings_to_enumerate_;  };  #pragma pack(push) diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am index ccc5b7f6..5aea33c2 100644 --- a/klm/search/Makefile.am +++ b/klm/search/Makefile.am @@ -2,10 +2,10 @@ noinst_LIBRARIES = libksearch.a  libksearch_a_SOURCES = \    edge_generator.cc \ +	nbest.cc \    rule.cc \    vertex.cc \ -  vertex_generator.cc \ -  weights.cc +  vertex_generator.cc  AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/search/applied.hh b/klm/search/applied.hh new file mode 100644 index 00000000..bd659e5c --- /dev/null +++ b/klm/search/applied.hh @@ -0,0 +1,86 @@ +#ifndef SEARCH_APPLIED__ +#define SEARCH_APPLIED__ + +#include "search/edge.hh" +#include "search/header.hh" +#include "util/pool.hh" + +#include <math.h> + +namespace search { + +// A full hypothesis: a score, arity of the rule, a pointer to the decoder's rule (Note), and pointers to non-terminals that were substituted.   +template <class Below> class GenericApplied : public Header { +  public: +    GenericApplied() {} + +    GenericApplied(void *location, PartialEdge partial)  +      : Header(location) { +      memcpy(Base(), partial.Base(), kHeaderSize); +      Below *child_out = Children(); +      const PartialVertex *part = partial.NT(); +      const PartialVertex *const part_end_loop = part + partial.GetArity(); +      for (; part != part_end_loop; ++part, ++child_out) +        *child_out = Below(part->End()); +    } +     +    GenericApplied(void *location, Score score, Arity arity, Note note) : Header(location, arity) { +      SetScore(score); +      SetNote(note); +    } + +    explicit GenericApplied(History from) : Header(from) {} + + +    // These are arrays of length GetArity(). +    Below *Children() { +      return reinterpret_cast<Below*>(After()); +    } +    const Below *Children() const { +      return reinterpret_cast<const Below*>(After()); +    } + +    static std::size_t Size(Arity arity) { +      return kHeaderSize + arity * sizeof(const Below); +    } +}; + +// Applied rule that references itself.   +class Applied : public GenericApplied<Applied> { +  private: +    typedef GenericApplied<Applied> P; + +  public: +    Applied() {} +    Applied(void *location, PartialEdge partial) : P(location, partial) {} +    Applied(History from) : P(from) {} +}; + +// How to build single-best hypotheses.   +class SingleBest { +  public: +    typedef PartialEdge Combine; + +    void Add(PartialEdge &existing, PartialEdge add) const { +      if (!existing.Valid() || existing.GetScore() < add.GetScore()) +        existing = add; +    } + +    NBestComplete Complete(PartialEdge partial) { +      if (!partial.Valid())  +        return NBestComplete(NULL, lm::ngram::ChartState(), -INFINITY); +      void *place_final = pool_.Allocate(Applied::Size(partial.GetArity())); +      Applied(place_final, partial); +      return NBestComplete( +          place_final, +          partial.CompletedState(), +          partial.GetScore()); +    } + +  private: +    util::Pool pool_; +}; + +} // namespace search + +#endif // SEARCH_APPLIED__ diff --git a/klm/search/config.hh b/klm/search/config.hh index ef8e2354..ba18c09e 100644 --- a/klm/search/config.hh +++ b/klm/search/config.hh @@ -1,23 +1,36 @@  #ifndef SEARCH_CONFIG__  #define SEARCH_CONFIG__ -#include "search/weights.hh" -#include "util/string_piece.hh" +#include "search/types.hh"  namespace search { +struct NBestConfig { +  explicit NBestConfig(unsigned int in_size) { +    keep = in_size; +    size = in_size; +  } +   +  unsigned int keep, size; +}; +  class Config {    public: -    Config(const Weights &weights, unsigned int pop_limit) : -      weights_(weights), pop_limit_(pop_limit) {} +    Config(Score lm_weight, unsigned int pop_limit, const NBestConfig &nbest) : +      lm_weight_(lm_weight), pop_limit_(pop_limit), nbest_(nbest) {} -    const Weights &GetWeights() const { return weights_; } +    Score LMWeight() const { return lm_weight_; }      unsigned int PopLimit() const { return pop_limit_; } +    const NBestConfig &GetNBest() const { return nbest_; } +    private: -    Weights weights_; +    Score lm_weight_; +      unsigned int pop_limit_; + +    NBestConfig nbest_;  };  } // namespace search diff --git a/klm/search/context.hh b/klm/search/context.hh index 62163144..08f21bbf 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -1,30 +1,16 @@  #ifndef SEARCH_CONTEXT__  #define SEARCH_CONTEXT__ -#include "lm/model.hh"  #include "search/config.hh" -#include "search/final.hh" -#include "search/types.hh"  #include "search/vertex.hh" -#include "util/exception.hh" -#include "util/pool.hh"  #include <boost/pool/object_pool.hpp> -#include <boost/ptr_container/ptr_vector.hpp> - -#include <vector>  namespace search { -class Weights; -  class ContextBase {    public: -    explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} - -    util::Pool &FinalPool() { -      return final_pool_; -    } +    explicit ContextBase(const Config &config) : config_(config) {}      VertexNode *NewVertexNode() {        VertexNode *ret = vertex_node_pool_.construct(); @@ -36,18 +22,16 @@ class ContextBase {        vertex_node_pool_.destroy(node);      } -    unsigned int PopLimit() const { return pop_limit_; } +    unsigned int PopLimit() const { return config_.PopLimit(); } -    const Weights &GetWeights() const { return weights_; } +    Score LMWeight() const { return config_.LMWeight(); } -  private: -    util::Pool final_pool_; +    const Config &GetConfig() const { return config_; } +  private:      boost::object_pool<VertexNode> vertex_node_pool_; -    unsigned int pop_limit_; - -    const Weights &weights_; +    Config config_;  };  template <class Model> class Context : public ContextBase { diff --git a/klm/search/dedupe.hh b/klm/search/dedupe.hh new file mode 100644 index 00000000..7eaa3b95 --- /dev/null +++ b/klm/search/dedupe.hh @@ -0,0 +1,131 @@ +#ifndef SEARCH_DEDUPE__ +#define SEARCH_DEDUPE__ + +#include "lm/state.hh" +#include "search/edge_generator.hh" + +#include <boost/pool/object_pool.hpp> +#include <boost/unordered_map.hpp> + +namespace search { + +class Dedupe { +  public: +    Dedupe() {} + +    PartialEdge AllocateEdge(Arity arity) { +      return behind_.AllocateEdge(arity); +    } + +    void AddEdge(PartialEdge edge) { +      edge.MutableFlags() = 0; + +      uint64_t hash = 0; +      const PartialVertex *v = edge.NT(); +      const PartialVertex *v_end = v + edge.GetArity(); +      for (; v != v_end; ++v) { +        const void *ptr = v->Identify(); +        hash = util::MurmurHashNative(&ptr, sizeof(const void*), hash); +      } +       +      const lm::ngram::ChartState *c = edge.Between(); +      const lm::ngram::ChartState *const c_end = c + edge.GetArity() + 1; +      for (; c != c_end; ++c) hash = hash_value(*c, hash); + +      std::pair<Table::iterator, bool> ret(table_.insert(std::make_pair(hash, edge))); +      if (!ret.second) FoundDupe(ret.first->second, edge); +    } + +    bool Empty() const { return behind_.Empty(); } + +    template <class Model, class Output> void Search(Context<Model> &context, Output &output) { +      for (Table::const_iterator i(table_.begin()); i != table_.end(); ++i) { +        behind_.AddEdge(i->second); +      } +      Unpack<Output> unpack(output, *this); +      behind_.Search(context, unpack); +    } + +  private: +    void FoundDupe(PartialEdge &table, PartialEdge adding) { +      if (table.GetFlags() & kPackedFlag) { +        Packed &packed = *static_cast<Packed*>(table.GetNote().mut); +        if (table.GetScore() >= adding.GetScore()) { +          packed.others.push_back(adding); +          return; +        } +        Note original(packed.original); +        packed.original = adding.GetNote(); +        adding.SetNote(table.GetNote()); +        table.SetNote(original); +        packed.others.push_back(table); +        packed.starting = adding.GetScore(); +        table = adding; +        table.MutableFlags() |= kPackedFlag; +        return; +      } +      PartialEdge loser; +      if (adding.GetScore() > table.GetScore()) { +        loser = table; +        table = adding; +      } else { +        loser = adding; +      } +      // table is winner, loser is loser... +      packed_.construct(table, loser); +    } + +    struct Packed { +      Packed(PartialEdge winner, PartialEdge loser)  +        : original(winner.GetNote()), starting(winner.GetScore()), others(1, loser) { +        winner.MutableNote().vp = this; +        winner.MutableFlags() |= kPackedFlag; +        loser.MutableFlags() &= ~kPackedFlag; +      } +      Note original; +      Score starting; +      std::vector<PartialEdge> others; +    }; + +    template <class Output> class Unpack { +      public: +        explicit Unpack(Output &output, Dedupe &owner) : output_(output), owner_(owner) {} + +        void NewHypothesis(PartialEdge edge) { +          if (edge.GetFlags() & kPackedFlag) { +            Packed &packed = *reinterpret_cast<Packed*>(edge.GetNote().mut); +            edge.SetNote(packed.original); +            edge.MutableFlags() = 0; +            std::size_t copy_size = sizeof(PartialVertex) * edge.GetArity() + sizeof(lm::ngram::ChartState); +            for (std::vector<PartialEdge>::iterator i = packed.others.begin(); i != packed.others.end(); ++i) { +              PartialEdge copy(owner_.AllocateEdge(edge.GetArity())); +              copy.SetScore(edge.GetScore() - packed.starting + i->GetScore()); +              copy.MutableFlags() = 0; +              copy.SetNote(i->GetNote()); +              memcpy(copy.NT(), edge.NT(), copy_size); +              output_.NewHypothesis(copy); +            } +          } +          output_.NewHypothesis(edge); +        } + +        void FinishedSearch() { +          output_.FinishedSearch(); +        } + +      private: +        Output &output_; +        Dedupe &owner_; +    }; + +    EdgeGenerator behind_; + +    typedef boost::unordered_map<uint64_t, PartialEdge> Table; +    Table table_; + +    boost::object_pool<Packed> packed_; + +    static const uint16_t kPackedFlag = 1; +}; +} // namespace search +#endif // SEARCH_DEDUPE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index 260159b1..eacf5de5 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -1,6 +1,7 @@  #include "search/edge_generator.hh"  #include "lm/left.hh" +#include "lm/model.hh"  #include "lm/partial.hh"  #include "search/context.hh"  #include "search/vertex.hh" @@ -38,7 +39,7 @@ template <class Model> void FastScore(const Context<Model> &context, Arity victi        *cover = *(cover + 1);      }    } -  update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM()); +  update.SetScore(update.GetScore() + adjustment * context.LMWeight());  }  } // namespace diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh index 582c78b7..203942c6 100644 --- a/klm/search/edge_generator.hh +++ b/klm/search/edge_generator.hh @@ -2,7 +2,6 @@  #define SEARCH_EDGE_GENERATOR__  #include "search/edge.hh" -#include "search/note.hh"  #include "search/types.hh"  #include <queue> diff --git a/klm/search/final.hh b/klm/search/final.hh deleted file mode 100644 index 50e62cf2..00000000 --- a/klm/search/final.hh +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef SEARCH_FINAL__ -#define SEARCH_FINAL__ - -#include "search/header.hh" -#include "util/pool.hh" - -namespace search { - -// A full hypothesis with pointers to children. -class Final : public Header { -  public: -    Final() {} - -    Final(util::Pool &pool, Score score, Arity arity, Note note)  -      : Header(pool.Allocate(Size(arity)), arity) { -      SetScore(score); -      SetNote(note); -    } - -    // These are arrays of length GetArity(). -    Final *Children() { -      return reinterpret_cast<Final*>(After()); -    } -    const Final *Children() const { -      return reinterpret_cast<const Final*>(After()); -    } - -  private: -    static std::size_t Size(Arity arity) { -      return kHeaderSize + arity * sizeof(const Final); -    } -}; - -} // namespace search - -#endif // SEARCH_FINAL__ diff --git a/klm/search/header.hh b/klm/search/header.hh index 25550dbe..69f0eed0 100644 --- a/klm/search/header.hh +++ b/klm/search/header.hh @@ -3,7 +3,6 @@  // Header consisting of Score, Arity, and Note -#include "search/note.hh"  #include "search/types.hh"  #include <stdint.h> @@ -24,6 +23,9 @@ class Header {      bool operator<(const Header &other) const {        return GetScore() < other.GetScore();      } +    bool operator>(const Header &other) const { +      return GetScore() > other.GetScore(); +    }      Arity GetArity() const {        return *reinterpret_cast<const Arity*>(base_ + sizeof(Score)); @@ -36,9 +38,14 @@ class Header {        *reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to;      } +    uint8_t *Base() { return base_; } +    const uint8_t *Base() const { return base_; } +    protected:      Header() : base_(NULL) {} +    explicit Header(void *base) : base_(static_cast<uint8_t*>(base)) {} +      Header(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) {        *reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity;      } diff --git a/klm/search/nbest.cc b/klm/search/nbest.cc new file mode 100644 index 00000000..ec3322c9 --- /dev/null +++ b/klm/search/nbest.cc @@ -0,0 +1,106 @@ +#include "search/nbest.hh" + +#include "util/pool.hh" + +#include <algorithm> +#include <functional> +#include <queue> + +#include <assert.h> +#include <math.h> + +namespace search { + +NBestList::NBestList(std::vector<PartialEdge> &partials, util::Pool &entry_pool, std::size_t keep) { +  assert(!partials.empty()); +  std::vector<PartialEdge>::iterator end; +  if (partials.size() > keep) { +    end = partials.begin() + keep; +    std::nth_element(partials.begin(), end, partials.end(), std::greater<PartialEdge>()); +  } else { +    end = partials.end(); +  } +  for (std::vector<PartialEdge>::const_iterator i(partials.begin()); i != end; ++i) { +    queue_.push(QueueEntry(entry_pool.Allocate(QueueEntry::Size(i->GetArity())), *i)); +  } +} + +Score NBestList::TopAfterConstructor() const { +  assert(revealed_.empty()); +  return queue_.top().GetScore(); +} + +const std::vector<Applied> &NBestList::Extract(util::Pool &pool, std::size_t n) { +  while (revealed_.size() < n && !queue_.empty()) { +    MoveTop(pool); +  } +  return revealed_; +} + +Score NBestList::Visit(util::Pool &pool, std::size_t index) { +  if (index + 1 < revealed_.size()) +    return revealed_[index + 1].GetScore() - revealed_[index].GetScore(); +  if (queue_.empty())  +    return -INFINITY; +  if (index + 1 == revealed_.size()) +    return queue_.top().GetScore() - revealed_[index].GetScore(); +  assert(index == revealed_.size()); + +  MoveTop(pool); + +  if (queue_.empty()) return -INFINITY; +  return queue_.top().GetScore() - revealed_[index].GetScore(); +} + +Applied NBestList::Get(util::Pool &pool, std::size_t index) { +  assert(index <= revealed_.size()); +  if (index == revealed_.size()) MoveTop(pool); +  return revealed_[index]; +} + +void NBestList::MoveTop(util::Pool &pool) { +  assert(!queue_.empty()); +  QueueEntry entry(queue_.top()); +  queue_.pop(); +  RevealedRef *const children_begin = entry.Children(); +  RevealedRef *const children_end = children_begin + entry.GetArity(); +  Score basis = entry.GetScore(); +  for (RevealedRef *child = children_begin; child != children_end; ++child) { +    Score change = child->in_->Visit(pool, child->index_); +    if (change != -INFINITY) { +      assert(change < 0.001); +      QueueEntry new_entry(pool.Allocate(QueueEntry::Size(entry.GetArity())), basis + change, entry.GetArity(), entry.GetNote()); +      std::copy(children_begin, child, new_entry.Children()); +      RevealedRef *update = new_entry.Children() + (child - children_begin); +      update->in_ = child->in_; +      update->index_ = child->index_ + 1; +      std::copy(child + 1, children_end, update + 1); +      queue_.push(new_entry); +    } +    // Gesmundo, A. and Henderson, J. Faster Cube Pruning, IWSLT 2010. +    if (child->index_) break; +  } + +  // Convert QueueEntry to Applied.  This leaves some unused memory.   +  void *overwrite = entry.Children(); +  for (unsigned int i = 0; i < entry.GetArity(); ++i) { +    RevealedRef from(*(static_cast<const RevealedRef*>(overwrite) + i)); +    *(static_cast<Applied*>(overwrite) + i) = from.in_->Get(pool, from.index_); +  } +  revealed_.push_back(Applied(entry.Base())); +} + +NBestComplete NBest::Complete(std::vector<PartialEdge> &partials) { +  assert(!partials.empty()); +  NBestList *list = list_pool_.construct(partials, entry_pool_, config_.keep); +  return NBestComplete( +      list, +      partials.front().CompletedState(), // All partials have the same state +      list->TopAfterConstructor()); +} + +const std::vector<Applied> &NBest::Extract(History history) { +  return static_cast<NBestList*>(history)->Extract(entry_pool_, config_.size); +} + +} // namespace search diff --git a/klm/search/nbest.hh b/klm/search/nbest.hh new file mode 100644 index 00000000..cb7651bc --- /dev/null +++ b/klm/search/nbest.hh @@ -0,0 +1,81 @@ +#ifndef SEARCH_NBEST__ +#define SEARCH_NBEST__ + +#include "search/applied.hh" +#include "search/config.hh" +#include "search/edge.hh" + +#include <boost/pool/object_pool.hpp> + +#include <cstddef> +#include <queue> +#include <vector> + +#include <assert.h> + +namespace search { + +class NBestList; + +class NBestList { +  private: +    class RevealedRef { +      public:  +        explicit RevealedRef(History history)  +          : in_(static_cast<NBestList*>(history)), index_(0) {} + +      private: +        friend class NBestList; + +        NBestList *in_; +        std::size_t index_; +    }; +     +    typedef GenericApplied<RevealedRef> QueueEntry; + +  public: +    NBestList(std::vector<PartialEdge> &existing, util::Pool &entry_pool, std::size_t keep); + +    Score TopAfterConstructor() const; + +    const std::vector<Applied> &Extract(util::Pool &pool, std::size_t n); + +  private: +    Score Visit(util::Pool &pool, std::size_t index); + +    Applied Get(util::Pool &pool, std::size_t index); + +    void MoveTop(util::Pool &pool); + +    typedef std::vector<Applied> Revealed; +    Revealed revealed_; + +    typedef std::priority_queue<QueueEntry> Queue; +    Queue queue_; +}; + +class NBest { +  public: +    typedef std::vector<PartialEdge> Combine; + +    explicit NBest(const NBestConfig &config) : config_(config) {} + +    void Add(std::vector<PartialEdge> &existing, PartialEdge addition) const { +      existing.push_back(addition); +    } + +    NBestComplete Complete(std::vector<PartialEdge> &partials); + +    const std::vector<Applied> &Extract(History root); + +  private: +    const NBestConfig config_; + +    boost::object_pool<NBestList> list_pool_; + +    util::Pool entry_pool_; +}; + +} // namespace search + +#endif // SEARCH_NBEST__ diff --git a/klm/search/note.hh b/klm/search/note.hh deleted file mode 100644 index 50bed06e..00000000 --- a/klm/search/note.hh +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef SEARCH_NOTE__ -#define SEARCH_NOTE__ - -namespace search { - -union Note { -  const void *vp; -}; - -} // namespace search - -#endif // SEARCH_NOTE__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc index 5b00207e..0244a09f 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -1,7 +1,7 @@  #include "search/rule.hh" +#include "lm/model.hh"  #include "search/context.hh" -#include "search/final.hh"  #include <ostream> @@ -9,35 +9,35 @@  namespace search { -template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing) { -  unsigned int oov_count = 0; -  float prob = 0.0; -  const Model &model = context.LanguageModel(); -  const lm::WordIndex oov = model.GetVocabulary().NotFound(); -  for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) { -    lm::ngram::RuleScore<Model> scorer(model, *(writing++)); -    // TODO: optimize -    if (prepend_bos && (word == words.begin())) { -      scorer.BeginSentence(); -    } -    for (; ; ++word) { -      if (word == words.end()) { -        prob += scorer.Finish(); -        return static_cast<float>(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM(); -      } -      if (*word == kNonTerminal) break; -      if (*word == oov) ++oov_count; +template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing) { +  ScoreRuleRet ret; +  ret.prob = 0.0; +  ret.oov = 0; +  const lm::WordIndex oov = model.GetVocabulary().NotFound(), bos = model.GetVocabulary().BeginSentence(); +  lm::ngram::RuleScore<Model> scorer(model, *(writing++)); +  std::vector<lm::WordIndex>::const_iterator word = words.begin(); +  if (word != words.end() && *word == bos) { +    scorer.BeginSentence(); +    ++word; +  } +  for (; word != words.end(); ++word) { +    if (*word == kNonTerminal) { +      ret.prob += scorer.Finish(); +      scorer.Reset(*(writing++)); +    } else { +      if (*word == oov) ++ret.oov;        scorer.Terminal(*word);      } -    prob += scorer.Finish();    } +  ret.prob += scorer.Finish(); +  return ret;  } -template float ScoreRule(const Context<lm::ngram::RestProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context<lm::ngram::ProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context<lm::ngram::TrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context<lm::ngram::QuantTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context<lm::ngram::ArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context<lm::ngram::QuantArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::RestProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::TrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);  } // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh index 0ce2794d..43ca6162 100644 --- a/klm/search/rule.hh +++ b/klm/search/rule.hh @@ -9,11 +9,16 @@  namespace search { -template <class Model> class Context; -  const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; -template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *state_out); +struct ScoreRuleRet { +  Score prob; +  unsigned int oov; +}; + +// Pass <s> and </s> normally.   +// Indicate non-terminals with kNonTerminal.   +template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *state_out);  } // namespace search diff --git a/klm/search/types.hh b/klm/search/types.hh index 06eb5bfa..f9c849b3 100644 --- a/klm/search/types.hh +++ b/klm/search/types.hh @@ -3,12 +3,29 @@  #include <stdint.h> +namespace lm { namespace ngram { class ChartState; } } +  namespace search {  typedef float Score;  typedef uint32_t Arity; +union Note { +  const void *vp; +}; + +typedef void *History; + +struct NBestComplete { +  NBestComplete(History in_history, const lm::ngram::ChartState &in_state, Score in_score)  +    : history(in_history), state(&in_state), score(in_score) {} + +  History history; +  const lm::ngram::ChartState *state; +  Score score; +}; +  } // namespace search  #endif // SEARCH_TYPES__ diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc index 11f4631f..45842982 100644 --- a/klm/search/vertex.cc +++ b/klm/search/vertex.cc @@ -19,21 +19,34 @@ struct GreaterByBound : public std::binary_function<const VertexNode *, const Ve  } // namespace -void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) { +void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) {    if (Complete()) { -    assert(end_.Valid()); +    assert(end_);      assert(extend_.empty()); -    bound_ = end_.GetScore();      return;    } -  if (extend_.size() == 1 && parent_ptr) { -    *parent_ptr = extend_[0]; -    extend_[0]->SortAndSet(context, parent_ptr); +  if (extend_.size() == 1) { +    parent_ptr = extend_[0]; +    extend_[0]->RecursiveSortAndSet(context, parent_ptr);      context.DeleteVertexNode(this);      return;    }    for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { -    (*i)->SortAndSet(context, &*i); +    (*i)->RecursiveSortAndSet(context, *i); +  } +  std::sort(extend_.begin(), extend_.end(), GreaterByBound()); +  bound_ = extend_.front()->Bound(); +} + +void VertexNode::SortAndSet(ContextBase &context) { +  // This is the root.  The root might be empty.   +  if (extend_.empty()) { +    bound_ = -INFINITY; +    return; +  } +  // The root cannot be replaced.  There's always one transition.   +  for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { +    (*i)->RecursiveSortAndSet(context, *i);    }    std::sort(extend_.begin(), extend_.end(), GreaterByBound());    bound_ = extend_.front()->Bound(); diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index 52bc1dfe..10b3339b 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -2,7 +2,6 @@  #define SEARCH_VERTEX__  #include "lm/left.hh" -#include "search/final.hh"  #include "search/types.hh"  #include <boost/unordered_set.hpp> @@ -10,6 +9,7 @@  #include <queue>  #include <vector> +#include <math.h>  #include <stdint.h>  namespace search { @@ -18,7 +18,7 @@ class ContextBase;  class VertexNode {    public: -    VertexNode() {} +    VertexNode() : end_() {}      void InitRoot() {        extend_.clear(); @@ -26,7 +26,7 @@ class VertexNode {        state_.left.length = 0;        state_.right.length = 0;        right_full_ = false; -      end_ = Final(); +      end_ = History();      }      lm::ngram::ChartState &MutableState() { return state_; } @@ -36,20 +36,21 @@ class VertexNode {        extend_.push_back(next);      } -    void SetEnd(Final end) { -      assert(!end_.Valid()); +    void SetEnd(History end, Score score) { +      assert(!end_);        end_ = end; +      bound_ = score;      } -    void SortAndSet(ContextBase &context, VertexNode **parent_pointer); +    void SortAndSet(ContextBase &context);      // Should only happen to a root node when the entire vertex is empty.         bool Empty() const { -      return !end_.Valid() && extend_.empty(); +      return !end_ && extend_.empty();      }      bool Complete() const { -      return end_.Valid(); +      return end_;      }      const lm::ngram::ChartState &State() const { return state_; } @@ -64,7 +65,7 @@ class VertexNode {      }      // Will be invalid unless this is a leaf.    -    const Final End() const { return end_; } +    const History End() const { return end_; }      const VertexNode &operator[](size_t index) const {        return *extend_[index]; @@ -75,13 +76,15 @@ class VertexNode {      }    private: +    void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent); +      std::vector<VertexNode*> extend_;      lm::ngram::ChartState state_;      bool right_full_;      Score bound_; -    Final end_; +    History end_;  };  class PartialVertex { @@ -97,7 +100,7 @@ class PartialVertex {      const lm::ngram::ChartState &State() const { return back_->State(); }      bool RightFull() const { return back_->RightFull(); } -    Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); } +    Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); }      unsigned char Length() const { return back_->Length(); } @@ -121,7 +124,7 @@ class PartialVertex {        return ret;      } -    const Final End() const { +    const History End() const {        return back_->End();      } @@ -130,16 +133,18 @@ class PartialVertex {      unsigned int index_;  }; +template <class Output> class VertexGenerator; +  class Vertex {    public:      Vertex() {}      PartialVertex RootPartial() const { return PartialVertex(root_); } -    const Final BestChild() const { +    const History BestChild() const {        PartialVertex top(RootPartial());        if (top.Empty()) { -        return Final(); +        return History();        } else {          PartialVertex continuation;          while (!top.Complete()) { @@ -150,8 +155,8 @@ class Vertex {      }    private: -    friend class VertexGenerator; - +    template <class Output> friend class VertexGenerator; +    template <class Output> friend class RootVertexGenerator;      VertexNode root_;  }; diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index 0945fe55..73139ffc 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -4,26 +4,18 @@  #include "search/context.hh"  #include "search/edge.hh" +#include <boost/unordered_map.hpp> +#include <boost/version.hpp> +  #include <stdint.h>  namespace search { -VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { -  gen.root_.InitRoot(); -} - +#if BOOST_VERSION > 104200  namespace {  const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); -// Parallel structure to VertexNode.   -struct Trie { -  Trie() : under(NULL) {} - -  VertexNode *under; -  boost::unordered_map<uint64_t, Trie> extend; -}; -  Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {    Trie &next = node.extend[added];    if (!next.under) { @@ -39,19 +31,10 @@ Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::n    return next;  } -void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) { -  Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote()); -  Final *child_out = final.Children(); -  const PartialVertex *part = partial.NT(); -  const PartialVertex *const part_end_loop = part + partial.GetArity(); -  for (; part != part_end_loop; ++part, ++child_out) -    *child_out = part->End(); - -  starter.under->SetEnd(final); -} +} // namespace -void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { -  const lm::ngram::ChartState &state = partial.CompletedState(); +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end) { +  const lm::ngram::ChartState &state = *end.state;    unsigned char left = 0, right = 0;    Trie *node = &root; @@ -77,18 +60,9 @@ void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) {    }    node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); -  CompleteTransition(context, *node, partial); +  node->under->SetEnd(end.history, end.score);  } -} // namespace - -void VertexGenerator::FinishedSearch() { -  Trie root; -  root.under = &gen_.root_; -  for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) { -    AddHypothesis(context_, root, i->second); -  } -  root.under->SortAndSet(context_, NULL); -} +#endif // BOOST_VERSION  } // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 60e86112..da563c2d 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -2,9 +2,11 @@  #define SEARCH_VERTEX_GENERATOR__  #include "search/edge.hh" +#include "search/types.hh"  #include "search/vertex.hh"  #include <boost/unordered_map.hpp> +#include <boost/version.hpp>  namespace lm {  namespace ngram { @@ -15,21 +17,44 @@ class ChartState;  namespace search {  class ContextBase; -class Final; -class VertexGenerator { +#if BOOST_VERSION > 104200 +// Parallel structure to VertexNode.   +struct Trie { +  Trie() : under(NULL) {} + +  VertexNode *under; +  boost::unordered_map<uint64_t, Trie> extend; +}; + +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end); + +#endif // BOOST_VERSION + +// Output makes the single-best or n-best list.    +template <class Output> class VertexGenerator {    public: -    VertexGenerator(ContextBase &context, Vertex &gen); +    VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) { +      gen.root_.InitRoot(); +    }      void NewHypothesis(PartialEdge partial) { -      const lm::ngram::ChartState &state = partial.CompletedState(); -      std::pair<Existing::iterator, bool> ret(existing_.insert(std::make_pair(hash_value(state), partial))); -      if (!ret.second && ret.first->second < partial) { -        ret.first->second = partial; -      } +      nbest_.Add(existing_[hash_value(partial.CompletedState())], partial);      } -    void FinishedSearch(); +    void FinishedSearch() { +#if BOOST_VERSION > 104200 +      Trie root; +      root.under = &gen_.root_; +      for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) { +        AddHypothesis(context_, root, nbest_.Complete(i->second)); +      } +      existing_.clear(); +      root.under->SortAndSet(context_); +#else +      UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search."); +#endif +    }      const Vertex &Generating() const { return gen_; } @@ -38,8 +63,35 @@ class VertexGenerator {      Vertex &gen_; -    typedef boost::unordered_map<uint64_t, PartialEdge> Existing; +    typedef boost::unordered_map<uint64_t, typename Output::Combine> Existing;      Existing existing_; + +    Output &nbest_; +}; + +// Special case for root vertex: everything should come together into the root +// node.  In theory, this should happen naturally due to state collapsing with +// <s> and </s>.  If that's the case, VertexGenerator is fine, though it will +// make one connection.   +template <class Output> class RootVertexGenerator { +  public: +    RootVertexGenerator(Vertex &gen, Output &out) : gen_(gen), out_(out) {} + +    void NewHypothesis(PartialEdge partial) { +      out_.Add(combine_, partial); +    } + +    void FinishedSearch() { +      gen_.root_.InitRoot(); +      NBestComplete completed(out_.Complete(combine_)); +      gen_.root_.SetEnd(completed.history, completed.score); +    } + +  private: +    Vertex &gen_; +     +    typename Output::Combine combine_; +    Output &out_;  };  } // namespace search diff --git a/klm/search/weights.cc b/klm/search/weights.cc deleted file mode 100644 index d65471ad..00000000 --- a/klm/search/weights.cc +++ /dev/null @@ -1,71 +0,0 @@ -#include "search/weights.hh" -#include "util/tokenize_piece.hh" - -#include <cstdlib> - -namespace search { - -namespace { -struct Insert { -  void operator()(boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) const { -    std::string copy(name.data(), name.size()); -    map[copy] = score; -  } -}; - -struct DotProduct { -  search::Score total; -  DotProduct() : total(0.0) {} - -  void operator()(const boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) { -    boost::unordered_map<std::string, search::Score>::const_iterator i(FindStringPiece(map, name)); -    if (i != map.end())  -      total += score * i->second; -  } -}; - -template <class Map, class Op> void Parse(StringPiece text, Map &map, Op &op) { -  for (util::TokenIter<util::SingleCharacter, true> spaces(text, ' '); spaces; ++spaces) { -    util::TokenIter<util::SingleCharacter> equals(*spaces, '='); -    UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces); -    StringPiece name(*equals); -    UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces); -    char *end; -    // Assumes proper termination.   -    double value = std::strtod(equals->data(), &end); -    UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals); -    UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces); -    op(map, name, value); -  } -} - -} // namespace - -Weights::Weights(StringPiece text) { -  Insert op; -  Parse<Map, Insert>(text, map_, op); -  lm_ = Steal("LanguageModel"); -  oov_ = Steal("OOV"); -  word_penalty_ = Steal("WordPenalty"); -} - -Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {} - -search::Score Weights::DotNoLM(StringPiece text) const { -  DotProduct dot; -  Parse<const Map, DotProduct>(text, map_, dot); -  return dot.total; -} - -float Weights::Steal(const std::string &str) { -  Map::iterator i(map_.find(str)); -  if (i == map_.end()) { -    return 0.0; -  } else { -    float ret = i->second; -    map_.erase(i); -    return ret; -  } -} - -} // namespace search diff --git a/klm/search/weights.hh b/klm/search/weights.hh deleted file mode 100644 index df1c419f..00000000 --- a/klm/search/weights.hh +++ /dev/null @@ -1,52 +0,0 @@ -// For now, the individual features are not kept.   -#ifndef SEARCH_WEIGHTS__ -#define SEARCH_WEIGHTS__ - -#include "search/types.hh" -#include "util/exception.hh" -#include "util/string_piece.hh" - -#include <boost/unordered_map.hpp> - -#include <string> - -namespace search { - -class WeightParseException : public util::Exception { -  public: -    WeightParseException() {} -    ~WeightParseException() throw() {} -}; - -class Weights { -  public: -    // Parses weights, sets lm_weight_, removes it from map_. -    explicit Weights(StringPiece text); - -    // Just the three scores we care about adding.    -    Weights(Score lm, Score oov, Score word_penalty); - -    Score DotNoLM(StringPiece text) const; - -    Score LM() const { return lm_; } - -    Score OOV() const { return oov_; } - -    Score WordPenalty() const { return word_penalty_; } - -    // Mostly for testing.   -    const boost::unordered_map<std::string, Score> &GetMap() const { return map_; } - -  private: -    float Steal(const std::string &str); - -    typedef boost::unordered_map<std::string, Score> Map; - -    Map map_; - -    Score lm_, oov_, word_penalty_; -}; - -} // namespace search - -#endif // SEARCH_WEIGHTS__ diff --git a/klm/search/weights_test.cc b/klm/search/weights_test.cc deleted file mode 100644 index 4811ff06..00000000 --- a/klm/search/weights_test.cc +++ /dev/null @@ -1,38 +0,0 @@ -#include "search/weights.hh" - -#define BOOST_TEST_MODULE WeightTest -#include <boost/test/unit_test.hpp> -#include <boost/test/floating_point_comparison.hpp> - -namespace search { -namespace { - -#define CHECK_WEIGHT(value, string) \ -  i = parsed.find(string); \ -  BOOST_REQUIRE(i != parsed.end()); \ -  BOOST_CHECK_CLOSE((value), i->second, 0.001); - -BOOST_AUTO_TEST_CASE(parse) { -  // These are not real feature weights.   -  Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); -  const boost::unordered_map<std::string, search::Score> &parsed = w.GetMap(); -  boost::unordered_map<std::string, search::Score>::const_iterator i; -  CHECK_WEIGHT(0.0, "rarity"); -  CHECK_WEIGHT(0.0, "phrase-SGT"); -  CHECK_WEIGHT(9.45117, "phrase-TGS"); -  CHECK_WEIGHT(2.33833, "lexical-SGT"); -  BOOST_CHECK(parsed.end() == parsed.find("lm")); -  BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001); -  CHECK_WEIGHT(-28.3317, "lexical-TGS"); -  CHECK_WEIGHT(5.0, "glue?"); -} - -BOOST_AUTO_TEST_CASE(dot) { -  Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); -  BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001); -  BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001); -  BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001); -} - -} // namespace -} // namespace search diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index 5306850f..a676bdb3 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -27,6 +27,7 @@ libklm_util_a_SOURCES = \    mmap.cc \    murmur_hash.cc \    pool.cc \ +	read_compressed.cc \    string_piece.cc \  	usage.cc diff --git a/klm/util/exception.hh b/klm/util/exception.hh index 053a850b..0165a7a3 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -87,8 +87,14 @@ template <class Except, class Data> typename Except::template ExceptionTag<Excep    throw UTIL_e; \  } while (0) +#if __GNUC__ >= 3 +#define UTIL_UNLIKELY(x) __builtin_expect (!!(x), 0) +#else +#define UTIL_UNLIKELY(x) (x) +#endif +  #define UTIL_THROW_IF(Condition, Exception, Modify) do { \ -  if (Condition) { \ +  if (UTIL_UNLIKELY(Condition)) { \      Exception UTIL_e; \      UTIL_SET_LOCATION(UTIL_e, #Exception, #Condition); \      UTIL_e << Modify; \ diff --git a/klm/util/file.cc b/klm/util/file.cc index 6bf879ac..b9a77cf9 100644 --- a/klm/util/file.cc +++ b/klm/util/file.cc @@ -15,6 +15,8 @@  #if defined(_WIN32) || defined(_WIN64)  #include <windows.h>  #include <io.h> +#include <algorithm> +#include <limits.h>  #else  #include <unistd.h>  #endif @@ -48,7 +50,7 @@ int OpenReadOrThrow(const char *name) {  int CreateOrThrow(const char *name) {    int ret;  #if defined(_WIN32) || defined(_WIN64) -  UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name); +  UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR | _O_BINARY, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name);  #else    UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)), ErrnoException, "while creating " << name);  #endif @@ -74,16 +76,22 @@ void ResizeOrThrow(int fd, uint64_t to) {  #endif  } -#ifdef WIN32 -typedef int ssize_t; +std::size_t PartialRead(int fd, void *to, std::size_t amount) { +#if defined(_WIN32) || defined(_WIN64) +  amount = min(static_cast<std::size_t>(INT_MAX), amount); +  int ret = _read(fd, to, amount);  +#else +  ssize_t ret = read(fd, to, amount);  #endif +  UTIL_THROW_IF(ret < 0, ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); +  return static_cast<std::size_t>(ret); +}  void ReadOrThrow(int fd, void *to_void, std::size_t amount) {    uint8_t *to = static_cast<uint8_t*>(to_void);    while (amount) { -    ssize_t ret = read(fd, to, amount); -    UTIL_THROW_IF(ret == -1, ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); -    UTIL_THROW_IF(ret == 0, EndOfFileException, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read."); +    std::size_t ret = PartialRead(fd, to, amount); +    UTIL_THROW_IF(ret == 0, EndOfFileException, " in fd " << fd << " but there should be " << amount << " more bytes to read.");      amount -= ret;      to += ret;    } @@ -93,8 +101,7 @@ std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount) {    uint8_t *to = static_cast<uint8_t*>(to_void);    std::size_t remaining = amount;    while (remaining) { -    ssize_t ret = read(fd, to, remaining); -    UTIL_THROW_IF(ret == -1, ErrnoException, "Reading " << remaining << " from fd " << fd << " failed."); +    std::size_t ret = PartialRead(fd, to, remaining);      if (!ret) return amount - remaining;      remaining -= ret;      to += ret; @@ -105,7 +112,11 @@ std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount) {  void WriteOrThrow(int fd, const void *data_void, std::size_t size) {    const uint8_t *data = static_cast<const uint8_t*>(data_void);    while (size) { +#if defined(_WIN32) || defined(_WIN64) +    int ret = write(fd, data, min(static_cast<std::size_t>(INT_MAX), size)); +#else      ssize_t ret = write(fd, data, size); +#endif      if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed");      data += ret;      size -= ret; @@ -114,7 +125,7 @@ void WriteOrThrow(int fd, const void *data_void, std::size_t size) {  void WriteOrThrow(FILE *to, const void *data, std::size_t size) {    assert(size); -  if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); +  UTIL_THROW_IF(1 != std::fwrite(data, size, 1, to), util::ErrnoException, "Short write; requested size " << size);  }  void FSyncOrThrow(int fd) { @@ -149,14 +160,15 @@ void SeekEnd(int fd) {  std::FILE *FDOpenOrThrow(scoped_fd &file) {    std::FILE *ret = fdopen(file.get(), "r+b"); -  if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen"); +  if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen descriptor " << file.get());    file.release();    return ret;  } -std::FILE *FOpenOrThrow(const char *path, const char *mode) { -  std::FILE *ret; -  UTIL_THROW_IF(!(ret = fopen(path, mode)), util::ErrnoException, "Could not fopen " << path << " for " << mode); +std::FILE *FDOpenReadOrThrow(scoped_fd &file) { +  std::FILE *ret = fdopen(file.get(), "rb"); +  if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen descriptor " << file.get()); +  file.release();    return ret;  } diff --git a/klm/util/file.hh b/klm/util/file.hh index 185cb1f3..c24580d6 100644 --- a/klm/util/file.hh +++ b/klm/util/file.hh @@ -32,8 +32,6 @@ class scoped_fd {        return ret;      } -    operator bool() { return fd_ != -1; } -    private:      int fd_; @@ -76,8 +74,9 @@ uint64_t SizeFile(int fd);  void ResizeOrThrow(int fd, uint64_t to); +std::size_t PartialRead(int fd, void *to, std::size_t size);  void ReadOrThrow(int fd, void *to, std::size_t size); -std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount); +std::size_t ReadOrEOF(int fd, void *to_void, std::size_t size);  void WriteOrThrow(int fd, const void *data_void, std::size_t size);  void WriteOrThrow(FILE *to, const void *data, std::size_t size); @@ -90,8 +89,7 @@ void AdvanceOrThrow(int fd, int64_t off);  void SeekEnd(int fd);  std::FILE *FDOpenOrThrow(scoped_fd &file); - -std::FILE *FOpenOrThrow(const char *path, const char *mode); +std::FILE *FDOpenReadOrThrow(scoped_fd &file);  class TempMaker {    public: diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index 280f438c..5a208eff 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -14,7 +14,6 @@  #include <limits>  #include <assert.h> -#include <ctype.h>  #include <fcntl.h>  #include <stdlib.h>  #include <sys/types.h> @@ -26,13 +25,6 @@ ParseNumberException::ParseNumberException(StringPiece value) throw() {    *this << "Could not parse \"" << value << "\" into a number";  } -#ifdef HAVE_ZLIB -GZException::GZException(gzFile file) { -  int num; -  *this << gzerror(file, &num) << " from zlib"; -} -#endif // HAVE_ZLIB -  // Sigh this is the only way I could come up with to do a _const_ bool.  It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale).   const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; @@ -48,19 +40,7 @@ FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, std:    Initialize(name, show_progress, min_buffer);  } -FilePiece::~FilePiece() { -#ifdef HAVE_ZLIB -  if (gz_file_) { -    // zlib took ownership -    file_.release(); -    int ret; -    if (Z_OK != (ret = gzclose(gz_file_))) { -      std::cerr << "could not close file " << file_name_ << " using zlib" << std::endl; -      abort(); -    } -  } -#endif -} +FilePiece::~FilePiece() {}  StringPiece FilePiece::ReadLine(char delim) {    std::size_t skip = 0; @@ -95,9 +75,6 @@ unsigned long int FilePiece::ReadULong() {  }  void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::size_t min_buffer)  { -#ifdef HAVE_ZLIB -  gz_file_ = NULL; -#endif    file_name_ = name;    default_map_size_ = page_ * std::max<std::size_t>((min_buffer / page_ + 1), 2); @@ -117,10 +94,7 @@ void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::s    }    Shift();    // gzip detect. -  if ((position_end_ - position_) > 2 && *position_ == 0x1f && static_cast<unsigned char>(*(position_ + 1)) == 0x8b) { -#ifndef HAVE_ZLIB -    UTIL_THROW(GZException, "Looks like a gzip file but support was not compiled in."); -#endif +  if ((position_end_ - position_) >= ReadCompressed::kMagicSize && ReadCompressed::DetectCompressedMagic(position_)) {      if (!fallback_to_read_) {        at_end_ = false;        TransitionToRead(); @@ -197,7 +171,7 @@ void FilePiece::Shift() {    if (fallback_to_read_) ReadShift();    for (last_space_ = position_end_ - 1; last_space_ >= position_; --last_space_) { -    if (isspace(*last_space_))  break; +    if (kSpaces[static_cast<unsigned char>(*last_space_)])  break;    }  } @@ -248,17 +222,14 @@ void FilePiece::TransitionToRead() {    position_ = data_.begin();    position_end_ = position_; -#ifdef HAVE_ZLIB -  assert(!gz_file_); -  gz_file_ = gzdopen(file_.get(), "r"); -  UTIL_THROW_IF(!gz_file_, GZException, "zlib failed to open " << file_name_); -#endif +  try { +    fell_back_.Reset(file_.release()); +  } catch (util::Exception &e) { +    e << " in file " << file_name_; +    throw; +  }  } -#ifdef WIN32 -typedef int ssize_t; -#endif -  void FilePiece::ReadShift() {    assert(fallback_to_read_);    // Bytes [data_.begin(), position_) have been consumed.   @@ -283,7 +254,7 @@ void FilePiece::ReadShift() {        position_ = data_.begin();        position_end_ = position_ + valid_length;      } else { -      size_t moving = position_end_ - position_; +      std::size_t moving = position_end_ - position_;        memmove(data_.get(), position_, moving);        position_ = data_.begin();        position_end_ = position_ + moving; @@ -291,20 +262,9 @@ void FilePiece::ReadShift() {      }    } -  ssize_t read_return; -#ifdef HAVE_ZLIB -  read_return = gzread(gz_file_, static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read); -  if (read_return == -1) throw GZException(gz_file_); -  if (total_size_ != kBadSize) { -    // Just get the position, don't actually seek.  Apparently this is how you do it. . .  -    off_t ret = lseek(file_.get(), 0, SEEK_CUR); -    if (ret != -1) progress_.Set(ret); -  } -#else -  read_return = read(file_.get(), static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read); -  UTIL_THROW_IF(read_return == -1, ErrnoException, "read failed"); -  progress_.Set(mapped_offset_); -#endif +  std::size_t read_return = fell_back_.Read(static_cast<uint8_t*>(data_.get()) + already_read, default_map_size_ - already_read); +  progress_.Set(fell_back_.RawAmount()); +    if (read_return == 0) {      at_end_ = true;    } diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index af93d8aa..39bd1581 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -4,8 +4,8 @@  #include "util/ersatz_progress.hh"  #include "util/exception.hh"  #include "util/file.hh" -#include "util/have.hh"  #include "util/mmap.hh" +#include "util/read_compressed.hh"  #include "util/string_piece.hh"  #include <cstddef> @@ -13,10 +13,6 @@  #include <stdint.h> -#ifdef HAVE_ZLIB -#include <zlib.h> -#endif -  namespace util {  class ParseNumberException : public Exception { @@ -25,28 +21,19 @@ class ParseNumberException : public Exception {      ~ParseNumberException() throw() {}  }; -class GZException : public Exception { -  public: -#ifdef HAVE_ZLIB -    explicit GZException(gzFile file); -#endif -    GZException() throw() {} -    ~GZException() throw() {} -}; -  extern const bool kSpaces[256]; -// Memory backing the returned StringPiece may vanish on the next call.   +// Memory backing the returned StringPiece may vanish on the next call.  class FilePiece {    public: -    // 32 MB default. -    explicit FilePiece(const char *file, std::ostream *show_progress = NULL, std::size_t min_buffer = 33554432); -    // Takes ownership of fd.  name is used for messages.   -    explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, std::size_t min_buffer = 33554432); +    // 1 MB default. +    explicit FilePiece(const char *file, std::ostream *show_progress = NULL, std::size_t min_buffer = 1048576); +    // Takes ownership of fd.  name is used for messages. +    explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, std::size_t min_buffer = 1048576);      ~FilePiece(); -      -    char get() {  + +    char get() {        if (position_ == position_end_) {          Shift();          if (at_end_) throw EndOfFileException(); @@ -54,14 +41,14 @@ class FilePiece {        return *(position_++);      } -    // Leaves the delimiter, if any, to be returned by get().  Delimiters defined by isspace().   +    // Leaves the delimiter, if any, to be returned by get().  Delimiters defined by isspace().      StringPiece ReadDelimited(const bool *delim = kSpaces) {        SkipSpaces(delim);        return Consume(FindDelimiterOrEOF(delim));      }      // Unlike ReadDelimited, this includes leading spaces and consumes the delimiter. -    // It is similar to getline in that way.   +    // It is similar to getline in that way.      StringPiece ReadLine(char delim = '\n');      float ReadFloat(); @@ -69,7 +56,7 @@ class FilePiece {      long int ReadLong();      unsigned long int ReadULong(); -    // Skip spaces defined by isspace.   +    // Skip spaces defined by isspace.      void SkipSpaces(const bool *delim = kSpaces) {        for (; ; ++position_) {          if (position_ == position_end_) Shift(); @@ -82,7 +69,7 @@ class FilePiece {      }      const std::string &FileName() const { return file_name_; } -     +    private:      void Initialize(const char *name, std::ostream *show_progress, std::size_t min_buffer); @@ -122,9 +109,7 @@ class FilePiece {      std::string file_name_; -#ifdef HAVE_ZLIB -    gzFile gz_file_; -#endif // HAVE_ZLIB +    ReadCompressed fell_back_;  };  } // namespace util diff --git a/klm/util/file_piece_test.cc b/klm/util/file_piece_test.cc index f912e18a..e79ece7a 100644 --- a/klm/util/file_piece_test.cc +++ b/klm/util/file_piece_test.cc @@ -38,7 +38,7 @@ BOOST_AUTO_TEST_CASE(MMapReadLine) {    BOOST_CHECK_THROW(test.get(), EndOfFileException);  } -#ifndef __APPLE__ +#if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__)  /* Apple isn't happy with the popen, fileno, dup.  And I don't want to   * reimplement popen.  This is an issue with the test.     */ @@ -65,7 +65,7 @@ BOOST_AUTO_TEST_CASE(StreamReadLine) {    BOOST_CHECK_THROW(test.get(), EndOfFileException);    BOOST_REQUIRE(!pclose(catter));  } -#endif // __APPLE__ +#endif  #ifdef HAVE_ZLIB diff --git a/klm/util/have.hh b/klm/util/have.hh index b8181e99..85b838e4 100644 --- a/klm/util/have.hh +++ b/klm/util/have.hh @@ -2,22 +2,16 @@  #ifndef UTIL_HAVE__  #define UTIL_HAVE__ -#ifndef HAVE_ZLIB -#if !defined(_WIN32) && !defined(_WIN64) -#define HAVE_ZLIB -#endif -#endif -  #ifndef HAVE_ICU  //#define HAVE_ICU  #endif  #ifndef HAVE_BOOST -#define HAVE_BOOST +//#define HAVE_BOOST  #endif -#ifndef HAVE_THREADS -//#define HAVE_THREADS +#ifdef HAVE_CONFIG_H +#include "config.h"  #endif  #endif // UTIL_HAVE__ diff --git a/klm/util/joint_sort.hh b/klm/util/joint_sort.hh index cf3d8432..1b43ddcf 100644 --- a/klm/util/joint_sort.hh +++ b/klm/util/joint_sort.hh @@ -60,7 +60,7 @@ template <class KeyIter, class ValueIter> class JointProxy {      JointProxy(const KeyIter &key_iter, const ValueIter &value_iter) : inner_(key_iter, value_iter) {}      JointProxy(const JointProxy<KeyIter, ValueIter> &other) : inner_(other.inner_) {} -    operator const value_type() const { +    operator value_type() const {        value_type ret;        ret.key = *inner_.key_;        ret.value = *inner_.value_; @@ -121,7 +121,7 @@ template <class Proxy, class Less> class LessWrapper : public std::binary_functi  template <class KeyIter, class ValueIter> class PairedIterator : public ProxyIterator<detail::JointProxy<KeyIter, ValueIter> > {    public: -    PairedIterator(const KeyIter &key, const ValueIter &value) :  +    PairedIterator(const KeyIter &key, const ValueIter &value) :        ProxyIterator<detail::JointProxy<KeyIter, ValueIter> >(detail::JointProxy<KeyIter, ValueIter>(key, value)) {}  }; diff --git a/klm/util/read_compressed.cc b/klm/util/read_compressed.cc new file mode 100644 index 00000000..4ec94c4e --- /dev/null +++ b/klm/util/read_compressed.cc @@ -0,0 +1,403 @@ +#include "util/read_compressed.hh" + +#include "util/file.hh" +#include "util/have.hh" +#include "util/scoped.hh" + +#include <algorithm> +#include <iostream> + +#include <assert.h> +#include <limits.h> +#include <stdlib.h> +#include <string.h> + +#ifdef HAVE_ZLIB +#include <zlib.h> +#endif + +#ifdef HAVE_BZLIB +#include <bzlib.h> +#endif + +#ifdef HAVE_XZLIB +#include <lzma.h> +#endif + +namespace util { + +CompressedException::CompressedException() throw() {} +CompressedException::~CompressedException() throw() {} + +GZException::GZException() throw() {} +GZException::~GZException() throw() {} + +BZException::BZException() throw() {} +BZException::~BZException() throw() {} + +XZException::XZException() throw() {} +XZException::~XZException() throw() {} + +class ReadBase { +  public: +    virtual ~ReadBase() {} + +    virtual std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) = 0; + +  protected: +    static void ReplaceThis(ReadBase *with, ReadCompressed &thunk) { +      thunk.internal_.reset(with); +    } + +    static uint64_t &ReadCount(ReadCompressed &thunk) { +      return thunk.raw_amount_; +    } +}; + +namespace { + +// Completed file that other classes can thunk to.   +class Complete : public ReadBase { +  public: +    std::size_t Read(void *, std::size_t, ReadCompressed &) { +      return 0; +    } +}; + +class Uncompressed : public ReadBase { +  public: +    explicit Uncompressed(int fd) : fd_(fd) {} + +    std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { +      std::size_t got = PartialRead(fd_.get(), to, amount); +      ReadCount(thunk) += got; +      return got; +    } + +  private: +    scoped_fd fd_; +}; + +class UncompressedWithHeader : public ReadBase { +  public: +    UncompressedWithHeader(int fd, void *already_data, std::size_t already_size) : fd_(fd) { +      assert(already_size); +      buf_.reset(malloc(already_size)); +      if (!buf_.get()) throw std::bad_alloc(); +      memcpy(buf_.get(), already_data, already_size); +      remain_ = static_cast<uint8_t*>(buf_.get()); +      end_ = remain_ + already_size; +    } + +    std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { +      assert(buf_.get()); +      std::size_t sending = std::min<std::size_t>(amount, end_ - remain_); +      memcpy(to, remain_, sending); +      remain_ += sending; +      if (remain_ == end_) { +        ReplaceThis(new Uncompressed(fd_.release()), thunk); +      } +      return sending; +    } + +  private: +    scoped_malloc buf_; +    uint8_t *remain_; +    uint8_t *end_; + +    scoped_fd fd_; +}; + +#ifdef HAVE_ZLIB +class GZip : public ReadBase { +  private: +    static const std::size_t kInputBuffer = 16384; +  public: +    GZip(int fd, void *already_data, std::size_t already_size)  +      : file_(fd), in_buffer_(malloc(kInputBuffer)) { +      if (!in_buffer_.get()) throw std::bad_alloc(); +      assert(already_size < kInputBuffer); +      if (already_size) { +        memcpy(in_buffer_.get(), already_data, already_size); +        stream_.next_in = static_cast<Bytef *>(in_buffer_.get()); +        stream_.avail_in = already_size; +        stream_.avail_in += ReadOrEOF(file_.get(), static_cast<uint8_t*>(in_buffer_.get()) + already_size, kInputBuffer - already_size); +      } else { +        stream_.avail_in = 0; +      } +      stream_.zalloc = Z_NULL; +      stream_.zfree = Z_NULL; +      stream_.opaque = Z_NULL; +      stream_.msg = NULL; +      // 32 for zlib and gzip decoding with automatic header detection.   +      // 15 for maximum window size.   +      UTIL_THROW_IF(Z_OK != inflateInit2(&stream_, 32 + 15), GZException, "Failed to initialize zlib."); +    } + +    ~GZip() { +      if (Z_OK != inflateEnd(&stream_)) { +        std::cerr << "zlib could not close properly." << std::endl; +        abort(); +      } +    } + +    std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { +      if (amount == 0) return 0; +      stream_.next_out = static_cast<Bytef*>(to); +      stream_.avail_out = std::min<std::size_t>(std::numeric_limits<uInt>::max(), amount); +      do { +        if (!stream_.avail_in) ReadInput(thunk); +        int result = inflate(&stream_, 0); +        switch (result) { +          case Z_OK: +            break; +          case Z_STREAM_END: +            { +              std::size_t ret = static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to); +              ReplaceThis(new Complete(), thunk); +              return ret; +            } +          case Z_ERRNO: +            UTIL_THROW(ErrnoException, "zlib error"); +          default: +            UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result); +        } +      } while (stream_.next_out == to); +      return static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to); +    } + +  private: +    void ReadInput(ReadCompressed &thunk) { +      assert(!stream_.avail_in); +      stream_.next_in = static_cast<Bytef *>(in_buffer_.get()); +      stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer); +      ReadCount(thunk) += stream_.avail_in; +    } + +    scoped_fd file_; +    scoped_malloc in_buffer_; +    z_stream stream_; +}; +#endif // HAVE_ZLIB + +#ifdef HAVE_BZLIB +class BZip : public ReadBase { +  public: +    explicit BZip(int fd, void *already_data, std::size_t already_size) { +      scoped_fd hold(fd); +      closer_.reset(FDOpenReadOrThrow(hold)); +      int bzerror = BZ_OK; +      file_ = BZ2_bzReadOpen(&bzerror, closer_.get(), 0, 0, already_data, already_size); +      switch (bzerror) { +        case BZ_OK: +          return; +        case BZ_CONFIG_ERROR: +          UTIL_THROW(BZException, "Looks like bzip2 was miscompiled."); +        case BZ_PARAM_ERROR: +          UTIL_THROW(BZException, "Parameter error"); +        case BZ_IO_ERROR: +          UTIL_THROW(BZException, "IO error reading file"); +        case BZ_MEM_ERROR: +          throw std::bad_alloc(); +      } +    } + +    ~BZip() { +      int bzerror = BZ_OK; +      BZ2_bzReadClose(&bzerror, file_); +      if (bzerror != BZ_OK) { +        std::cerr << "bz2 readclose error" << std::endl; +        abort(); +      } +    } + +    std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { +      int bzerror = BZ_OK; +      int ret = BZ2_bzRead(&bzerror, file_, to, std::min<std::size_t>(static_cast<std::size_t>(INT_MAX), amount)); +      long pos; +      switch (bzerror) { +        case BZ_STREAM_END: +          pos = ftell(closer_.get()); +          if (pos != -1) ReadCount(thunk) = pos; +          ReplaceThis(new Complete(), thunk); +          return ret; +        case BZ_OK: +          pos = ftell(closer_.get()); +          if (pos != -1) ReadCount(thunk) = pos; +          return ret; +        default: +          UTIL_THROW(BZException, "bzip2 error " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror); +      } +    } + +  private: +    scoped_FILE closer_; +    BZFILE *file_; +}; +#endif // HAVE_BZLIB + +#ifdef HAVE_XZLIB +class XZip : public ReadBase { +  private: +    static const std::size_t kInputBuffer = 16384; +  public: +    XZip(int fd, void *already_data, std::size_t already_size)  +      : file_(fd), in_buffer_(malloc(kInputBuffer)), stream_(), action_(LZMA_RUN) { +      if (!in_buffer_.get()) throw std::bad_alloc(); +      assert(already_size < kInputBuffer); +      if (already_size) { +        memcpy(in_buffer_.get(), already_data, already_size); +        stream_.next_in = static_cast<const uint8_t*>(in_buffer_.get()); +        stream_.avail_in = already_size; +        stream_.avail_in += ReadOrEOF(file_.get(), static_cast<uint8_t*>(in_buffer_.get()) + already_size, kInputBuffer - already_size); +      } else { +        stream_.avail_in = 0; +      } +      stream_.allocator = NULL; +      lzma_ret ret = lzma_stream_decoder(&stream_, UINT64_MAX, LZMA_CONCATENATED); +      switch (ret) { +        case LZMA_OK: +          break; +        case LZMA_MEM_ERROR: +          UTIL_THROW(ErrnoException, "xz open error"); +        default: +          UTIL_THROW(XZException, "xz error code " << ret); +      } +    } + +    ~XZip() { +      lzma_end(&stream_); +    } + +    std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { +      if (amount == 0) return 0; +      stream_.next_out = static_cast<uint8_t*>(to); +      stream_.avail_out = amount; +      do { +        if (!stream_.avail_in) ReadInput(thunk); +        lzma_ret status = lzma_code(&stream_, action_); +        switch (status) { +          case LZMA_OK: +            break; +          case LZMA_STREAM_END: +            UTIL_THROW_IF(action_ != LZMA_FINISH, XZException, "Input not finished yet."); +            { +              std::size_t ret = static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to); +              ReplaceThis(new Complete(), thunk); +              return ret; +            } +          case LZMA_MEM_ERROR: +            throw std::bad_alloc(); +          case LZMA_FORMAT_ERROR: +            UTIL_THROW(XZException, "xzlib says file format not recognized"); +          case LZMA_OPTIONS_ERROR: +            UTIL_THROW(XZException, "xzlib says unsupported compression options"); +          case LZMA_DATA_ERROR: +            UTIL_THROW(XZException, "xzlib says this file is corrupt"); +          case LZMA_BUF_ERROR: +            UTIL_THROW(XZException, "xzlib says unexpected end of input"); +          default: +            UTIL_THROW(XZException, "unrecognized xzlib error " << status); +        } +      } while (stream_.next_out == to); +      return static_cast<uint8_t*>(stream_.next_out) - static_cast<uint8_t*>(to); +    } + +  private: +    void ReadInput(ReadCompressed &thunk) { +      assert(!stream_.avail_in); +      stream_.next_in = static_cast<const uint8_t*>(in_buffer_.get()); +      stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer); +      if (!stream_.avail_in) action_ = LZMA_FINISH; +      ReadCount(thunk) += stream_.avail_in; +    } + +    scoped_fd file_; +    scoped_malloc in_buffer_; +    lzma_stream stream_; + +    lzma_action action_; +}; +#endif // HAVE_XZLIB + +enum MagicResult { +  UNKNOWN, GZIP, BZIP, XZIP +}; + +MagicResult DetectMagic(const void *from_void) { +  const uint8_t *header = static_cast<const uint8_t*>(from_void); +  if (header[0] == 0x1f && header[1] == 0x8b) { +    return GZIP; +  } +  if (header[0] == 'B' && header[1] == 'Z') { +    return BZIP; +  } +  const uint8_t xzmagic[6] = { 0xFD, '7', 'z', 'X', 'Z', 0x00 }; +  if (!memcmp(header, xzmagic, 6)) { +    return XZIP; +  } +  return UNKNOWN; +} + +ReadBase *ReadFactory(int fd, uint64_t &raw_amount) { +  scoped_fd hold(fd); +  unsigned char header[ReadCompressed::kMagicSize]; +  raw_amount = ReadOrEOF(fd, header, ReadCompressed::kMagicSize); +  if (!raw_amount) +    return new Uncompressed(hold.release()); +  if (raw_amount != ReadCompressed::kMagicSize) +    return new UncompressedWithHeader(hold.release(), header, raw_amount); +  switch (DetectMagic(header)) { +    case GZIP: +#ifdef HAVE_ZLIB +      return new GZip(hold.release(), header, ReadCompressed::kMagicSize); +#else +      UTIL_THROW(CompressedException, "This looks like a gzip file but gzip support was not compiled in."); +#endif +    case BZIP: +#ifdef HAVE_BZLIB +      return new BZip(hold.release(), header, ReadCompressed::kMagicSize); +#else +      UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZ), but bzip support was not compiled in."); +#endif +    case XZIP: +#ifdef HAVE_XZLIB +      return new XZip(hold.release(), header, ReadCompressed::kMagicSize); +#else +      UTIL_THROW(CompressedException, "This looks like an xz file, but xz support was not compiled in."); +#endif +    case UNKNOWN: +      break; +  } +  try { +    AdvanceOrThrow(fd, -ReadCompressed::kMagicSize); +  } catch (const util::ErrnoException &e) { +    return new UncompressedWithHeader(hold.release(), header, ReadCompressed::kMagicSize); +  } +  return new Uncompressed(hold.release()); +} + +} // namespace + +bool ReadCompressed::DetectCompressedMagic(const void *from_void) { +  return DetectMagic(from_void) != UNKNOWN; +} + +ReadCompressed::ReadCompressed(int fd) { +  Reset(fd); +} + +ReadCompressed::ReadCompressed() {} + +ReadCompressed::~ReadCompressed() {} + +void ReadCompressed::Reset(int fd) { +  internal_.reset(); +  internal_.reset(ReadFactory(fd, raw_amount_)); +} + +std::size_t ReadCompressed::Read(void *to, std::size_t amount) { +  return internal_->Read(to, amount, *this); +} + +} // namespace util diff --git a/klm/util/read_compressed.hh b/klm/util/read_compressed.hh new file mode 100644 index 00000000..83ca9fb2 --- /dev/null +++ b/klm/util/read_compressed.hh @@ -0,0 +1,74 @@ +#ifndef UTIL_READ_COMPRESSED__ +#define UTIL_READ_COMPRESSED__ + +#include "util/exception.hh" +#include "util/scoped.hh" + +#include <cstddef> + +#include <stdint.h> + +namespace util { + +class CompressedException : public Exception { +  public: +    CompressedException() throw(); +    virtual ~CompressedException() throw(); +}; + +class GZException : public CompressedException { +  public: +    GZException() throw(); +    ~GZException() throw(); +}; + +class BZException : public CompressedException { +  public: +    BZException() throw(); +    ~BZException() throw(); +}; + +class XZException : public CompressedException { +  public: +    XZException() throw(); +    ~XZException() throw(); +}; + +class ReadBase; + +class ReadCompressed { +  public: +    static const std::size_t kMagicSize = 6; +    // Must have at least kMagicSize bytes.   +    static bool DetectCompressedMagic(const void *from); + +    // Takes ownership of fd.    +    explicit ReadCompressed(int fd); + +    // Must call Reset later. +    ReadCompressed(); + +    ~ReadCompressed(); + +    // Takes ownership of fd.   +    void Reset(int fd); + +    std::size_t Read(void *to, std::size_t amount); + +    uint64_t RawAmount() const { return raw_amount_; } + +  private: +    friend class ReadBase; + +    scoped_ptr<ReadBase> internal_; + +    uint64_t raw_amount_; + +    // No copying.   +    ReadCompressed(const ReadCompressed &); +    void operator=(const ReadCompressed &); +}; + +} // namespace util + +#endif // UTIL_READ_COMPRESSED__ diff --git a/klm/util/read_compressed_test.cc b/klm/util/read_compressed_test.cc new file mode 100644 index 00000000..6fd97e5e --- /dev/null +++ b/klm/util/read_compressed_test.cc @@ -0,0 +1,94 @@ +#include "util/read_compressed.hh" + +#include "util/file.hh" +#include "util/have.hh" + +#define BOOST_TEST_MODULE ReadCompressedTest +#include <boost/test/unit_test.hpp> +#include <boost/scoped_ptr.hpp> + +#include <fstream> +#include <string> + +#include <stdlib.h> + +namespace util { +namespace { + +void ReadLoop(ReadCompressed &reader, void *to_void, std::size_t amount) { +  uint8_t *to = static_cast<uint8_t*>(to_void); +  while (amount) { +    std::size_t ret = reader.Read(to, amount); +    BOOST_REQUIRE(ret); +    to += ret; +    amount -= ret; +  } +} + +void TestRandom(const char *compressor) { +  const uint32_t kSize4 = 100000 / 4; +  char name[] = "tempXXXXXX"; + +  // Write test file.   +  { +    scoped_fd original(mkstemp(name)); +    BOOST_REQUIRE(original.get() > 0); +    for (uint32_t i = 0; i < kSize4; ++i) { +      WriteOrThrow(original.get(), &i, sizeof(uint32_t)); +    } +  } + +  char gzname[] = "tempXXXXXX"; +  scoped_fd gzipped(mkstemp(gzname)); + +  std::string command(compressor); +#ifdef __CYGWIN__ +  command += ".exe"; +#endif +  command += " <\""; +  command += name; +  command += "\" >\""; +  command += gzname; +  command += "\""; +  BOOST_REQUIRE_EQUAL(0, system(command.c_str())); + +  BOOST_CHECK_EQUAL(0, unlink(name)); +  BOOST_CHECK_EQUAL(0, unlink(gzname)); + +  ReadCompressed reader(gzipped.release()); +  for (uint32_t i = 0; i < kSize4; ++i) { +    uint32_t got; +    ReadLoop(reader, &got, sizeof(uint32_t)); +    BOOST_CHECK_EQUAL(i, got); +  } + +  char ignored; +  BOOST_CHECK_EQUAL((std::size_t)0, reader.Read(&ignored, 1)); +  // Test double EOF call. +  BOOST_CHECK_EQUAL((std::size_t)0, reader.Read(&ignored, 1)); +} + +BOOST_AUTO_TEST_CASE(Uncompressed) { +  TestRandom("cat"); +} + +#ifdef HAVE_ZLIB +BOOST_AUTO_TEST_CASE(ReadGZ) { +  TestRandom("gzip"); +} +#endif // HAVE_ZLIB + +#ifdef HAVE_BZLIB +BOOST_AUTO_TEST_CASE(ReadBZ) { +  TestRandom("bzip2"); +} +#endif // HAVE_BZLIB + +#ifdef HAVE_XZLIB +BOOST_AUTO_TEST_CASE(ReadXZ) { +  TestRandom("xz"); +} +#endif + +} // namespace +} // namespace util diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh index 93e2e817..d62c6df1 100644 --- a/klm/util/scoped.hh +++ b/klm/util/scoped.hh @@ -1,40 +1,13 @@  #ifndef UTIL_SCOPED__  #define UTIL_SCOPED__ +/* Other scoped objects in the style of scoped_ptr. */  #include "util/exception.hh" - -/* Other scoped objects in the style of scoped_ptr. */  #include <cstddef>  #include <cstdlib>  namespace util { -template <class T, class R, R (*Free)(T*)> class scoped_thing { -  public: -    explicit scoped_thing(T *c = static_cast<T*>(0)) : c_(c) {} - -    ~scoped_thing() { if (c_) Free(c_); } - -    void reset(T *c) { -      if (c_) Free(c_); -      c_ = c; -    } - -    T &operator*() { return *c_; } -    const T&operator*() const { return *c_; } -    T &operator->() { return *c_; } -    const T&operator->() const { return *c_; } - -    T *get() { return c_; } -    const T *get() const { return c_; } - -  private: -    T *c_; - -    scoped_thing(const scoped_thing &); -    scoped_thing &operator=(const scoped_thing &); -}; -  class scoped_malloc {    public:      scoped_malloc() : p_(NULL) {} @@ -77,9 +50,6 @@ template <class T> class scoped_array {      T &operator*() { return *c_; }      const T&operator*() const { return *c_; } -    T &operator->() { return *c_; } -    const T&operator->() const { return *c_; } -      T &operator[](std::size_t idx) { return c_[idx]; }      const T &operator[](std::size_t idx) const { return c_[idx]; } @@ -90,6 +60,39 @@ template <class T> class scoped_array {    private:      T *c_; + +    scoped_array(const scoped_array &); +    void operator=(const scoped_array &); +}; + +template <class T> class scoped_ptr { +  public: +    explicit scoped_ptr(T *content = NULL) : c_(content) {} + +    ~scoped_ptr() { delete c_; } + +    T *get() { return c_; } +    const T* get() const { return c_; } + +    T &operator*() { return *c_; } +    const T&operator*() const { return *c_; } + +    T *operator->() { return c_; } +    const T*operator->() const { return c_; } + +    T &operator[](std::size_t idx) { return c_[idx]; } +    const T &operator[](std::size_t idx) const { return c_[idx]; } + +    void reset(T *to = NULL) { +      scoped_ptr<T> other(c_); +      c_ = to; +    } + +  private: +    T *c_; + +    scoped_ptr(const scoped_ptr &); +    void operator=(const scoped_ptr &);  };  } // namespace util diff --git a/klm/util/string_piece.hh b/klm/util/string_piece.hh index be6a643d..51481646 100644 --- a/klm/util/string_piece.hh +++ b/klm/util/string_piece.hh @@ -1,6 +1,6 @@  /* If you use ICU in your program, then compile with -DHAVE_ICU -licui18n.  If   * you don't use ICU, then this will use the Google implementation from Chrome. - * This has been modified from the original version to let you choose.   + * This has been modified from the original version to let you choose.   */  // Copyright 2008, Google Inc. @@ -62,9 +62,9 @@  #include <unicode/stringpiece.h>  #include <unicode/uversion.h> -// Old versions of ICU don't define operator== and operator!=.   +// Old versions of ICU don't define operator== and operator!=.  #if (U_ICU_VERSION_MAJOR_NUM < 4) || ((U_ICU_VERSION_MAJOR_NUM == 4) && (U_ICU_VERSION_MINOR_NUM < 4)) -#warning You are using an old version of ICU.  Consider upgrading to ICU >= 4.6.   +#warning You are using an old version of ICU.  Consider upgrading to ICU >= 4.6.  inline bool operator==(const StringPiece& x, const StringPiece& y) {    if (x.size() != y.size())      return false; @@ -274,15 +274,28 @@ struct StringPieceCompatibleEquals : public std::binary_function<const StringPie    }  };  template <class T> typename T::const_iterator FindStringPiece(const T &t, const StringPiece &key) { +#if BOOST_VERSION < 104200 +  std::string temp(key.data(), key.size()); +  return t.find(temp); +#else    return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); +#endif  } +  template <class T> typename T::iterator FindStringPiece(T &t, const StringPiece &key) { +#if BOOST_VERSION < 104200 +  std::string temp(key.data(), key.size()); +  return t.find(temp); +#else    return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); +#endif  }  #endif  #ifdef HAVE_ICU  U_NAMESPACE_END +using U_NAMESPACE_QUALIFIER StringPiece;  #endif +  #endif  // BASE_STRING_PIECE_H__ diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh index 4a7f5460..a588c3fc 100644 --- a/klm/util/tokenize_piece.hh +++ b/klm/util/tokenize_piece.hh @@ -20,6 +20,7 @@ class OutOfTokens : public Exception {  class SingleCharacter {    public: +    SingleCharacter() {}      explicit SingleCharacter(char delim) : delim_(delim) {}      StringPiece Find(const StringPiece &in) const { @@ -32,6 +33,8 @@ class SingleCharacter {  class MultiCharacter {    public: +    MultiCharacter() {} +      explicit MultiCharacter(const StringPiece &delimiter) : delimiter_(delimiter) {}      StringPiece Find(const StringPiece &in) const { @@ -44,6 +47,7 @@ class MultiCharacter {  class AnyCharacter {    public: +    AnyCharacter() {}      explicit AnyCharacter(const StringPiece &chars) : chars_(chars) {}      StringPiece Find(const StringPiece &in) const { @@ -56,6 +60,8 @@ class AnyCharacter {  class AnyCharacterLast {    public: +    AnyCharacterLast() {} +      explicit AnyCharacterLast(const StringPiece &chars) : chars_(chars) {}      StringPiece Find(const StringPiece &in) const { @@ -81,8 +87,8 @@ template <class Find, bool SkipEmpty = false> class TokenIter : public boost::it        return current_.data() != 0;      } -    static TokenIter<Find> end() { -      return TokenIter<Find>(); +    static TokenIter<Find, SkipEmpty> end() { +      return TokenIter<Find, SkipEmpty>();      }    private: @@ -100,8 +106,8 @@ template <class Find, bool SkipEmpty = false> class TokenIter : public boost::it        } while (SkipEmpty && current_.data() && current_.empty()); // Compiler should optimize this away if SkipEmpty is false.        } -    bool equal(const TokenIter<Find> &other) const { -      return after_.data() == other.after_.data(); +    bool equal(const TokenIter<Find, SkipEmpty> &other) const { +      return current_.data() == other.current_.data();      }      const StringPiece &dereference() const {  | 
