diff options
author | Patrick Simianer <p@simianer.de> | 2011-11-13 12:26:23 +0100 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2011-11-13 12:26:23 +0100 |
commit | effc9bfc40a0559ce36a155daa15e0dc53e93b75 (patch) | |
tree | 768a29ebad48089e3445c515d47f49c942f09124 /klm | |
parent | ed8ca37550910a540e755ada119e814f13eeef03 (diff) | |
parent | a5592c9ab0266dbf4993e42e82e5a113316990ad (diff) |
merge upstream/master
Diffstat (limited to 'klm')
-rw-r--r-- | klm/lm/bhiksha.hh | 5 | ||||
-rw-r--r-- | klm/lm/build_binary.cc | 2 | ||||
-rw-r--r-- | klm/lm/config.hh | 6 | ||||
-rw-r--r-- | klm/lm/enumerate_vocab.hh | 2 | ||||
-rw-r--r-- | klm/lm/left.hh | 43 | ||||
-rw-r--r-- | klm/lm/model.cc | 4 | ||||
-rw-r--r-- | klm/lm/read_arpa.cc | 2 | ||||
-rw-r--r-- | klm/lm/search_hashed.cc | 16 | ||||
-rw-r--r-- | klm/lm/search_trie.cc | 3 | ||||
-rw-r--r-- | klm/lm/sri.cc | 108 | ||||
-rw-r--r-- | klm/lm/sri.hh | 102 | ||||
-rw-r--r-- | klm/lm/vocab.cc | 1 | ||||
-rw-r--r-- | klm/lm/vocab.hh | 3 | ||||
-rw-r--r-- | klm/util/mmap.cc | 2 | ||||
-rw-r--r-- | klm/util/murmur_hash.cc | 15 | ||||
-rw-r--r-- | klm/util/probing_hash_table.hh | 4 | ||||
-rw-r--r-- | klm/util/tokenize_piece.hh | 75 | ||||
-rw-r--r-- | klm/util/tokenize_piece_test.cc | 94 |
18 files changed, 235 insertions, 252 deletions
diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index bc705959..3df43dda 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -10,6 +10,9 @@ * Currently only used for next pointers. */ +#ifndef LM_BHIKSHA__ +#define LM_BHIKSHA__ + #include <inttypes.h> #include <assert.h> @@ -108,3 +111,5 @@ class ArrayBhiksha { } // namespace trie } // namespace ngram } // namespace lm + +#endif // LM_BHIKSHA__ diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index b7aee4de..fdb62a71 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,7 +15,7 @@ namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-c bits] [type] input.arpa [output.mmap]\n\n" + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n" "-u sets the log10 probability for <unk> if the ARPA file does not have one.\n" " Default is -100. The ARPA file will always take precedence.\n" "-s allows models to be built even if they do not have <s> and </s>.\n" diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 227b8512..8564661b 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -8,10 +8,12 @@ /* Configuration for ngram model. Separate header to reduce pollution. */ -namespace lm { namespace ngram { - +namespace lm { + class EnumerateVocab; +namespace ngram { + struct Config { // EFFECTIVE FOR BOTH ARPA AND BINARY READS diff --git a/klm/lm/enumerate_vocab.hh b/klm/lm/enumerate_vocab.hh index e734316b..27263621 100644 --- a/klm/lm/enumerate_vocab.hh +++ b/klm/lm/enumerate_vocab.hh @@ -5,7 +5,6 @@ #include "util/string_piece.hh" namespace lm { -namespace ngram { /* If you need the actual strings in the vocabulary, inherit from this class * and implement Add. Then put a pointer in Config.enumerate_vocab; it does @@ -23,7 +22,6 @@ class EnumerateVocab { EnumerateVocab() {} }; -} // namespace ngram } // namespace lm #endif // LM_ENUMERATE_VOCAB__ diff --git a/klm/lm/left.hh b/klm/lm/left.hh index bb3f5539..41f71f84 100644 --- a/klm/lm/left.hh +++ b/klm/lm/left.hh @@ -175,22 +175,14 @@ template <class M> class RuleScore { float backoffs[kMaxOrder - 1], backoffs2[kMaxOrder - 1]; float *back = backoffs, *back2 = backoffs2; - unsigned char next_use; - FullScoreReturn ret; - ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + out_.right.length, out_.right.backoff, in.left.pointers[0], 1, back, next_use)); - if (!next_use) { - left_done_ = true; - out_.right = in.right; - return; - } - unsigned char extend_length = 2; - for (const uint64_t *i = in.left.pointers + 1; i < in.left.pointers + in.left.length; ++i, ++extend_length) { - ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + next_use, back, *i, extend_length, back2, next_use)); - if (!next_use) { - left_done_ = true; - out_.right = in.right; - return; - } + unsigned char next_use = out_.right.length; + + // First word + if (ExtendLeft(in, next_use, 1, out_.right.backoff, back)) return; + + // Words after the first, so extending a bigram to begin with + for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) { + if (ExtendLeft(in, next_use, extend_length, back, back2)) return; std::swap(back, back2); } @@ -226,6 +218,25 @@ template <class M> class RuleScore { } private: + bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) { + ProcessRet(model_.ExtendLeft( + out_.right.words, out_.right.words + next_use, // Words to extend into + back_in, // Backoffs to use + in.left.pointers[extend_length - 1], extend_length, // Words to be extended + back_out, // Backoffs for the next score + next_use)); // Length of n-gram to use in next scoring. + if (next_use != out_.right.length) { + left_done_ = true; + if (!next_use) { + out_.right = in.right; + // Early exit. + return true; + } + } + // Continue scoring. + return false; + } + void ProcessRet(const FullScoreReturn &ret) { prob_ += ret.prob; if (left_done_) return; diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 25f1ab7c..e4c1ec1d 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -91,8 +91,8 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state); - if (ret.ngram_length - 1 < in_state.length) { - ret.prob = std::accumulate(in_state.backoff + ret.ngram_length - 1, in_state.backoff + in_state.length, ret.prob); + for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) { + ret.prob += *i; } return ret; } diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 455bc4ba..dce73f77 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -38,6 +38,8 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) { } if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic) UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser. Did you compress the binary file or pass a binary file where only ARPA files are accepted?"); + UTIL_THROW_IF(line.size() >= 4 && StringPiece(line.data(), 4) == "blmt", FormatLoadException, "This looks like an IRSTLM binary file. Did you forget to pass --text yes to compile-lm?"); + UTIL_THROW_IF(line == "iARPA", FormatLoadException, "This looks like an IRSTLM iARPA file. You need an ARPA file. Run\n compile-lm --text yes " << in.FileName() << " " << in.FileName() << ".arpa\nfirst."); UTIL_THROW(FormatLoadException, "first non-empty line was \"" << line << "\" not \\data\\."); } while (!IsEntirelyWhiteSpace(line = in.ReadLine())) { diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 334adf12..247832b0 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -87,14 +87,14 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams( ReadNGramHeader(f, n); // vocab ids of words in reverse order - WordIndex vocab_ids[n]; - uint64_t keys[n - 1]; + std::vector<WordIndex> vocab_ids(n); + std::vector<uint64_t> keys(n-1); typename Store::Packing::Value value; typename Middle::MutableIterator found; for (size_t i = 0; i < count; ++i) { - ReadNGram(f, n, vocab, vocab_ids, value, warn); + ReadNGram(f, n, vocab, &*vocab_ids.begin(), value, warn); - keys[0] = detail::CombineWordHash(static_cast<uint64_t>(*vocab_ids), vocab_ids[1]); + keys[0] = detail::CombineWordHash(static_cast<uint64_t>(vocab_ids.front()), vocab_ids[1]); for (unsigned int h = 1; h < n - 1; ++h) { keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]); } @@ -106,9 +106,9 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams( util::FloatEnc fix_prob; for (lower = n - 3; ; --lower) { if (lower == -1) { - fix_prob.f = unigrams[vocab_ids[0]].prob; + fix_prob.f = unigrams[vocab_ids.front()].prob; fix_prob.i &= ~util::kSignBit; - unigrams[vocab_ids[0]].prob = fix_prob.f; + unigrams[vocab_ids.front()].prob = fix_prob.f; break; } if (middle[lower].UnsafeMutableFind(keys[lower], found)) { @@ -120,8 +120,8 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams( break; } } - if (lower != static_cast<int>(n) - 3) FixSRI(lower, fix_prob.f, n, keys, vocab_ids, unigrams, middle); - activate(vocab_ids, n); + if (lower != static_cast<int>(n) - 3) FixSRI(lower, fix_prob.f, n, &*keys.begin(), &*vocab_ids.begin(), unigrams, middle); + activate(&*vocab_ids.begin(), n); } store.FinishedInserting(); diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 5d8c70db..4bd3f4ee 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -358,6 +358,7 @@ template <class Doing> class BlankManager { // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram. basis_[blank - 1] = kBadProb; } + *pre = *cur; been_length_ = length; } @@ -493,7 +494,7 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre util::scoped_FILE unigram_file; { std::string name(file_prefix + "unigrams"); - unigram_file.reset(OpenOrThrow(name.c_str(), "r")); + unigram_file.reset(OpenOrThrow(name.c_str(), "r+")); util::RemoveOrThrow(name.c_str()); } sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); diff --git a/klm/lm/sri.cc b/klm/lm/sri.cc deleted file mode 100644 index 825f699b..00000000 --- a/klm/lm/sri.cc +++ /dev/null @@ -1,108 +0,0 @@ -#include "lm/lm_exception.hh" -#include "lm/sri.hh" - -#include <Ngram.h> -#include <Vocab.h> - -#include <errno.h> - -namespace lm { -namespace sri { - -Vocabulary::Vocabulary() : sri_(new Vocab) {} - -Vocabulary::~Vocabulary() {} - -WordIndex Vocabulary::Index(const char *str) const { - WordIndex ret = sri_->getIndex(str); - // NGram wants the index of Vocab_Unknown for unknown words, but for some reason SRI returns Vocab_None here :-(. - if (ret == Vocab_None) { - return not_found_; - } else { - return ret; - } -} - -const char *Vocabulary::Word(WordIndex index) const { - return sri_->getWord(index); -} - -void Vocabulary::FinishedLoading() { - SetSpecial( - sri_->ssIndex(), - sri_->seIndex(), - sri_->unkIndex()); -} - -namespace { -Ngram *MakeSRIModel(const char *file_name, unsigned int ngram_length, Vocab &sri_vocab) { - sri_vocab.unkIsWord() = true; - std::auto_ptr<Ngram> ret(new Ngram(sri_vocab, ngram_length)); - File file(file_name, "r"); - errno = 0; - if (!ret->read(file)) { - UTIL_THROW(FormatLoadException, "reading file " << file_name << " with SRI failed."); - } - return ret.release(); -} -} // namespace - -Model::Model(const char *file_name, unsigned int ngram_length) : sri_(MakeSRIModel(file_name, ngram_length, *vocab_.sri_)) { - if (!sri_->setorder()) { - UTIL_THROW(FormatLoadException, "Can't have an SRI model with order 0."); - } - vocab_.FinishedLoading(); - State begin_state = State(); - begin_state.valid_length_ = 1; - if (kMaxOrder > 1) { - begin_state.history_[0] = vocab_.BeginSentence(); - if (kMaxOrder > 2) begin_state.history_[1] = Vocab_None; - } - State null_state = State(); - null_state.valid_length_ = 0; - if (kMaxOrder > 1) null_state.history_[0] = Vocab_None; - Init(begin_state, null_state, vocab_, sri_->setorder()); - not_found_ = vocab_.NotFound(); -} - -Model::~Model() {} - -namespace { - -/* Argh SRI's wordProb knows the ngram length but doesn't return it. One more - * reason you should use my model. */ -// TODO(stolcke): fix SRILM so I don't have to do this. -unsigned int MatchedLength(Ngram &model, const WordIndex new_word, const SRIVocabIndex *const_history) { - unsigned int out_length = 0; - // This gets the length of context used, which is ngram_length - 1 unless new_word is OOV in which case it is 0. - model.contextID(new_word, const_history, out_length); - return out_length + 1; -} - -} // namespace - -FullScoreReturn Model::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { - // If you get a compiler in this function, change SRIVocabIndex in sri.hh to match the one found in SRI's Vocab.h. - const SRIVocabIndex *const_history; - SRIVocabIndex local_history[Order()]; - if (in_state.valid_length_ < kMaxOrder - 1) { - const_history = in_state.history_; - } else { - std::copy(in_state.history_, in_state.history_ + in_state.valid_length_, local_history); - local_history[in_state.valid_length_] = Vocab_None; - const_history = local_history; - } - FullScoreReturn ret; - ret.ngram_length = MatchedLength(*sri_, new_word, const_history); - out_state.history_[0] = new_word; - out_state.valid_length_ = std::min<unsigned char>(ret.ngram_length, Order() - 1); - std::copy(const_history, const_history + out_state.valid_length_ - 1, out_state.history_ + 1); - if (out_state.valid_length_ < kMaxOrder - 1) { - out_state.history_[out_state.valid_length_] = Vocab_None; - } - ret.prob = sri_->wordProb(new_word, const_history); - return ret; -} - -} // namespace sri -} // namespace lm diff --git a/klm/lm/sri.hh b/klm/lm/sri.hh deleted file mode 100644 index b57e9b73..00000000 --- a/klm/lm/sri.hh +++ /dev/null @@ -1,102 +0,0 @@ -#ifndef LM_SRI__ -#define LM_SRI__ - -#include "lm/facade.hh" -#include "util/murmur_hash.hh" - -#include <cmath> -#include <exception> -#include <memory> - -class Ngram; -class Vocab; - -/* The ngram length reported uses some random API I found and may be wrong. - * - * See ngram, which should return equivalent results. - */ - -namespace lm { -namespace sri { - -static const unsigned int kMaxOrder = 6; - -/* This should match VocabIndex found in SRI's Vocab.h - * The reason I define this here independently is that SRI's headers - * pollute and increase compile time. - * It's difficult to extract this from their header and anyway would - * break packaging. - * If these differ there will be a compiler error in ActuallyCall. - */ -typedef unsigned int SRIVocabIndex; - -class State { - public: - // You shouldn't need to touch these, but they're public so State will be a POD. - // If valid_length_ < kMaxOrder - 1 then history_[valid_length_] == Vocab_None. - SRIVocabIndex history_[kMaxOrder - 1]; - unsigned char valid_length_; -}; - -inline bool operator==(const State &left, const State &right) { - if (left.valid_length_ != right.valid_length_) { - return false; - } - for (const SRIVocabIndex *l = left.history_, *r = right.history_; - l != left.history_ + left.valid_length_; - ++l, ++r) { - if (*l != *r) return false; - } - return true; -} - -inline size_t hash_value(const State &state) { - return util::MurmurHashNative(&state.history_, sizeof(SRIVocabIndex) * state.valid_length_); -} - -class Vocabulary : public base::Vocabulary { - public: - Vocabulary(); - - ~Vocabulary(); - - WordIndex Index(const StringPiece &str) const { - std::string temp(str.data(), str.length()); - return Index(temp.c_str()); - } - WordIndex Index(const std::string &str) const { - return Index(str.c_str()); - } - WordIndex Index(const char *str) const; - - const char *Word(WordIndex index) const; - - private: - friend class Model; - void FinishedLoading(); - - // The parent class isn't copyable so auto_ptr is the same as scoped_ptr - // but without the boost dependence. - mutable std::auto_ptr<Vocab> sri_; -}; - -class Model : public base::ModelFacade<Model, State, Vocabulary> { - public: - Model(const char *file_name, unsigned int ngram_length); - - ~Model(); - - FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const; - - private: - Vocabulary vocab_; - - mutable std::auto_ptr<Ngram> sri_; - - WordIndex not_found_; -}; - -} // namespace sri -} // namespace lm - -#endif // LM_SRI__ diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 03b0767a..ffec41ca 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -135,6 +135,7 @@ void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1); ReadWords(fd, to); SetSpecial(Index("<s>"), Index("</s>"), 0); + bound_ = end_ - begin_ + 1; } namespace { diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 41e97052..3c3414fb 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -15,10 +15,10 @@ namespace lm { class ProbBackoff; +class EnumerateVocab; namespace ngram { class Config; -class EnumerateVocab; namespace detail { uint64_t HashForVocab(const char *str, std::size_t len); @@ -66,7 +66,6 @@ class SortedVocabulary : public base::Vocabulary { static size_t Size(std::size_t entries, const Config &config); // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. - // While this number is correct, ProbingVocabulary::Bound might not be correct in some cases. WordIndex Bound() const { return bound_; } // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc index 5ce7adc9..279bafa8 100644 --- a/klm/util/mmap.cc +++ b/klm/util/mmap.cc @@ -15,7 +15,7 @@ namespace util { scoped_mmap::~scoped_mmap() { if (data_ != (void*)-1) { - // Thanks Denis Filimonov for pointing on NFS likes msync first. + // Thanks Denis Filimonov for pointing out NFS likes msync first. if (msync(data_, size_, MS_SYNC) || munmap(data_, size_)) { std::cerr << "msync or mmap failed for " << size_ << " bytes." << std::endl; abort(); diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc index fec47fd9..ef5783fe 100644 --- a/klm/util/murmur_hash.cc +++ b/klm/util/murmur_hash.cc @@ -117,13 +117,18 @@ uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) return h; } +// Trick to test for 64-bit architecture at compile time. +namespace { +template <unsigned L> uint64_t MurmurHashNativeBackend(const void * key, std::size_t len, unsigned int seed) { + return MurmurHash64A(key, len, seed); +} +template <> uint64_t MurmurHashNativeBackend<4>(const void * key, std::size_t len, unsigned int seed) { + return MurmurHash64B(key, len, seed); +} +} // namespace uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { - if (sizeof(int) == 4) { - return MurmurHash64B(key, len, seed); - } else { - return MurmurHash64A(key, len, seed); - } + return MurmurHashNativeBackend<sizeof(void*)>(key, len, seed); } } // namespace util diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 2ec342a6..8122d69c 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -61,14 +61,14 @@ template <class PackingT, class HashT, class EqualT = std::equal_to<typename Pac #endif {} - template <class T> void Insert(const T &t) { + template <class T> MutableIterator Insert(const T &t) { if (++entries_ >= buckets_) UTIL_THROW(ProbingSizeException, "Hash table with " << buckets_ << " buckets is full."); #ifdef DEBUG assert(initialized_); #endif for (MutableIterator i(begin_ + (hash_(t.GetKey()) % buckets_));;) { - if (equal_(i->GetKey(), invalid_)) { *i = t; return; } + if (equal_(i->GetKey(), invalid_)) { *i = t; return i; } if (++i == end_) { i = begin_; } } } diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh index ee1c7ab2..413bda0b 100644 --- a/klm/util/tokenize_piece.hh +++ b/klm/util/tokenize_piece.hh @@ -5,6 +5,9 @@ #include <boost/iterator/iterator_facade.hpp> +#include <algorithm> +#include <iostream> + /* Usage: * * for (PieceIterator<' '> i(" foo \r\n bar "); i; ++i) { @@ -64,6 +67,78 @@ template <char d> class PieceIterator : public boost::iterator_facade<PieceItera StringPiece after_; }; +class MultiCharacter { + public: + explicit MultiCharacter(const StringPiece &delimiter) : delimiter_(delimiter) {} + + StringPiece Find(const StringPiece &in) const { + return StringPiece(std::search(in.data(), in.data() + in.size(), delimiter_.data(), delimiter_.data() + delimiter_.size()), delimiter_.size()); + } + + private: + StringPiece delimiter_; +}; + +class AnyCharacter { + public: + explicit AnyCharacter(const StringPiece &chars) : chars_(chars) {} + + StringPiece Find(const StringPiece &in) const { + return StringPiece(std::find_first_of(in.data(), in.data() + in.size(), chars_.data(), chars_.data() + chars_.size()), 1); + } + + private: + StringPiece chars_; +}; + +template <class Find, bool SkipEmpty = false> class TokenIter : public boost::iterator_facade<TokenIter<Find, SkipEmpty>, const StringPiece, boost::forward_traversal_tag> { + public: + TokenIter() {} + + TokenIter(const StringPiece &str, const Find &finder) : after_(str), finder_(finder) { + increment(); + } + + bool operator!() const { + return current_.data() == 0; + } + operator bool() const { + return current_.data() != 0; + } + + static TokenIter<Find> end() { + return TokenIter<Find>(); + } + + private: + friend class boost::iterator_core_access; + + void increment() { + do { + StringPiece found(finder_.Find(after_)); + current_ = StringPiece(after_.data(), found.data() - after_.data()); + if (found.data() == after_.data() + after_.size()) { + after_ = StringPiece(NULL, 0); + } else { + after_ = StringPiece(found.data() + found.size(), after_.data() - found.data() + after_.size() - found.size()); + } + } while (SkipEmpty && current_.data() && current_.empty()); // Compiler should optimize this away if SkipEmpty is false. + } + + bool equal(const TokenIter<Find> &other) const { + return after_.data() == other.after_.data(); + } + + const StringPiece &dereference() const { + return current_; + } + + StringPiece current_; + StringPiece after_; + + Find finder_; +}; + } // namespace util #endif // UTIL_TOKENIZE_PIECE__ diff --git a/klm/util/tokenize_piece_test.cc b/klm/util/tokenize_piece_test.cc new file mode 100644 index 00000000..e07ebcf5 --- /dev/null +++ b/klm/util/tokenize_piece_test.cc @@ -0,0 +1,94 @@ +#include "util/tokenize_piece.hh" +#include "util/string_piece.hh" + +#define BOOST_TEST_MODULE TokenIteratorTest +#include <boost/test/unit_test.hpp> + +#include <iostream> + +namespace util { +namespace { + +BOOST_AUTO_TEST_CASE(simple) { + PieceIterator<' '> it("single spaced words."); + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece("single"), *it); + ++it; + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece("spaced"), *it); + ++it; + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece("words."), *it); + ++it; + BOOST_CHECK(!it); +} + +BOOST_AUTO_TEST_CASE(null_delimiter) { + const char str[] = "\0first\0\0second\0\0\0third\0fourth\0\0\0"; + PieceIterator<'\0'> it(StringPiece(str, sizeof(str) - 1)); + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece("first"), *it); + ++it; + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece("second"), *it); + ++it; + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece("third"), *it); + ++it; + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece("fourth"), *it); + ++it; + BOOST_CHECK(!it); +} + +BOOST_AUTO_TEST_CASE(null_entries) { + const char str[] = "\0split\0\0 \0me\0 "; + PieceIterator<' '> it(StringPiece(str, sizeof(str) - 1)); + BOOST_REQUIRE(it); + const char first[] = "\0split\0\0"; + BOOST_CHECK_EQUAL(StringPiece(first, sizeof(first) - 1), *it); + ++it; + BOOST_REQUIRE(it); + const char second[] = "\0me\0"; + BOOST_CHECK_EQUAL(StringPiece(second, sizeof(second) - 1), *it); + ++it; + BOOST_CHECK(!it); +} + +/*BOOST_AUTO_TEST_CASE(pipe_pipe_none) { + const char str[] = "nodelimit at all"; + TokenIter<MultiCharacter> it(str, MultiCharacter("|||")); + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece(str), *it); + ++it; + BOOST_CHECK(!it); +} +BOOST_AUTO_TEST_CASE(pipe_pipe_two) { + const char str[] = "|||"; + TokenIter<MultiCharacter> it(str, MultiCharacter("|||")); + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece(), *it); + ++it; + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece(), *it); + ++it; + BOOST_CHECK(!it); +} + +BOOST_AUTO_TEST_CASE(remove_empty) { + const char str[] = "|||"; + TokenIter<MultiCharacter, true> it(str, MultiCharacter("|||")); + BOOST_CHECK(!it); +}*/ + +BOOST_AUTO_TEST_CASE(remove_empty_keep) { + const char str[] = " |||"; + TokenIter<MultiCharacter, true> it(str, MultiCharacter("|||")); + BOOST_REQUIRE(it); + BOOST_CHECK_EQUAL(StringPiece(" "), *it); + ++it; + BOOST_CHECK(!it); +} + +} // namespace +} // namespace util |