diff options
40 files changed, 2261 insertions, 1183 deletions
diff --git a/klm/compile.sh b/klm/compile.sh index abe3473a..56f2e9b2 100755 --- a/klm/compile.sh +++ b/klm/compile.sh @@ -3,10 +3,12 @@ #If your code uses ICU, edit util/string_piece.hh and uncomment #define USE_ICU #I use zlib by default. If you don't want to depend on zlib, remove #define USE_ZLIB from util/file_piece.hh +#don't need to use if compiling with moses Makefiles already + set -e -for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{bhiksha,binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do - g++ -I. -O3 $CXXFLAGS -c $i.cc -o $i.o +for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,file,mmap} lm/{bhiksha,binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,trie_sort,virtual_interface,vocab}; do + g++ -I. -O3 -DNDEBUG $CXXFLAGS -c $i.cc -o $i.o done -g++ -I. -O3 $CXXFLAGS lm/build_binary.cc {lm,util}/*.o -lz -o build_binary -g++ -I. -O3 $CXXFLAGS lm/ngram_query.cc {lm,util}/*.o -lz -o query +g++ -I. -O3 -DNDEBUG $CXXFLAGS lm/build_binary.cc {lm,util}/*.o -lz -o build_binary +g++ -I. -O3 -DNDEBUG $CXXFLAGS lm/ngram_query.cc {lm,util}/*.o -lz -o query diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index cfb2b053..ff7fe452 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -12,7 +12,7 @@ #include <inttypes.h> -#include "lm/binary_format.hh" +#include "lm/model_type.hh" #include "lm/trie.hh" #include "util/bit_packing.hh" #include "util/sorted_uniform.hh" diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index e02e621a..27cada13 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -19,10 +19,10 @@ namespace lm { namespace ngram { namespace { const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; -const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 4\n\0"; +const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0"; // This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed). const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n"; -const long int kMagicVersion = 4; +const long int kMagicVersion = 5; // Test values. struct Sanity { @@ -42,12 +42,6 @@ struct Sanity { const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; -std::size_t Align8(std::size_t in) { - std::size_t off = in % 8; - if (!off) return in; - return in + 8 - off; -} - std::size_t TotalHeaderSize(unsigned char order) { return Align8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); } @@ -119,7 +113,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t } } -void FinishFile(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, Backing &backing) { +void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, Backing &backing) { if (config.write_mmap) { if (msync(backing.search.get(), backing.search.size(), MS_SYNC) || msync(backing.vocab.get(), backing.vocab.size(), MS_SYNC)) UTIL_THROW(util::ErrnoException, "msync failed for " << config.write_mmap); @@ -130,6 +124,7 @@ void FinishFile(const Config &config, ModelType model_type, const std::vector<ui params.fixed.probing_multiplier = config.probing_multiplier; params.fixed.model_type = model_type; params.fixed.has_vocabulary = config.include_vocab; + params.fixed.search_version = search_version; WriteHeader(backing.vocab.get(), params); } } @@ -174,12 +169,13 @@ void ReadHeader(int fd, Parameters &out) { ReadLoop(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order); } -void MatchCheck(ModelType model_type, const Parameters ¶ms) { +void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters ¶ms) { if (params.fixed.model_type != model_type) { if (static_cast<unsigned int>(params.fixed.model_type) >= (sizeof(kModelNames) / sizeof(const char *))) UTIL_THROW(FormatLoadException, "The binary file claims to be model type " << static_cast<unsigned int>(params.fixed.model_type) << " but this is not implemented for in this inference code."); UTIL_THROW(FormatLoadException, "The binary file was built for " << kModelNames[params.fixed.model_type] << " but the inference code is trying to load " << kModelNames[model_type]); } + UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version); } void SeekPastHeader(int fd, const Parameters ¶ms) { diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index d28cb6c5..e9df0892 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -2,6 +2,7 @@ #define LM_BINARY_FORMAT__ #include "lm/config.hh" +#include "lm/model_type.hh" #include "lm/read_arpa.hh" #include "util/file_piece.hh" @@ -16,13 +17,6 @@ namespace lm { namespace ngram { -/* Not the best numbering system, but it grew this way for historical reasons - * and I want to preserve existing binary files. */ -typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType; - -const static ModelType kQuantAdd = static_cast<ModelType>(QUANT_TRIE_SORTED - TRIE_SORTED); -const static ModelType kArrayAdd = static_cast<ModelType>(ARRAY_TRIE_SORTED - TRIE_SORTED); - /*Inspect a file to determine if it is a binary lm. If not, return false. * If so, return true and set recognized to the type. This is the only API in * this header designed for use by decoder authors. @@ -36,8 +30,14 @@ struct FixedWidthParameters { ModelType model_type; // Does the end of the file have the actual strings in the vocabulary? bool has_vocabulary; + unsigned int search_version; }; +inline std::size_t Align8(std::size_t in) { + std::size_t off = in % 8; + return off ? (in + 8 - off) : in; +} + // Parameters stored in the header of a binary file. struct Parameters { FixedWidthParameters fixed; @@ -64,7 +64,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t // Write header to binary file. This is done last to prevent incomplete files // from loading. -void FinishFile(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, Backing &backing); +void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, Backing &backing); namespace detail { @@ -72,7 +72,7 @@ bool IsBinaryFormat(int fd); void ReadHeader(int fd, Parameters ¶ms); -void MatchCheck(ModelType model_type, const Parameters ¶ms); +void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters ¶ms); void SeekPastHeader(int fd, const Parameters ¶ms); @@ -90,7 +90,7 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to) if (detail::IsBinaryFormat(backing.file.get())) { Parameters params; detail::ReadHeader(backing.file.get(), params); - detail::MatchCheck(To::kModelType, params); + detail::MatchCheck(To::kModelType, To::kVersion, params); // Replace the run-time configured probing_multiplier with the one in the file. Config new_config(config); new_config.probing_multiplier = params.fixed.probing_multiplier; diff --git a/klm/lm/blank.hh b/klm/lm/blank.hh index 162411a9..2fb64cd0 100644 --- a/klm/lm/blank.hh +++ b/klm/lm/blank.hh @@ -38,20 +38,6 @@ inline bool HasExtension(const float &backoff) { return compare.i != interpret.i; } -/* Suppose "foo bar baz quux" appears in the ARPA but not "bar baz quux" or - * "baz quux" (because they were pruned). 1.2% of n-grams generated by SRI - * with default settings on the benchmark data set are like this. Since search - * proceeds by finding "quux", "baz quux", "bar baz quux", and finally - * "foo bar baz quux" and the trie needs pointer nodes anyway, blanks are - * inserted. The blanks have probability kBlankProb and backoff kBlankBackoff. - * A blank is recognized by kBlankProb in the probability field; kBlankBackoff - * must be 0 so that inference asseses zero backoff from these blanks. - */ -const float kBlankProb = -std::numeric_limits<float>::infinity(); -const float kBlankBackoff = kNoExtensionBackoff; -const uint32_t kBlankProbQuant = 0; -const uint32_t kBlankBackoffQuant = 0; - } // namespace ngram } // namespace lm #endif // LM_BLANK__ diff --git a/klm/lm/left.hh b/klm/lm/left.hh new file mode 100644 index 00000000..df69e97a --- /dev/null +++ b/klm/lm/left.hh @@ -0,0 +1,181 @@ +#ifndef LM_LEFT__ +#define LM_LEFT__ + +#include "lm/max_order.hh" +#include "lm/model.hh" +#include "lm/return.hh" + +#include <algorithm> + +namespace lm { +namespace ngram { + +struct Left { + bool operator==(const Left &other) const { + return + (length == other.length) && + pointers[length - 1] == other.pointers[length - 1]; + } + + int Compare(const Left &other) const { + if (length != other.length) { + return (int)length - (int)other.length; + } + if (pointers[length - 1] > other.pointers[length - 1]) return 1; + if (pointers[length - 1] < other.pointers[length - 1]) return -1; + return 0; + } + + uint64_t pointers[kMaxOrder - 1]; + unsigned char length; +}; + +struct ChartState { + bool operator==(const ChartState &other) { + return (left == other.left) && (right == other.right) && (full == other.full); + } + + int Compare(const ChartState &other) const { + int lres = left.Compare(other.left); + if (lres) return lres; + int rres = right.Compare(other.right); + if (rres) return rres; + return (int)full - (int)other.full; + } + + Left left; + State right; + bool full; +}; + +template <class M> class RuleScore { + public: + explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) { + out.left.length = 0; + out.right.length = 0; + } + + void BeginSentence() { + out_.right = model_.BeginSentenceState(); + // out_.left is empty. + left_done_ = true; + } + + void Terminal(WordIndex word) { + State copy(out_.right); + FullScoreReturn ret = model_.FullScore(copy, word, out_.right); + ProcessRet(ret); + if (out_.right.length != copy.length + 1) left_done_ = true; + } + + // Faster version of NonTerminal for the case where the rule begins with a non-terminal. + void BeginNonTerminal(const ChartState &in, float prob) { + prob_ = prob; + out_ = in; + left_write_ = out_.left.pointers + out_.left.length; + left_done_ = in.full; + } + + void NonTerminal(const ChartState &in, float prob) { + prob_ += prob; + + if (!in.left.length) { + if (in.full) { + for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i; + left_done_ = true; + out_.right = in.right; + } + return; + } + + if (!out_.right.length) { + out_.right = in.right; + if (left_done_) return; + if (left_write_ != out_.left.pointers) { + left_done_ = true; + } else { + out_.left = in.left; + left_write_ = out_.left.pointers + in.left.length; + left_done_ = in.full; + } + return; + } + + float backoffs[kMaxOrder - 1], backoffs2[kMaxOrder - 1]; + float *back = backoffs, *back2 = backoffs2; + unsigned char next_use; + FullScoreReturn ret; + ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + out_.right.length, out_.right.backoff, in.left.pointers[0], 1, back, next_use)); + if (!next_use) { + left_done_ = true; + out_.right = in.right; + return; + } + unsigned char extend_length = 2; + for (const uint64_t *i = in.left.pointers + 1; i < in.left.pointers + in.left.length; ++i, ++extend_length) { + ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + next_use, back, *i, extend_length, back2, next_use)); + if (!next_use) { + left_done_ = true; + out_.right = in.right; + return; + } + std::swap(back, back2); + } + + if (in.full) { + for (const float *i = back; i != back + next_use; ++i) prob_ += *i; + left_done_ = true; + out_.right = in.right; + return; + } + + // Right state was minimized, so it's already independent of the new words to the left. + if (in.right.length < in.left.length) { + out_.right = in.right; + return; + } + + // Shift exisiting words down. + for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) { + *(i + in.right.length) = *i; + } + // Add words from in.right. + std::copy(in.right.words, in.right.words + in.right.length, out_.right.words); + // Assemble backoff composed on the existing state's backoff followed by the new state's backoff. + std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff); + std::copy(back, back + next_use, out_.right.backoff + in.right.length); + out_.right.length = in.right.length + next_use; + } + + float Finish() { + out_.left.length = left_write_ - out_.left.pointers; + out_.full = left_done_; + return prob_; + } + + private: + void ProcessRet(const FullScoreReturn &ret) { + prob_ += ret.prob; + if (left_done_) return; + if (ret.independent_left) { + left_done_ = true; + return; + } + *(left_write_++) = ret.extend_left; + } + + const M &model_; + + ChartState &out_; + + bool left_done_; + + uint64_t *left_write_; + + float prob_; +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_LEFT__ diff --git a/klm/lm/left_test.cc b/klm/lm/left_test.cc new file mode 100644 index 00000000..8bb91cb3 --- /dev/null +++ b/klm/lm/left_test.cc @@ -0,0 +1,360 @@ +#include "lm/left.hh" +#include "lm/model.hh" + +#include "util/tokenize_piece.hh" + +#include <vector> + +#define BOOST_TEST_MODULE LeftTest +#include <boost/test/unit_test.hpp> +#include <boost/test/floating_point_comparison.hpp> + +namespace lm { +namespace ngram { +namespace { + +#define Term(word) score.Terminal(m.GetVocabulary().Index(word)); +#define VCheck(word, value) BOOST_CHECK_EQUAL(m.GetVocabulary().Index(word), value); + +template <class M> void Short(const M &m) { + ChartState base; + { + RuleScore<M> score(m, base); + Term("more"); + Term("loin"); + BOOST_CHECK_CLOSE(-1.206319 - 0.3561665, score.Finish(), 0.001); + } + BOOST_CHECK(base.full); + BOOST_CHECK_EQUAL(2, base.left.length); + BOOST_CHECK_EQUAL(1, base.right.length); + VCheck("loin", base.right.words[0]); + + ChartState more_left; + { + RuleScore<M> score(m, more_left); + Term("little"); + score.NonTerminal(base, -1.206319 - 0.3561665); + // p(little more loin | null context) + BOOST_CHECK_CLOSE(-1.56538, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(3, more_left.left.length); + BOOST_CHECK_EQUAL(1, more_left.right.length); + VCheck("loin", more_left.right.words[0]); + BOOST_CHECK(more_left.full); + + ChartState shorter; + { + RuleScore<M> score(m, shorter); + Term("to"); + score.NonTerminal(base, -1.206319 - 0.3561665); + BOOST_CHECK_CLOSE(-0.30103 - 1.687872 - 1.206319 - 0.3561665, score.Finish(), 0.01); + } + BOOST_CHECK_EQUAL(1, shorter.left.length); + BOOST_CHECK_EQUAL(1, shorter.right.length); + VCheck("loin", shorter.right.words[0]); + BOOST_CHECK(shorter.full); +} + +template <class M> void Charge(const M &m) { + ChartState base; + { + RuleScore<M> score(m, base); + Term("on"); + Term("more"); + BOOST_CHECK_CLOSE(-1.509559 -0.4771212 -1.206319, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(1, base.left.length); + BOOST_CHECK_EQUAL(1, base.right.length); + VCheck("more", base.right.words[0]); + BOOST_CHECK(base.full); + + ChartState extend; + { + RuleScore<M> score(m, extend); + Term("looking"); + score.NonTerminal(base, -1.509559 -0.4771212 -1.206319); + BOOST_CHECK_CLOSE(-3.91039, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(2, extend.left.length); + BOOST_CHECK_EQUAL(1, extend.right.length); + VCheck("more", extend.right.words[0]); + BOOST_CHECK(extend.full); + + ChartState tobos; + { + RuleScore<M> score(m, tobos); + score.BeginSentence(); + score.NonTerminal(extend, -3.91039); + BOOST_CHECK_CLOSE(-3.471169, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(0, tobos.left.length); + BOOST_CHECK_EQUAL(1, tobos.right.length); +} + +template <class M> float LeftToRight(const M &m, const std::vector<WordIndex> &words) { + float ret = 0.0; + State right = m.NullContextState(); + for (std::vector<WordIndex>::const_iterator i = words.begin(); i != words.end(); ++i) { + State copy(right); + ret += m.Score(copy, *i, right); + } + return ret; +} + +template <class M> float RightToLeft(const M &m, const std::vector<WordIndex> &words) { + float ret = 0.0; + ChartState state; + state.left.length = 0; + state.right.length = 0; + state.full = false; + for (std::vector<WordIndex>::const_reverse_iterator i = words.rbegin(); i != words.rend(); ++i) { + ChartState copy(state); + RuleScore<M> score(m, state); + score.Terminal(*i); + score.NonTerminal(copy, ret); + ret = score.Finish(); + } + return ret; +} + +template <class M> float TreeMiddle(const M &m, const std::vector<WordIndex> &words) { + std::vector<std::pair<ChartState, float> > states(words.size()); + for (unsigned int i = 0; i < words.size(); ++i) { + RuleScore<M> score(m, states[i].first); + score.Terminal(words[i]); + states[i].second = score.Finish(); + } + while (states.size() > 1) { + std::vector<std::pair<ChartState, float> > upper((states.size() + 1) / 2); + for (unsigned int i = 0; i < states.size() / 2; ++i) { + RuleScore<M> score(m, upper[i].first); + score.NonTerminal(states[i*2].first, states[i*2].second); + score.NonTerminal(states[i*2+1].first, states[i*2+1].second); + upper[i].second = score.Finish(); + } + if (states.size() % 2) { + upper.back() = states.back(); + } + std::swap(states, upper); + } + return states.empty() ? 0 : states.back().second; +} + +template <class M> void LookupVocab(const M &m, const StringPiece &str, std::vector<WordIndex> &out) { + out.clear(); + for (util::PieceIterator<' '> i(str); i; ++i) { + out.push_back(m.GetVocabulary().Index(*i)); + } +} + +#define TEXT_TEST(str) \ +{ \ + std::vector<WordIndex> words; \ + LookupVocab(m, str, words); \ + float expect = LeftToRight(m, words); \ + BOOST_CHECK_CLOSE(expect, RightToLeft(m, words), 0.001); \ + BOOST_CHECK_CLOSE(expect, TreeMiddle(m, words), 0.001); \ +} + +// Build sentences, or parts thereof, from right to left. +template <class M> void GrowBig(const M &m) { + TEXT_TEST("in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>"); + TEXT_TEST("on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>"); + TEXT_TEST("on a little more loin also would consider higher to look good"); + TEXT_TEST("more loin also would consider higher to look good"); + TEXT_TEST("more loin also would consider higher to look"); + TEXT_TEST("also would consider higher to look"); + TEXT_TEST("also would consider higher"); + TEXT_TEST("would consider higher to look"); + TEXT_TEST("consider higher to look"); + TEXT_TEST("consider higher to"); + TEXT_TEST("consider higher"); +} + +template <class M> void AlsoWouldConsiderHigher(const M &m) { + ChartState also; + { + RuleScore<M> score(m, also); + score.Terminal(m.GetVocabulary().Index("also")); + BOOST_CHECK_CLOSE(-1.687872, score.Finish(), 0.001); + } + ChartState would; + { + RuleScore<M> score(m, would); + score.Terminal(m.GetVocabulary().Index("would")); + BOOST_CHECK_CLOSE(-1.687872, score.Finish(), 0.001); + } + ChartState combine_also_would; + { + RuleScore<M> score(m, combine_also_would); + score.NonTerminal(also, -1.687872); + score.NonTerminal(would, -1.687872); + BOOST_CHECK_CLOSE(-1.687872 - 2.0, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(2, combine_also_would.right.length); + + ChartState also_would; + { + RuleScore<M> score(m, also_would); + score.Terminal(m.GetVocabulary().Index("also")); + score.Terminal(m.GetVocabulary().Index("would")); + BOOST_CHECK_CLOSE(-1.687872 - 2.0, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(2, also_would.right.length); + + ChartState consider; + { + RuleScore<M> score(m, consider); + score.Terminal(m.GetVocabulary().Index("consider")); + BOOST_CHECK_CLOSE(-1.687872, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(1, consider.left.length); + BOOST_CHECK_EQUAL(1, consider.right.length); + BOOST_CHECK(!consider.full); + + ChartState higher; + float higher_score; + { + RuleScore<M> score(m, higher); + score.Terminal(m.GetVocabulary().Index("higher")); + higher_score = score.Finish(); + } + BOOST_CHECK_CLOSE(-1.509559, higher_score, 0.001); + BOOST_CHECK_EQUAL(1, higher.left.length); + BOOST_CHECK_EQUAL(1, higher.right.length); + BOOST_CHECK(!higher.full); + VCheck("higher", higher.right.words[0]); + BOOST_CHECK_CLOSE(-0.30103, higher.right.backoff[0], 0.001); + + ChartState consider_higher; + { + RuleScore<M> score(m, consider_higher); + score.NonTerminal(consider, -1.687872); + score.NonTerminal(higher, higher_score); + BOOST_CHECK_CLOSE(-1.509559 - 1.687872 - 0.30103, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(2, consider_higher.left.length); + BOOST_CHECK(!consider_higher.full); + + ChartState full; + { + RuleScore<M> score(m, full); + score.NonTerminal(combine_also_would, -1.687872 - 2.0); + score.NonTerminal(consider_higher, -1.509559 - 1.687872 - 0.30103); + BOOST_CHECK_CLOSE(-10.6879, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(4, full.right.length); +} + +template <class M> void GrowSmall(const M &m) { + TEXT_TEST("in biarritz watching considering looking . </s>"); + TEXT_TEST("in biarritz watching considering looking ."); + TEXT_TEST("in biarritz"); +} + +#define CHECK_SCORE(str, val) \ +{ \ + float got = val; \ + std::vector<WordIndex> indices; \ + LookupVocab(m, str, indices); \ + BOOST_CHECK_CLOSE(LeftToRight(m, indices), got, 0.001); \ +} + +template <class M> void FullGrow(const M &m) { + std::vector<WordIndex> words; + LookupVocab(m, "in biarritz watching considering looking . </s>", words); + + ChartState lexical[7]; + float lexical_scores[7]; + for (unsigned int i = 0; i < 7; ++i) { + RuleScore<M> score(m, lexical[i]); + score.Terminal(words[i]); + lexical_scores[i] = score.Finish(); + } + CHECK_SCORE("in", lexical_scores[0]); + CHECK_SCORE("biarritz", lexical_scores[1]); + CHECK_SCORE("watching", lexical_scores[2]); + CHECK_SCORE("</s>", lexical_scores[6]); + + ChartState l1[4]; + float l1_scores[4]; + { + RuleScore<M> score(m, l1[0]); + score.NonTerminal(lexical[0], lexical_scores[0]); + score.NonTerminal(lexical[1], lexical_scores[1]); + CHECK_SCORE("in biarritz", l1_scores[0] = score.Finish()); + } + { + RuleScore<M> score(m, l1[1]); + score.NonTerminal(lexical[2], lexical_scores[2]); + score.NonTerminal(lexical[3], lexical_scores[3]); + CHECK_SCORE("watching considering", l1_scores[1] = score.Finish()); + } + { + RuleScore<M> score(m, l1[2]); + score.NonTerminal(lexical[4], lexical_scores[4]); + score.NonTerminal(lexical[5], lexical_scores[5]); + CHECK_SCORE("looking .", l1_scores[2] = score.Finish()); + } + BOOST_CHECK_EQUAL(l1[2].left.length, 1); + l1[3] = lexical[6]; + l1_scores[3] = lexical_scores[6]; + + ChartState l2[2]; + float l2_scores[2]; + { + RuleScore<M> score(m, l2[0]); + score.NonTerminal(l1[0], l1_scores[0]); + score.NonTerminal(l1[1], l1_scores[1]); + CHECK_SCORE("in biarritz watching considering", l2_scores[0] = score.Finish()); + } + { + RuleScore<M> score(m, l2[1]); + score.NonTerminal(l1[2], l1_scores[2]); + score.NonTerminal(l1[3], l1_scores[3]); + CHECK_SCORE("looking . </s>", l2_scores[1] = score.Finish()); + } + BOOST_CHECK_EQUAL(l2[1].left.length, 1); + BOOST_CHECK(l2[1].full); + + ChartState top; + { + RuleScore<M> score(m, top); + score.NonTerminal(l2[0], l2_scores[0]); + score.NonTerminal(l2[1], l2_scores[1]); + CHECK_SCORE("in biarritz watching considering looking . </s>", score.Finish()); + } +} + +template <class M> void Everything() { + Config config; + config.messages = NULL; + M m("test.arpa", config); + + Short(m); + Charge(m); + GrowBig(m); + AlsoWouldConsiderHigher(m); + GrowSmall(m); + FullGrow(m); +} + +BOOST_AUTO_TEST_CASE(ProbingAll) { + Everything<Model>(); +} +BOOST_AUTO_TEST_CASE(TrieAll) { + Everything<TrieModel>(); +} +BOOST_AUTO_TEST_CASE(QuantTrieAll) { + Everything<QuantTrieModel>(); +} +BOOST_AUTO_TEST_CASE(ArrayQuantTrieAll) { + Everything<QuantArrayTrieModel>(); +} +BOOST_AUTO_TEST_CASE(ArrayTrieAll) { + Everything<ArrayTrieModel>(); +} + +} // namespace +} // namespace ngram +} // namespace lm diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 27e24b1c..ca581d8a 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -16,7 +16,7 @@ namespace lm { namespace ngram { size_t hash_value(const State &state) { - return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_); + return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length); } namespace detail { @@ -41,11 +41,11 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge // g++ prints warnings unless these are fully initialized. State begin_sentence = State(); - begin_sentence.valid_length_ = 1; - begin_sentence.history_[0] = vocab_.BeginSentence(); - begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff; + begin_sentence.length = 1; + begin_sentence.words[0] = vocab_.BeginSentence(); + begin_sentence.backoff[0] = search_.unigram.Lookup(begin_sentence.words[0]).backoff; State null_context = State(); - null_context.valid_length_ = 0; + null_context.length = 0; P::Init(begin_sentence, null_context, vocab_, search_.MiddleEnd() - search_.MiddleBegin() + 2); } @@ -87,7 +87,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT search_.unigram.Unknown().backoff = 0.0; search_.unigram.Unknown().prob = config.unknown_missing_logprob; } - FinishFile(config, kModelType, counts, backing_); + FinishFile(config, kModelType, kVersion, counts, backing_); } catch (util::Exception &e) { e << " Byte: " << f.Offset(); throw; @@ -95,9 +95,9 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT } template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { - FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, out_state); - if (ret.ngram_length - 1 < in_state.valid_length_) { - ret.prob = std::accumulate(in_state.backoff_ + ret.ngram_length - 1, in_state.backoff_ + in_state.valid_length_, ret.prob); + FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state); + if (ret.ngram_length - 1 < in_state.length) { + ret.prob = std::accumulate(in_state.backoff + ret.ngram_length - 1, in_state.backoff + in_state.length, ret.prob); } return ret; } @@ -131,32 +131,80 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT // Generate a state from context. context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); if (context_rend == context_rbegin) { - out_state.valid_length_ = 0; + out_state.length = 0; return; } - float ignored_prob; + FullScoreReturn ignored; typename Search::Node node; - search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node); - out_state.valid_length_ = HasExtension(out_state.backoff_[0]) ? 1 : 0; - float *backoff_out = out_state.backoff_ + 1; + search_.LookupUnigram(*context_rbegin, out_state.backoff[0], node, ignored); + out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0; + float *backoff_out = out_state.backoff + 1; const typename Search::Middle *mid = search_.MiddleBegin(); for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) { if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) { - std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_); + std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words); return; } - if (HasExtension(*backoff_out)) out_state.valid_length_ = i - context_rbegin + 1; + if (HasExtension(*backoff_out)) out_state.length = i - context_rbegin + 1; } - std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_); + std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words); +} + +template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ExtendLeft( + const WordIndex *add_rbegin, const WordIndex *add_rend, + const float *backoff_in, + uint64_t extend_pointer, + unsigned char extend_length, + float *backoff_out, + unsigned char &next_use) const { + FullScoreReturn ret; + float subtract_me; + typename Search::Node node(search_.Unpack(extend_pointer, extend_length, subtract_me)); + ret.prob = subtract_me; + ret.ngram_length = extend_length; + next_use = 0; + // If this function is called, then it does depend on left words. + ret.independent_left = false; + ret.extend_left = extend_pointer; + const typename Search::Middle *mid_iter = search_.MiddleBegin() + extend_length - 1; + const WordIndex *i = add_rbegin; + for (; ; ++i, ++backoff_out, ++mid_iter) { + if (i == add_rend) { + // Ran out of words. + for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; + ret.prob -= subtract_me; + return ret; + } + if (mid_iter == search_.MiddleEnd()) break; + if (ret.independent_left || !search_.LookupMiddle(*mid_iter, *i, *backoff_out, node, ret)) { + // Didn't match a word. + ret.independent_left = true; + for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; + ret.prob -= subtract_me; + return ret; + } + ret.ngram_length = mid_iter - search_.MiddleBegin() + 2; + if (HasExtension(*backoff_out)) next_use = i - add_rbegin + 1; + } + + if (ret.independent_left || !search_.LookupLongest(*i, ret.prob, node)) { + // The last backoff weight, for Order() - 1. + ret.prob += backoff_in[i - add_rbegin]; + } else { + ret.ngram_length = P::Order(); + } + ret.independent_left = true; + ret.prob -= subtract_me; + return ret; } namespace { // Do a paraonoid copy of history, assuming new_word has already been copied -// (hence the -1). out_state.valid_length_ could be zero so I avoided using +// (hence the -1). out_state.length could be zero so I avoided using // std::copy. void CopyRemainingHistory(const WordIndex *from, State &out_state) { - WordIndex *out = out_state.history_ + 1; - const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.valid_length_) - 1; + WordIndex *out = out_state.words + 1; + const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.length) - 1; for (const WordIndex *in = from; in < in_end; ++in, ++out) *out = *in; } } // namespace @@ -175,17 +223,17 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, // ret.ngram_length contains the last known non-blank ngram length. ret.ngram_length = 1; + float *backoff_out(out_state.backoff); typename Search::Node node; - float *backoff_out(out_state.backoff_); - search_.LookupUnigram(new_word, ret.prob, *backoff_out, node); - // This is the length of the context that should be used for continuation. - out_state.valid_length_ = HasExtension(*backoff_out) ? 1 : 0; + search_.LookupUnigram(new_word, *backoff_out, node, ret); + // This is the length of the context that should be used for continuation to the right. + out_state.length = HasExtension(*backoff_out) ? 1 : 0; // We'll write the word anyway since it will probably be used and does no harm being there. - out_state.history_[0] = new_word; + out_state.words[0] = new_word; if (context_rbegin == context_rend) return ret; ++backoff_out; - // Ok now we now that the bigram contains known words. Start by looking it up. + // Ok start by looking up the bigram. const WordIndex *hist_iter = context_rbegin; const typename Search::Middle *mid_iter = search_.MiddleBegin(); for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { @@ -198,36 +246,28 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, if (mid_iter == search_.MiddleEnd()) break; - float revert = ret.prob; - if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) { + if (ret.independent_left || !search_.LookupMiddle(*mid_iter, *hist_iter, *backoff_out, node, ret)) { // Didn't find an ngram using hist_iter. CopyRemainingHistory(context_rbegin, out_state); - // ret.prob was already set. + // ret.prob was already set. + ret.independent_left = true; return ret; } - if (ret.prob == kBlankProb) { - // It's a blank. Go back to the old probability. - ret.prob = revert; - } else { - ret.ngram_length = hist_iter - context_rbegin + 2; - if (HasExtension(*backoff_out)) { - out_state.valid_length_ = ret.ngram_length; - } + ret.ngram_length = hist_iter - context_rbegin + 2; + if (HasExtension(*backoff_out)) { + out_state.length = ret.ngram_length; } } // It passed every lookup in search_.middle. All that's left is to check search_.longest. - - if (!search_.LookupLongest(*hist_iter, ret.prob, node)) { - // Failed to find a longest n-gram. Fall back to the most recent non-blank. - CopyRemainingHistory(context_rbegin, out_state); - // ret.prob was already set. - return ret; + if (!ret.independent_left && search_.LookupLongest(*hist_iter, ret.prob, node)) { + // It's an P::Order()-gram. + // There is no blank in longest_. + ret.ngram_length = P::Order(); } - // It's an P::Order()-gram. + // This handles (N-1)-grams and N-grams. CopyRemainingHistory(context_rbegin, out_state); - // There is no blank in longest_. - ret.ngram_length = P::Order(); + ret.independent_left = true; return ret; } diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 21595321..fe91af2e 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -27,9 +27,9 @@ namespace ngram { class State { public: bool operator==(const State &other) const { - if (valid_length_ != other.valid_length_) return false; - const WordIndex *end = history_ + valid_length_; - for (const WordIndex *first = history_, *second = other.history_; + if (length != other.length) return false; + const WordIndex *end = words + length; + for (const WordIndex *first = words, *second = other.words; first != end; ++first, ++second) { if (*first != *second) return false; } @@ -39,27 +39,27 @@ class State { // Three way comparison function. int Compare(const State &other) const { - if (valid_length_ == other.valid_length_) { - return memcmp(history_, other.history_, valid_length_ * sizeof(WordIndex)); + if (length == other.length) { + return memcmp(words, other.words, length * sizeof(WordIndex)); } - return (valid_length_ < other.valid_length_) ? -1 : 1; + return (length < other.length) ? -1 : 1; } // Call this before using raw memcmp. void ZeroRemaining() { - for (unsigned char i = valid_length_; i < kMaxOrder - 1; ++i) { - history_[i] = 0; - backoff_[i] = 0.0; + for (unsigned char i = length; i < kMaxOrder - 1; ++i) { + words[i] = 0; + backoff[i] = 0.0; } } - unsigned char ValidLength() const { return valid_length_; } + unsigned char Length() const { return length; } // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD. // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit. - WordIndex history_[kMaxOrder - 1]; - float backoff_[kMaxOrder - 1]; - unsigned char valid_length_; + WordIndex words[kMaxOrder - 1]; + float backoff[kMaxOrder - 1]; + unsigned char length; }; size_t hash_value(const State &state); @@ -75,6 +75,8 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod // This is the model type returned by RecognizeBinary. static const ModelType kModelType; + static const unsigned int kVersion = Search::kVersion; + /* Get the size of memory that will be mapped given ngram counts. This * does not include small non-mapped control structures, such as this class * itself. @@ -114,6 +116,25 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod */ void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const; + /* More efficient version of FullScore where a partial n-gram has already + * been scored. + * NOTE: THE RETURNED .prob IS RELATIVE, NOT ABSOLUTE. So for example, if + * the n-gram does not end up extending further left, then 0 is returned. + */ + FullScoreReturn ExtendLeft( + // Additional context in reverse order. This will update add_rend to + const WordIndex *add_rbegin, const WordIndex *add_rend, + // Backoff weights to use. + const float *backoff_in, + // extend_left returned by a previous query. + uint64_t extend_pointer, + // Length of n-gram that the pointer corresponds to. + unsigned char extend_length, + // Where to write additional backoffs for [extend_length + 1, min(Order() - 1, return.ngram_length)] + float *backoff_out, + // Amount of additional content that should be considered by the next call. + unsigned char &next_use) const; + private: friend void LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to); diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 57c7291c..2654071f 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -10,8 +10,8 @@ namespace lm { namespace ngram { std::ostream &operator<<(std::ostream &o, const State &state) { - o << "State length " << static_cast<unsigned int>(state.valid_length_) << ':'; - for (const WordIndex *i = state.history_; i < state.history_ + state.valid_length_; ++i) { + o << "State length " << static_cast<unsigned int>(state.length) << ':'; + for (const WordIndex *i = state.words; i < state.words + state.length; ++i) { o << ' ' << *i; } return o; @@ -19,25 +19,26 @@ std::ostream &operator<<(std::ostream &o, const State &state) { namespace { -#define StartTest(word, ngram, score) \ +#define StartTest(word, ngram, score, indep_left) \ ret = model.FullScore( \ state, \ model.GetVocabulary().Index(word), \ out);\ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \ - BOOST_CHECK_GE(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_); \ + BOOST_CHECK_GE(std::min<unsigned char>(ngram, 5 - 1), out.length); \ + BOOST_CHECK_EQUAL(indep_left, ret.independent_left); \ {\ - WordIndex context[state.valid_length_ + 1]; \ + WordIndex context[state.length + 1]; \ context[0] = model.GetVocabulary().Index(word); \ - std::copy(state.history_, state.history_ + state.valid_length_, context + 1); \ + std::copy(state.words, state.words + state.length, context + 1); \ State get_state; \ - model.GetState(context, context + state.valid_length_ + 1, get_state); \ + model.GetState(context, context + state.length + 1, get_state); \ BOOST_CHECK_EQUAL(out, get_state); \ } -#define AppendTest(word, ngram, score) \ - StartTest(word, ngram, score) \ +#define AppendTest(word, ngram, score, indep_left) \ + StartTest(word, ngram, score, indep_left) \ state = out; template <class M> void Starters(const M &model) { @@ -45,12 +46,12 @@ template <class M> void Starters(const M &model) { Model::State state(model.BeginSentenceState()); Model::State out; - StartTest("looking", 2, -0.4846522); + StartTest("looking", 2, -0.4846522, true); // , probability plus <s> backoff - StartTest(",", 1, -1.383514 + -0.4149733); + StartTest(",", 1, -1.383514 + -0.4149733, true); // <unk> probability plus <s> backoff - StartTest("this_is_not_found", 1, -1.995635 + -0.4149733); + StartTest("this_is_not_found", 1, -1.995635 + -0.4149733, true); } template <class M> void Continuation(const M &model) { @@ -58,46 +59,64 @@ template <class M> void Continuation(const M &model) { Model::State state(model.BeginSentenceState()); Model::State out; - AppendTest("looking", 2, -0.484652); - AppendTest("on", 3, -0.348837); - AppendTest("a", 4, -0.0155266); - AppendTest("little", 5, -0.00306122); + AppendTest("looking", 2, -0.484652, true); + AppendTest("on", 3, -0.348837, true); + AppendTest("a", 4, -0.0155266, true); + AppendTest("little", 5, -0.00306122, true); State preserve = state; - AppendTest("the", 1, -4.04005); - AppendTest("biarritz", 1, -1.9889); - AppendTest("not_found", 1, -2.29666); - AppendTest("more", 1, -1.20632 - 20.0); - AppendTest(".", 2, -0.51363); - AppendTest("</s>", 3, -0.0191651); - BOOST_CHECK_EQUAL(0, state.valid_length_); + AppendTest("the", 1, -4.04005, true); + AppendTest("biarritz", 1, -1.9889, true); + AppendTest("not_found", 1, -2.29666, true); + AppendTest("more", 1, -1.20632 - 20.0, true); + AppendTest(".", 2, -0.51363, true); + AppendTest("</s>", 3, -0.0191651, true); + BOOST_CHECK_EQUAL(0, state.length); state = preserve; - AppendTest("more", 5, -0.00181395); - BOOST_CHECK_EQUAL(4, state.valid_length_); - AppendTest("loin", 5, -0.0432557); - BOOST_CHECK_EQUAL(1, state.valid_length_); + AppendTest("more", 5, -0.00181395, true); + BOOST_CHECK_EQUAL(4, state.length); + AppendTest("loin", 5, -0.0432557, true); + BOOST_CHECK_EQUAL(1, state.length); } template <class M> void Blanks(const M &model) { FullScoreReturn ret; State state(model.NullContextState()); State out; - AppendTest("also", 1, -1.687872); - AppendTest("would", 2, -2); - AppendTest("consider", 3, -3); + AppendTest("also", 1, -1.687872, false); + AppendTest("would", 2, -2, true); + AppendTest("consider", 3, -3, true); State preserve = state; - AppendTest("higher", 4, -4); - AppendTest("looking", 5, -5); - BOOST_CHECK_EQUAL(1, state.valid_length_); + AppendTest("higher", 4, -4, true); + AppendTest("looking", 5, -5, true); + BOOST_CHECK_EQUAL(1, state.length); state = preserve; - AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103); + // also would consider not_found + AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103, true); state = model.NullContextState(); // higher looking is a blank. - AppendTest("higher", 1, -1.509559); - AppendTest("looking", 1, -1.285941 - 0.30103); - AppendTest("not_found", 1, -1.995635 - 0.4771212); + AppendTest("higher", 1, -1.509559, false); + AppendTest("looking", 2, -1.285941 - 0.30103, false); + + State higher_looking = state; + + BOOST_CHECK_EQUAL(1, state.length); + AppendTest("not_found", 1, -1.995635 - 0.4771212, true); + + state = higher_looking; + // higher looking consider + AppendTest("consider", 1, -1.687872 - 0.4771212, true); + + state = model.NullContextState(); + AppendTest("would", 1, -1.687872, false); + BOOST_CHECK_EQUAL(1, state.length); + AppendTest("consider", 2, -1.687872 -0.30103, false); + BOOST_CHECK_EQUAL(2, state.length); + AppendTest("higher", 3, -1.509559 - 0.30103, false); + BOOST_CHECK_EQUAL(3, state.length); + AppendTest("looking", 4, -1.285941 - 0.30103, false); } template <class M> void Unknowns(const M &model) { @@ -105,14 +124,14 @@ template <class M> void Unknowns(const M &model) { State state(model.NullContextState()); State out; - AppendTest("not_found", 1, -1.995635); + AppendTest("not_found", 1, -1.995635, false); State preserve = state; - AppendTest("not_found2", 2, -15.0); - AppendTest("not_found3", 2, -15.0 - 2.0); + AppendTest("not_found2", 2, -15.0, true); + AppendTest("not_found3", 2, -15.0 - 2.0, true); state = preserve; - AppendTest("however", 2, -4); - AppendTest("not_found3", 3, -6); + AppendTest("however", 2, -4, true); + AppendTest("not_found3", 3, -6, true); } template <class M> void MinimalState(const M &model) { @@ -120,22 +139,66 @@ template <class M> void MinimalState(const M &model) { State state(model.NullContextState()); State out; - AppendTest("baz", 1, -6.535897); - BOOST_CHECK_EQUAL(0, state.valid_length_); + AppendTest("baz", 1, -6.535897, true); + BOOST_CHECK_EQUAL(0, state.length); state = model.NullContextState(); - AppendTest("foo", 1, -3.141592); - BOOST_CHECK_EQUAL(1, state.valid_length_); - AppendTest("bar", 2, -6.0); + AppendTest("foo", 1, -3.141592, true); + BOOST_CHECK_EQUAL(1, state.length); + AppendTest("bar", 2, -6.0, true); // Has to include the backoff weight. - BOOST_CHECK_EQUAL(1, state.valid_length_); - AppendTest("bar", 1, -2.718281 + 3.0); - BOOST_CHECK_EQUAL(1, state.valid_length_); + BOOST_CHECK_EQUAL(1, state.length); + AppendTest("bar", 1, -2.718281 + 3.0, true); + BOOST_CHECK_EQUAL(1, state.length); state = model.NullContextState(); - AppendTest("to", 1, -1.687872); - AppendTest("look", 2, -0.2922095); - BOOST_CHECK_EQUAL(2, state.valid_length_); - AppendTest("good", 3, -7); + AppendTest("to", 1, -1.687872, false); + AppendTest("look", 2, -0.2922095, true); + BOOST_CHECK_EQUAL(2, state.length); + AppendTest("good", 3, -7, true); +} + +template <class M> void ExtendLeftTest(const M &model) { + State right; + FullScoreReturn little(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("little"), right)); + const float kLittleProb = -1.285941; + BOOST_CHECK_CLOSE(kLittleProb, little.prob, 0.001); + unsigned char next_use; + float backoff_out[4]; + + FullScoreReturn extend_none(model.ExtendLeft(NULL, NULL, NULL, little.extend_left, 1, NULL, next_use)); + BOOST_CHECK_EQUAL(0, next_use); + BOOST_CHECK_EQUAL(little.extend_left, extend_none.extend_left); + BOOST_CHECK_CLOSE(0.0, extend_none.prob, 0.001); + BOOST_CHECK_EQUAL(1, extend_none.ngram_length); + + const WordIndex a = model.GetVocabulary().Index("a"); + float backoff_in = 3.14; + // a little + FullScoreReturn extend_a(model.ExtendLeft(&a, &a + 1, &backoff_in, little.extend_left, 1, backoff_out, next_use)); + BOOST_CHECK_EQUAL(1, next_use); + BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001); + BOOST_CHECK_CLOSE(-0.09132547 - kLittleProb, extend_a.prob, 0.001); + BOOST_CHECK_EQUAL(2, extend_a.ngram_length); + BOOST_CHECK(!extend_a.independent_left); + + const WordIndex on = model.GetVocabulary().Index("on"); + FullScoreReturn extend_on(model.ExtendLeft(&on, &on + 1, &backoff_in, extend_a.extend_left, 2, backoff_out, next_use)); + BOOST_CHECK_EQUAL(1, next_use); + BOOST_CHECK_CLOSE(-0.4771212, backoff_out[0], 0.001); + BOOST_CHECK_CLOSE(-0.0283603 - -0.09132547, extend_on.prob, 0.001); + BOOST_CHECK_EQUAL(3, extend_on.ngram_length); + BOOST_CHECK(!extend_on.independent_left); + + const WordIndex both[2] = {a, on}; + float backoff_in_arr[4]; + FullScoreReturn extend_both(model.ExtendLeft(both, both + 2, backoff_in_arr, little.extend_left, 1, backoff_out, next_use)); + BOOST_CHECK_EQUAL(2, next_use); + BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001); + BOOST_CHECK_CLOSE(-0.4771212, backoff_out[1], 0.001); + BOOST_CHECK_CLOSE(-0.0283603 - kLittleProb, extend_both.prob, 0.001); + BOOST_CHECK_EQUAL(3, extend_both.ngram_length); + BOOST_CHECK(!extend_both.independent_left); + BOOST_CHECK_EQUAL(extend_on.extend_left, extend_both.extend_left); } #define StatelessTest(word, provide, ngram, score) \ @@ -166,17 +229,17 @@ template <class M> void Stateless(const M &model) { // looking StatelessTest(1, 2, 2, -0.484652); // on - AppendTest("on", 3, -0.348837); + AppendTest("on", 3, -0.348837, true); StatelessTest(2, 3, 3, -0.348837); StatelessTest(2, 2, 3, -0.348837); StatelessTest(2, 1, 2, -0.4638903); // a StatelessTest(3, 4, 4, -0.0155266); // little - AppendTest("little", 5, -0.00306122); + AppendTest("little", 5, -0.00306122, true); StatelessTest(4, 5, 5, -0.00306122); // the - AppendTest("the", 1, -4.04005); + AppendTest("the", 1, -4.04005, true); StatelessTest(5, 5, 1, -4.04005); // No context of the. StatelessTest(5, 0, 1, -1.687872); @@ -189,8 +252,8 @@ template <class M> void Stateless(const M &model) { WordIndex unk[1]; unk[0] = 0; model.GetState(unk, unk + 1, state); - BOOST_CHECK_EQUAL(1, state.valid_length_); - BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.history_[0]); + BOOST_CHECK_EQUAL(1, state.length); + BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.words[0]); } template <class M> void NoUnkCheck(const M &model) { @@ -207,6 +270,7 @@ template <class M> void Everything(const M &m) { Blanks(m); Unknowns(m); MinimalState(m); + ExtendLeftTest(m); Stateless(m); } @@ -245,6 +309,7 @@ template <class ModelT> void LoadingTest() { config.enumerate_vocab = &enumerate; ModelT m("test.arpa", config); enumerate.Check(m.GetVocabulary()); + BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound()); Everything(m); } { @@ -252,6 +317,7 @@ template <class ModelT> void LoadingTest() { config.enumerate_vocab = &enumerate; ModelT m("test_nounk.arpa", config); enumerate.Check(m.GetVocabulary()); + BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound()); NoUnkCheck(m); } } diff --git a/klm/lm/model_type.hh b/klm/lm/model_type.hh new file mode 100644 index 00000000..5057ed25 --- /dev/null +++ b/klm/lm/model_type.hh @@ -0,0 +1,16 @@ +#ifndef LM_MODEL_TYPE__ +#define LM_MODEL_TYPE__ + +namespace lm { +namespace ngram { + +/* Not the best numbering system, but it grew this way for historical reasons + * and I want to preserve existing binary files. */ +typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType; + +const static ModelType kQuantAdd = static_cast<ModelType>(QUANT_TRIE_SORTED - TRIE_SORTED); +const static ModelType kArrayAdd = static_cast<ModelType>(ARRAY_TRIE_SORTED - TRIE_SORTED); + +} // namespace ngram +} // namespace lm +#endif // LM_MODEL_TYPE__ diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index fd371cc8..98a5d048 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -1,5 +1,6 @@ #include "lm/quantize.hh" +#include "lm/binary_format.hh" #include "lm/lm_exception.hh" #include <algorithm> @@ -70,8 +71,7 @@ void SeparatelyQuantize::Train(uint8_t order, std::vector<float> &prob, std::vec void SeparatelyQuantize::TrainProb(uint8_t order, std::vector<float> &prob) { float *centers = start_ + TableStart(order); - *(centers++) = kBlankProb; - MakeBins(&*prob.begin(), &*prob.end(), centers, (1ULL << prob_bits_) - 1); + MakeBins(&*prob.begin(), &*prob.end(), centers, (1ULL << prob_bits_)); } void SeparatelyQuantize::FinishedLoading(const Config &config) { diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index 0b71d14a..4cf4236e 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -1,9 +1,9 @@ #ifndef LM_QUANTIZE_H__ #define LM_QUANTIZE_H__ -#include "lm/binary_format.hh" // for ModelType #include "lm/blank.hh" #include "lm/config.hh" +#include "lm/model_type.hh" #include "util/bit_packing.hh" #include <algorithm> @@ -36,6 +36,9 @@ class DontQuantize { prob = util::ReadNonPositiveFloat31(base, bit_offset); backoff = util::ReadFloat32(base, bit_offset + 31); } + void ReadProb(const void *base, uint64_t bit_offset, float &prob) const { + prob = util::ReadNonPositiveFloat31(base, bit_offset); + } void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { backoff = util::ReadFloat32(base, bit_offset + 31); } @@ -77,7 +80,7 @@ class SeparatelyQuantize { Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} uint64_t EncodeProb(float value) const { - return(value == kBlankProb ? kBlankProbQuant : Encode(value, 1)); + return Encode(value, 0); } uint64_t EncodeBackoff(float value) const { @@ -132,6 +135,10 @@ class SeparatelyQuantize { (prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff)); } + void ReadProb(const void *base, uint64_t bit_offset, float &prob) const { + prob = prob_.Decode(util::ReadInt25(base, bit_offset + backoff_.Bits(), prob_.Bits(), prob_.Mask())); + } + void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_); prob = prob_.Decode(both >> backoff_.Bits()); @@ -179,7 +186,7 @@ class SeparatelyQuantize { void SetupMemory(void *start, const Config &config); static const bool kTrain = true; - // Assumes kBlankProb is removed from prob and 0.0 is removed from backoff. + // Assumes 0.0 is removed from backoff. void Train(uint8_t order, std::vector<float> &prob, std::vector<float> &backoff); // Train just probabilities (for longest order). void TrainProb(uint8_t order, std::vector<float> &prob); diff --git a/klm/lm/return.hh b/klm/lm/return.hh new file mode 100644 index 00000000..15571960 --- /dev/null +++ b/klm/lm/return.hh @@ -0,0 +1,39 @@ +#ifndef LM_RETURN__ +#define LM_RETURN__ + +#include <inttypes.h> + +namespace lm { +/* Structure returned by scoring routines. */ +struct FullScoreReturn { + // log10 probability + float prob; + + /* The length of n-gram matched. Do not use this for recombination. + * Consider a model containing only the following n-grams: + * -1 foo + * -3.14 bar + * -2.718 baz -5 + * -6 foo bar + * + * If you score ``bar'' then ngram_length is 1 and recombination state is the + * empty string because bar has zero backoff and does not extend to the + * right. + * If you score ``foo'' then ngram_length is 1 and recombination state is + * ``foo''. + * + * Ideally, keep output states around and compare them. Failing that, + * get out_state.ValidLength() and use that length for recombination. + */ + unsigned char ngram_length; + + /* Left extension information. If independent_left is set, then prob is + * independent of words to the left (up to additional backoff). Otherwise, + * extend_left indicates how to efficiently extend further to the left. + */ + bool independent_left; + uint64_t extend_left; // Defined only if independent_left +}; + +} // namespace lm +#endif // LM_RETURN__ diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 82c53ec8..334adf12 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -1,10 +1,12 @@ #include "lm/search_hashed.hh" +#include "lm/binary_format.hh" #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/read_arpa.hh" #include "lm/vocab.hh" +#include "util/bit_packing.hh" #include "util/file_piece.hh" #include <string> @@ -48,30 +50,77 @@ class ActivateUnigram { ProbBackoff *modify_; }; -template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector<Middle> &middle, Activate activate, Store &store, PositiveProbWarn &warn) { - - ReadNGramHeader(f, n); +template <class Middle> void FixSRI(int lower, float negative_lower_prob, unsigned int n, const uint64_t *keys, const WordIndex *vocab_ids, ProbBackoff *unigrams, std::vector<Middle> &middle) { ProbBackoff blank; - blank.prob = kBlankProb; - blank.backoff = kBlankBackoff; + blank.backoff = kNoExtensionBackoff; + // Fix SRI's stupidity. + // Note that negative_lower_prob is the negative of the probability (so it's currently >= 0). We still want the sign bit off to indicate left extension, so I just do -= on the backoffs. + blank.prob = negative_lower_prob; + // An entry was found at lower (order lower + 2). + // We need to insert blanks starting at lower + 1 (order lower + 3). + unsigned int fix = static_cast<unsigned int>(lower + 1); + uint64_t backoff_hash = detail::CombineWordHash(static_cast<uint64_t>(vocab_ids[1]), vocab_ids[2]); + if (fix == 0) { + // Insert a missing bigram. + blank.prob -= unigrams[vocab_ids[1]].backoff; + SetExtension(unigrams[vocab_ids[1]].backoff); + // Bigram including a unigram's backoff + middle[0].Insert(Middle::Packing::Make(keys[0], blank)); + fix = 1; + } else { + for (unsigned int i = 3; i < fix + 2; ++i) backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]); + } + // fix >= 1. Insert trigrams and above. + for (; fix <= n - 3; ++fix) { + typename Middle::MutableIterator gotit; + if (middle[fix - 1].UnsafeMutableFind(backoff_hash, gotit)) { + float &backoff = gotit->MutableValue().backoff; + SetExtension(backoff); + blank.prob -= backoff; + } + middle[fix].Insert(Middle::Packing::Make(keys[fix], blank)); + backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[fix + 2]); + } +} + +template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, ProbBackoff *unigrams, std::vector<Middle> &middle, Activate activate, Store &store, PositiveProbWarn &warn) { + ReadNGramHeader(f, n); // vocab ids of words in reverse order WordIndex vocab_ids[n]; uint64_t keys[n - 1]; typename Store::Packing::Value value; - typename Middle::ConstIterator found; + typename Middle::MutableIterator found; for (size_t i = 0; i < count; ++i) { ReadNGram(f, n, vocab, vocab_ids, value, warn); + keys[0] = detail::CombineWordHash(static_cast<uint64_t>(*vocab_ids), vocab_ids[1]); for (unsigned int h = 1; h < n - 1; ++h) { keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]); } + // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0. + util::SetSign(value.prob); store.Insert(Store::Packing::Make(keys[n-2], value)); - // Go back and insert blanks. - for (int lower = n - 3; lower >= 0; --lower) { - if (middle[lower].Find(keys[lower], found)) break; - middle[lower].Insert(Middle::Packing::Make(keys[lower], blank)); + // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb. + int lower; + util::FloatEnc fix_prob; + for (lower = n - 3; ; --lower) { + if (lower == -1) { + fix_prob.f = unigrams[vocab_ids[0]].prob; + fix_prob.i &= ~util::kSignBit; + unigrams[vocab_ids[0]].prob = fix_prob.f; + break; + } + if (middle[lower].UnsafeMutableFind(keys[lower], found)) { + // Turn off sign bit to indicate that it extends left. + fix_prob.f = found->MutableValue().prob; + fix_prob.i &= ~util::kSignBit; + found->MutableValue().prob = fix_prob.f; + // We don't need to recurse further down because this entry already set the bits for lower entries. + break; + } } + if (lower != static_cast<int>(n) - 3) FixSRI(lower, fix_prob.f, n, keys, vocab_ids, unigrams, middle); activate(vocab_ids, n); } @@ -107,15 +156,15 @@ template <class MiddleT, class LongestT> template <class Voc> void TemplateHashe try { if (counts.size() > 2) { - ReadNGrams(f, 2, counts[1], vocab, middle_, ActivateUnigram(unigram.Raw()), middle_[0], warn); + ReadNGrams(f, 2, counts[1], vocab, unigram.Raw(), middle_, ActivateUnigram(unigram.Raw()), middle_[0], warn); } for (unsigned int n = 3; n < counts.size(); ++n) { - ReadNGrams(f, n, counts[n-1], vocab, middle_, ActivateLowerMiddle<Middle>(middle_[n-3]), middle_[n-2], warn); + ReadNGrams(f, n, counts[n-1], vocab, unigram.Raw(), middle_, ActivateLowerMiddle<Middle>(middle_[n-3]), middle_[n-2], warn); } if (counts.size() > 2) { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateLowerMiddle<Middle>(middle_.back()), longest, warn); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, unigram.Raw(), middle_, ActivateLowerMiddle<Middle>(middle_.back()), longest, warn); } else { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateUnigram(unigram.Raw()), longest, warn); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, unigram.Raw(), middle_, ActivateUnigram(unigram.Raw()), longest, warn); } } catch (util::ProbingSizeException &e) { UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n"); @@ -133,7 +182,7 @@ template <class MiddleT, class LongestT> void TemplateHashedSearch<MiddleT, Long template class TemplateHashedSearch<ProbingHashedSearch::Middle, ProbingHashedSearch::Longest>; -template void TemplateHashedSearch<ProbingHashedSearch::Middle, ProbingHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); +template void TemplateHashedSearch<ProbingHashedSearch::Middle, ProbingHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); } // namespace detail } // namespace ngram diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index c62985e4..e289fd11 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -1,15 +1,18 @@ #ifndef LM_SEARCH_HASHED__ #define LM_SEARCH_HASHED__ -#include "lm/binary_format.hh" +#include "lm/model_type.hh" #include "lm/config.hh" #include "lm/read_arpa.hh" +#include "lm/return.hh" #include "lm/weights.hh" +#include "util/bit_packing.hh" #include "util/key_value_packing.hh" #include "util/probing_hash_table.hh" #include <algorithm> +#include <iostream> #include <vector> namespace util { class FilePiece; } @@ -52,9 +55,14 @@ struct HashedSearch { Unigram unigram; - void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { + void LookupUnigram(WordIndex word, float &backoff, Node &next, FullScoreReturn &ret) const { const ProbBackoff &entry = unigram.Lookup(word); - prob = entry.prob; + util::FloatEnc val; + val.f = entry.prob; + ret.independent_left = (val.i & util::kSignBit); + ret.extend_left = static_cast<uint64_t>(word); + val.i |= util::kSignBit; + ret.prob = val.f; backoff = entry.backoff; next = static_cast<Node>(word); } @@ -67,6 +75,8 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has typedef LongestT Longest; Longest longest; + static const unsigned int kVersion = 0; + // TODO: move probing_multiplier here with next binary file format update. static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} @@ -85,11 +95,33 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has const Middle *MiddleBegin() const { return &*middle_.begin(); } const Middle *MiddleEnd() const { return &*middle_.end(); } - bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { + Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const { + util::FloatEnc val; + if (extend_length == 1) { + val.f = unigram.Lookup(static_cast<uint64_t>(extend_pointer)).prob; + } else { + typename Middle::ConstIterator found; + if (!middle_[extend_length - 2].Find(extend_pointer, found)) { + std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl; + abort(); + } + val.f = found->GetValue().prob; + } + val.i |= util::kSignBit; + prob = val.f; + return extend_pointer; + } + + bool LookupMiddle(const Middle &middle, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { node = CombineWordHash(node, word); typename Middle::ConstIterator found; if (!middle.Find(node, found)) return false; - prob = found->GetValue().prob; + util::FloatEnc enc; + enc.f = found->GetValue().prob; + ret.independent_left = (enc.i & util::kSignBit); + ret.extend_left = node; + enc.i |= util::kSignBit; + ret.prob = enc.f; backoff = found->GetValue().backoff; return true; } @@ -105,6 +137,7 @@ template <class MiddleT, class LongestT> class TemplateHashedSearch : public Has } bool LookupLongest(WordIndex word, float &prob, Node &node) const { + // Sign bit is always on because longest n-grams do not extend left. node = CombineWordHash(node, word); typename Longest::ConstIterator found; if (!longest.Find(node, found)) return false; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 05059ffb..6479813b 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -2,26 +2,25 @@ #include "lm/search_trie.hh" #include "lm/bhiksha.hh" +#include "lm/binary_format.hh" #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/max_order.hh" #include "lm/quantize.hh" -#include "lm/read_arpa.hh" #include "lm/trie.hh" +#include "lm/trie_sort.hh" #include "lm/vocab.hh" #include "lm/weights.hh" #include "lm/word_index.hh" #include "util/ersatz_progress.hh" -#include "util/file_piece.hh" -#include "util/have.hh" #include "util/proxy_iterator.hh" #include "util/scoped.hh" +#include "util/sized_iterator.hh" #include <algorithm> -#include <cmath> #include <cstring> #include <cstdio> -#include <deque> +#include <queue> #include <limits> #include <numeric> #include <vector> @@ -29,575 +28,221 @@ #include <sys/mman.h> #include <sys/types.h> #include <sys/stat.h> -#include <fcntl.h> -#include <stdlib.h> -#include <unistd.h> namespace lm { namespace ngram { namespace trie { namespace { -/* An entry is a n-gram with probability. It consists of: - * WordIndex[order] - * float probability - * backoff probability (omitted for highest order n-gram) - * These are stored consecutively in memory. We want to sort them. - * - * The problem is the length depends on order (but all n-grams being compared - * have the same order). Allocating each entry on the heap (i.e. std::vector - * or std::string) then sorting pointers is the normal solution. But that's - * too memory inefficient. A lot of this code is just here to force std::sort - * to work with records where length is specified at runtime (and avoid using - * Boost for LM code). I could have used qsort, but the point is to also - * support __gnu_cxx:parallel_sort which doesn't have a qsort version. - */ - -class EntryIterator { - public: - EntryIterator() {} - - EntryIterator(void *ptr, std::size_t size) : ptr_(static_cast<uint8_t*>(ptr)), size_(size) {} - - bool operator==(const EntryIterator &other) const { - return ptr_ == other.ptr_; - } - bool operator<(const EntryIterator &other) const { - return ptr_ < other.ptr_; - } - EntryIterator &operator+=(std::ptrdiff_t amount) { - ptr_ += amount * size_; - return *this; - } - std::ptrdiff_t operator-(const EntryIterator &other) const { - return (ptr_ - other.ptr_) / size_; - } - - const void *Data() const { return ptr_; } - void *Data() { return ptr_; } - std::size_t EntrySize() const { return size_; } - - private: - uint8_t *ptr_; - std::size_t size_; -}; - -class EntryProxy { - public: - EntryProxy() {} - - EntryProxy(void *ptr, std::size_t size) : inner_(ptr, size) {} - - operator std::string() const { - return std::string(reinterpret_cast<const char*>(inner_.Data()), inner_.EntrySize()); - } - - EntryProxy &operator=(const EntryProxy &from) { - memcpy(inner_.Data(), from.inner_.Data(), inner_.EntrySize()); - return *this; - } - - EntryProxy &operator=(const std::string &from) { - memcpy(inner_.Data(), from.data(), inner_.EntrySize()); - return *this; - } - - const WordIndex *Indices() const { - return reinterpret_cast<const WordIndex*>(inner_.Data()); - } - - private: - friend class util::ProxyIterator<EntryProxy>; - - typedef std::string value_type; - - typedef EntryIterator InnerIterator; - InnerIterator &Inner() { return inner_; } - const InnerIterator &Inner() const { return inner_; } - InnerIterator inner_; -}; - -typedef util::ProxyIterator<EntryProxy> NGramIter; - -// Proxy for an entry except there is some extra cruft between the entries. This is used to sort (n-1)-grams using the same memory as the sorted n-grams. -class PartialViewProxy { - public: - PartialViewProxy() : attention_size_(0), inner_() {} - - PartialViewProxy(void *ptr, std::size_t block_size, std::size_t attention_size) : attention_size_(attention_size), inner_(ptr, block_size) {} - - operator std::string() const { - return std::string(reinterpret_cast<const char*>(inner_.Data()), attention_size_); - } - - PartialViewProxy &operator=(const PartialViewProxy &from) { - memcpy(inner_.Data(), from.inner_.Data(), attention_size_); - return *this; - } - - PartialViewProxy &operator=(const std::string &from) { - memcpy(inner_.Data(), from.data(), attention_size_); - return *this; - } - - const WordIndex *Indices() const { - return reinterpret_cast<const WordIndex*>(inner_.Data()); - } - - private: - friend class util::ProxyIterator<PartialViewProxy>; - - typedef std::string value_type; - - const std::size_t attention_size_; - - typedef EntryIterator InnerIterator; - InnerIterator &Inner() { return inner_; } - const InnerIterator &Inner() const { return inner_; } - InnerIterator inner_; -}; - -typedef util::ProxyIterator<PartialViewProxy> PartialIter; - -template <class Proxy> class CompareRecords : public std::binary_function<const Proxy &, const Proxy &, bool> { - public: - explicit CompareRecords(unsigned char order) : order_(order) {} - - bool operator()(const Proxy &first, const Proxy &second) const { - return Compare(first.Indices(), second.Indices()); - } - bool operator()(const Proxy &first, const std::string &second) const { - return Compare(first.Indices(), reinterpret_cast<const WordIndex*>(second.data())); - } - bool operator()(const std::string &first, const Proxy &second) const { - return Compare(reinterpret_cast<const WordIndex*>(first.data()), second.Indices()); - } - bool operator()(const std::string &first, const std::string &second) const { - return Compare(reinterpret_cast<const WordIndex*>(first.data()), reinterpret_cast<const WordIndex*>(second.data())); - } - - private: - bool Compare(const WordIndex *first, const WordIndex *second) const { - const WordIndex *end = first + order_; - for (; first != end; ++first, ++second) { - if (*first < *second) return true; - if (*first > *second) return false; - } - return false; - } - - unsigned char order_; -}; - -FILE *OpenOrThrow(const char *name, const char *mode) { - FILE *ret = fopen(name, mode); - if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode); - return ret; -} - -void WriteOrThrow(FILE *to, const void *data, size_t size) { - assert(size); - if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); -} - void ReadOrThrow(FILE *from, void *data, size_t size) { - if (1 != std::fread(data, size, 1, from)) UTIL_THROW(util::ErrnoException, "Short read; requested size " << size); -} - -const std::size_t kCopyBufSize = 512; -void CopyOrThrow(FILE *from, FILE *to, size_t size) { - char buf[std::min<size_t>(size, kCopyBufSize)]; - for (size_t i = 0; i < size; i += kCopyBufSize) { - std::size_t amount = std::min(size - i, kCopyBufSize); - ReadOrThrow(from, buf, amount); - WriteOrThrow(to, buf, amount); - } + UTIL_THROW_IF(1 != std::fread(data, size, 1, from), util::ErrnoException, "Short read"); } -void CopyRestOrThrow(FILE *from, FILE *to) { - char buf[kCopyBufSize]; - size_t amount; - while ((amount = fread(buf, 1, kCopyBufSize, from))) { - WriteOrThrow(to, buf, amount); +int Compare(unsigned char order, const void *first_void, const void *second_void) { + const WordIndex *first = reinterpret_cast<const WordIndex*>(first_void), *second = reinterpret_cast<const WordIndex*>(second_void); + const WordIndex *end = first + order; + for (; first != end; ++first, ++second) { + if (*first < *second) return -1; + if (*first > *second) return 1; } - if (!feof(from)) UTIL_THROW(util::ErrnoException, "Short read"); -} - -void RemoveOrThrow(const char *name) { - if (std::remove(name)) UTIL_THROW(util::ErrnoException, "Could not remove " << name); + return 0; } -std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order, std::size_t weights_size) { - const std::size_t entry_size = sizeof(WordIndex) * order + weights_size; - const std::size_t prefix_size = sizeof(WordIndex) * (order - 1); - std::stringstream assembled; - assembled << file_prefix << static_cast<unsigned int>(order) << '_' << batch; - std::string ret(assembled.str()); - util::scoped_FILE out(OpenOrThrow(ret.c_str(), "w")); - // Compress entries that being with the same (order-1) words. - for (const uint8_t *group_begin = static_cast<const uint8_t*>(mem_begin); group_begin != static_cast<const uint8_t*>(mem_end);) { - const uint8_t *group_end; - for (group_end = group_begin + entry_size; - (group_end != static_cast<const uint8_t*>(mem_end)) && !memcmp(group_begin, group_end, prefix_size); - group_end += entry_size) {} - WriteOrThrow(out.get(), group_begin, prefix_size); - WordIndex group_size = (group_end - group_begin) / entry_size; - WriteOrThrow(out.get(), &group_size, sizeof(group_size)); - for (const uint8_t *i = group_begin; i != group_end; i += entry_size) { - WriteOrThrow(out.get(), i + prefix_size, sizeof(WordIndex)); - WriteOrThrow(out.get(), i + sizeof(WordIndex) * order, weights_size); - } - group_begin = group_end; - } - return ret; -} +struct ProbPointer { + unsigned char array; + uint64_t index; +}; -class SortedFileReader { +// Array of n-grams and float indices. +class BackoffMessages { public: - SortedFileReader() : ended_(false) {} - - void Init(const std::string &name, unsigned char order) { - file_.reset(OpenOrThrow(name.c_str(), "r")); - header_.resize(order - 1); - NextHeader(); + void Init(std::size_t entry_size) { + current_ = NULL; + allocated_ = NULL; + entry_size_ = entry_size; } - // Preceding words. - const WordIndex *Header() const { - return &*header_.begin(); - } - const std::vector<WordIndex> &HeaderVector() const { return header_;} - - std::size_t HeaderBytes() const { return header_.size() * sizeof(WordIndex); } - - void NextHeader() { - if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get())) { - if (feof(file_.get())) { - ended_ = true; - } else { - UTIL_THROW(util::ErrnoException, "Short read of counts"); + void Add(const WordIndex *to, ProbPointer index) { + while (current_ + entry_size_ > allocated_) { + std::size_t allocated_size = allocated_ - (uint8_t*)backing_.get(); + Resize(std::max<std::size_t>(allocated_size * 2, entry_size_)); + } + memcpy(current_, to, entry_size_ - sizeof(ProbPointer)); + *reinterpret_cast<ProbPointer*>(current_ + entry_size_ - sizeof(ProbPointer)) = index; + current_ += entry_size_; + } + + void Apply(float *const *const base, FILE *unigrams) { + FinishedAdding(); + if (current_ == allocated_) return; + rewind(unigrams); + ProbBackoff weights; + WordIndex unigram = 0; + ReadOrThrow(unigrams, &weights, sizeof(weights)); + for (; current_ != allocated_; current_ += entry_size_) { + const WordIndex &cur_word = *reinterpret_cast<const WordIndex*>(current_); + for (; unigram < cur_word; ++unigram) { + ReadOrThrow(unigrams, &weights, sizeof(weights)); } + if (!HasExtension(weights.backoff)) { + weights.backoff = kExtensionBackoff; + UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed."); + WriteOrThrow(unigrams, &weights, sizeof(weights)); + } + const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + sizeof(WordIndex)); + base[write_to.array][write_to.index] += weights.backoff; } + backing_.reset(); + } + + void Apply(float *const *const base, RecordReader &reader) { + FinishedAdding(); + if (current_ == allocated_) return; + // We'll also use the same buffer to record messages to blanks that they extend. + WordIndex *extend_out = reinterpret_cast<WordIndex*>(current_); + const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex); + for (reader.Rewind(); reader && (current_ != allocated_); ) { + switch (Compare(order, reader.Data(), current_)) { + case -1: + ++reader; + break; + case 1: + // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends. + for (const WordIndex *w = reinterpret_cast<const WordIndex *>(current_); w != reinterpret_cast<const WordIndex *>(current_) + order; ++w, ++extend_out) *extend_out = *w; + current_ += entry_size_; + break; + case 0: + float &backoff = reinterpret_cast<ProbBackoff*>((uint8_t*)reader.Data() + order * sizeof(WordIndex))->backoff; + if (!HasExtension(backoff)) { + backoff = kExtensionBackoff; + reader.Overwrite(&backoff, sizeof(float)); + } else { + const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + entry_size_ - sizeof(ProbPointer)); + base[write_to.array][write_to.index] += backoff; + } + current_ += entry_size_; + break; + } + } + // Now this is a list of blanks that extend right. + entry_size_ = sizeof(WordIndex) * order; + Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get())); + current_ = (uint8_t*)backing_.get(); } - WordIndex ReadCount() { - WordIndex ret; - ReadOrThrow(file_.get(), &ret, sizeof(WordIndex)); - return ret; - } - - WordIndex ReadWord() { - WordIndex ret; - ReadOrThrow(file_.get(), &ret, sizeof(WordIndex)); - return ret; - } - - template <class Weights> void ReadWeights(Weights &weights) { - ReadOrThrow(file_.get(), &weights, sizeof(Weights)); + // Call after Apply + bool Extends(unsigned char order, const WordIndex *words) { + if (current_ == allocated_) return false; + assert(order * sizeof(WordIndex) == entry_size_); + while (true) { + switch(Compare(order, words, current_)) { + case 1: + current_ += entry_size_; + if (current_ == allocated_) return false; + break; + case -1: + return false; + case 0: + return true; + } + } } - bool Ended() const { - return ended_; + private: + void FinishedAdding() { + Resize(current_ - (uint8_t*)backing_.get()); + current_ = (uint8_t*)backing_.get(); } - void Rewind() { - rewind(file_.get()); - ended_ = false; - NextHeader(); + void Resize(std::size_t to) { + std::size_t current = current_ - (uint8_t*)backing_.get(); + backing_.call_realloc(to); + current_ = (uint8_t*)backing_.get() + current; + allocated_ = (uint8_t*)backing_.get() + to; } - FILE *File() { return file_.get(); } - - private: - util::scoped_FILE file_; + util::scoped_malloc backing_; - std::vector<WordIndex> header_; + uint8_t *current_, *allocated_; - bool ended_; + std::size_t entry_size_; }; -void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size) { - WriteOrThrow(to, from.Header(), from.HeaderBytes()); - WordIndex count = from.ReadCount(); - WriteOrThrow(to, &count, sizeof(WordIndex)); - - CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count); -} - -void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order) { - SortedFileReader first, second; - first.Init(first_name.c_str(), order); - RemoveOrThrow(first_name.c_str()); - second.Init(second_name.c_str(), order); - RemoveOrThrow(second_name.c_str()); - util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w")); - while (!first.Ended() && !second.Ended()) { - if (first.HeaderVector() < second.HeaderVector()) { - CopyFullRecord(first, out_file.get(), weights_size); - first.NextHeader(); - continue; - } - if (first.HeaderVector() > second.HeaderVector()) { - CopyFullRecord(second, out_file.get(), weights_size); - second.NextHeader(); - continue; - } - // Merge at the entry level. - WriteOrThrow(out_file.get(), first.Header(), first.HeaderBytes()); - WordIndex first_count = first.ReadCount(), second_count = second.ReadCount(); - WordIndex total_count = first_count + second_count; - WriteOrThrow(out_file.get(), &total_count, sizeof(WordIndex)); - - WordIndex first_word = first.ReadWord(), second_word = second.ReadWord(); - WordIndex first_index = 0, second_index = 0; - while (true) { - if (first_word < second_word) { - WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex)); - CopyOrThrow(first.File(), out_file.get(), weights_size); - if (++first_index == first_count) break; - first_word = first.ReadWord(); - } else { - WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex)); - CopyOrThrow(second.File(), out_file.get(), weights_size); - if (++second_index == second_count) break; - second_word = second.ReadWord(); - } - } - if (first_index == first_count) { - WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex)); - CopyOrThrow(second.File(), out_file.get(), (second_count - second_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex)); - } else { - WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex)); - CopyOrThrow(first.File(), out_file.get(), (first_count - first_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex)); - } - first.NextHeader(); - second.NextHeader(); - } - - for (SortedFileReader &remaining = first.Ended() ? second : first; !remaining.Ended(); remaining.NextHeader()) { - CopyFullRecord(remaining, out_file.get(), weights_size); - } -} - -const char *kContextSuffix = "_contexts"; - -void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) { - const size_t context_size = sizeof(WordIndex) * (order - 1); - // Sort just the contexts using the same memory. - PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); - PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size)); - - std::sort(context_begin, context_end, CompareRecords<PartialViewProxy>(order - 1)); - - std::string name(ngram_file_name + kContextSuffix); - util::scoped_FILE out(OpenOrThrow(name.c_str(), "w")); - - // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. - if (context_begin == context_end) return; - PartialIter i(context_begin); - WriteOrThrow(out.get(), i->Indices(), context_size); - const WordIndex *previous = i->Indices(); - ++i; - for (; i != context_end; ++i) { - if (memcmp(previous, i->Indices(), context_size)) { - WriteOrThrow(out.get(), i->Indices(), context_size); - previous = i->Indices(); - } - } -} +const float kBadProb = std::numeric_limits<float>::infinity(); -class ContextReader { +class SRISucks { public: - ContextReader() : valid_(false) {} - - ContextReader(const char *name, unsigned char order) { - Reset(name, order); - } - - void Reset(const char *name, unsigned char order) { - file_.reset(OpenOrThrow(name, "r")); - length_ = sizeof(WordIndex) * static_cast<size_t>(order); - words_.resize(order); - valid_ = true; - ++*this; - } - - ContextReader &operator++() { - if (1 != fread(&*words_.begin(), length_, 1, file_.get())) { - if (!feof(file_.get())) - UTIL_THROW(util::ErrnoException, "Short read"); - valid_ = false; + SRISucks() { + for (BackoffMessages *i = messages_; i != messages_ + kMaxOrder - 1; ++i) + i->Init(sizeof(ProbPointer) + sizeof(WordIndex) * (i - messages_ + 1)); + } + + void Send(unsigned char begin, unsigned char order, const WordIndex *to, float prob_basis) { + assert(prob_basis != kBadProb); + ProbPointer pointer; + pointer.array = order - 1; + pointer.index = values_[order - 1].size(); + for (unsigned char i = begin; i < order; ++i) { + messages_[i - 1].Add(to, pointer); } - return *this; + values_[order - 1].push_back(prob_basis); } - const WordIndex *operator*() const { return &*words_.begin(); } - - operator bool() const { return valid_; } - - FILE *GetFile() { return file_.get(); } - - private: - util::scoped_FILE file_; - - size_t length_; - - std::vector<WordIndex> words_; - - bool valid_; -}; - -void MergeContextFiles(const std::string &first_base, const std::string &second_base, const std::string &out_base, unsigned char order) { - const size_t context_size = sizeof(WordIndex) * (order - 1); - std::string first_name(first_base + kContextSuffix); - std::string second_name(second_base + kContextSuffix); - ContextReader first(first_name.c_str(), order - 1), second(second_name.c_str(), order - 1); - RemoveOrThrow(first_name.c_str()); - RemoveOrThrow(second_name.c_str()); - std::string out_name(out_base + kContextSuffix); - util::scoped_FILE out(OpenOrThrow(out_name.c_str(), "w")); - while (first && second) { - for (const WordIndex *f = *first, *s = *second; ; ++f, ++s) { - if (f == *first + order - 1) { - // Equal. - WriteOrThrow(out.get(), *first, context_size); - ++first; - ++second; - break; + void ObtainBackoffs(unsigned char total_order, FILE *unigram_file, RecordReader *reader) { + for (unsigned char i = 0; i < kMaxOrder - 1; ++i) { + it_[i] = &*values_[i].begin(); } - if (*f < *s) { - // First lower - WriteOrThrow(out.get(), *first, context_size); - ++first; - break; - } else if (*f > *s) { - WriteOrThrow(out.get(), *second, context_size); - ++second; - break; + messages_[0].Apply(it_, unigram_file); + BackoffMessages *messages = messages_ + 1; + const RecordReader *end = reader + total_order - 2 /* exclude unigrams and longest order */; + for (; reader != end; ++messages, ++reader) { + messages->Apply(it_, *reader); } } - } - ContextReader &remaining = first ? first : second; - if (!remaining) return; - WriteOrThrow(out.get(), *remaining, context_size); - CopyRestOrThrow(remaining.GetFile(), out.get()); -} -void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) { - ReadNGramHeader(f, order); - const size_t count = counts[order - 1]; - // Size of weights. Does it include backoff? - const size_t words_size = sizeof(WordIndex) * order; - const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float)); - const size_t entry_size = words_size + weights_size; - const size_t batch_size = std::min(count, mem.size() / entry_size); - uint8_t *const begin = reinterpret_cast<uint8_t*>(mem.get()); - std::deque<std::string> files; - for (std::size_t batch = 0, done = 0; done < count; ++batch) { - uint8_t *out = begin; - uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; - if (order == counts.size()) { - for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn); - } - } else { - for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn); - } + ProbBackoff GetBlank(unsigned char total_order, unsigned char order, const WordIndex *indices) { + assert(order > 1); + ProbBackoff ret; + ret.prob = *(it_[order - 1]++); + ret.backoff = ((order != total_order - 1) && messages_[order - 1].Extends(order, indices)) ? kExtensionBackoff : kNoExtensionBackoff; + return ret; } - // Sort full records by full n-gram. - EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); - // parallel_sort uses too much RAM - std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords<EntryProxy>(order)); - files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size)); - WriteContextFile(begin, out_end, files.back(), entry_size, order); - - done += (out_end - begin) / entry_size; - } - // All individual files created. Merge them. - - std::size_t merge_count = 0; - while (files.size() > 1) { - std::stringstream assembled; - assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++); - files.push_back(assembled.str()); - MergeSortedFiles(files[0], files[1], files.back(), weights_size, order); - MergeContextFiles(files[0], files[1], files.back(), order); - files.pop_front(); - files.pop_front(); - } - if (!files.empty()) { - std::stringstream assembled; - assembled << file_prefix << static_cast<unsigned int>(order) << "_merged"; - std::string merged_name(assembled.str()); - if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); - std::string context_name = files[0] + kContextSuffix; - merged_name += kContextSuffix; - if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str()); - } -} - -void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { - PositiveProbWarn warn(config.positive_log_probability); - { - std::string unigram_name = file_prefix + "unigrams"; - util::scoped_fd unigram_file; - // In case <unk> appears. - size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); - util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out); - Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn); - CheckSpecials(config, vocab); - if (!vocab.SawUnk()) ++counts[0]; - } + const std::vector<float> &Values(unsigned char order) const { + return values_[order - 1]; + } - // Only use as much buffer as we need. - size_t buffer_use = 0; - for (unsigned int order = 2; order < counts.size(); ++order) { - buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1])); - } - buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); - buffer = std::min<size_t>(buffer, buffer_use); + private: + // This used to be one array. Then I needed to separate it by order for quantization to work. + std::vector<float> values_[kMaxOrder - 1]; + BackoffMessages messages_[kMaxOrder - 1]; - util::scoped_memory mem; - mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); - if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); + float *it_[kMaxOrder - 1]; +}; - for (unsigned char order = 2; order <= counts.size(); ++order) { - ConvertToSorted(f, vocab, counts, mem, file_prefix, order, warn); - } - ReadEnd(f); -} +class FindBlanks { + public: + FindBlanks(uint64_t *counts, unsigned char order, const ProbBackoff *unigrams, SRISucks &messages) + : counts_(counts), longest_counts_(counts + order - 1), unigrams_(unigrams), sri_(messages) {} -bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const WordIndex *header) { - for (; words != words_end; ++words, ++header) { - if (*words != *header) { - //assert(*words <= *header); - return false; + float UnigramProb(WordIndex index) const { + return unigrams_[index].prob; } - } - return true; -} -// Phase to count n-grams, including blanks inserted because they were pruned but have extensions -class JustCount { - public: - template <class Middle, class Longest> JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, Middle * /*middle*/, Longest &/*longest*/, uint64_t *counts, unsigned char order) - : counts_(counts), longest_counts_(counts + order - 1) {} - - void Unigrams(WordIndex begin, WordIndex end) { - counts_[0] += end - begin; + void Unigram(WordIndex /*index*/) { + ++counts_[0]; } - void MiddleBlank(const unsigned char mid_idx, WordIndex /* idx */) { - ++counts_[mid_idx + 1]; + void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char lower, float prob_basis) { + sri_.Send(lower, order, indices + 1, prob_basis); + ++counts_[order - 1]; } - void Middle(const unsigned char mid_idx, const WordIndex * /*before*/, WordIndex /*key*/, const ProbBackoff &/*weights*/) { - ++counts_[mid_idx + 1]; + void Middle(const unsigned char order, const void * /*data*/) { + ++counts_[order - 1]; } - void Longest(WordIndex /*key*/, Prob /*prob*/) { + void Longest(const void * /*data*/) { ++*longest_counts_; } @@ -608,167 +253,156 @@ class JustCount { private: uint64_t *const counts_, *const longest_counts_; + + const ProbBackoff *unigrams_; + + SRISucks &sri_; }; // Phase to actually write n-grams to the trie. template <class Quant, class Bhiksha> class WriteEntries { public: - WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, const uint64_t * /*counts*/, unsigned char order) : + WriteEntries(RecordReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, unsigned char order, SRISucks &sri) : contexts_(contexts), unigrams_(unigrams), middle_(middle), longest_(longest), - bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)) {} + bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)), + order_(order), + sri_(sri) {} - void Unigrams(WordIndex begin, WordIndex end) { - uint64_t next = bigram_pack_.InsertIndex(); - for (UnigramValue *i = unigrams_ + begin; i < unigrams_ + end; ++i) { - i->next = next; - } + float UnigramProb(WordIndex index) const { return unigrams_[index].weights.prob; } + + void Unigram(WordIndex word) { + unigrams_[word].next = bigram_pack_.InsertIndex(); } - void MiddleBlank(const unsigned char mid_idx, WordIndex key) { - middle_[mid_idx].Insert(key, kBlankProb, kBlankBackoff); + void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char /*lower*/, float /*prob_base*/) { + ProbBackoff weights = sri_.GetBlank(order_, order, indices); + middle_[order - 2].Insert(indices[order - 1], weights.prob, weights.backoff); } - void Middle(const unsigned char mid_idx, const WordIndex *before, WordIndex key, ProbBackoff weights) { - // Order (mid_idx+2). - ContextReader &context = contexts_[mid_idx + 1]; - if (context && !memcmp(before, *context, sizeof(WordIndex) * (mid_idx + 1)) && (*context)[mid_idx + 1] == key) { + void Middle(const unsigned char order, const void *data) { + RecordReader &context = contexts_[order - 1]; + const WordIndex *words = reinterpret_cast<const WordIndex*>(data); + ProbBackoff weights = *reinterpret_cast<const ProbBackoff*>(words + order); + if (context && !memcmp(data, context.Data(), sizeof(WordIndex) * order)) { SetExtension(weights.backoff); ++context; } - middle_[mid_idx].Insert(key, weights.prob, weights.backoff); + middle_[order - 2].Insert(words[order - 1], weights.prob, weights.backoff); } - void Longest(WordIndex key, Prob prob) { - longest_.Insert(key, prob.prob); + void Longest(const void *data) { + const WordIndex *words = reinterpret_cast<const WordIndex*>(data); + longest_.Insert(words[order_ - 1], reinterpret_cast<const Prob*>(words + order_)->prob); } void Cleanup() {} private: - ContextReader *contexts_; + RecordReader *contexts_; UnigramValue *const unigrams_; BitPackedMiddle<typename Quant::Middle, Bhiksha> *const middle_; BitPackedLongest<typename Quant::Longest> &longest_; BitPacked &bigram_pack_; + const unsigned char order_; + SRISucks &sri_; }; -template <class Doing> class RecursiveInsert { - public: - template <class MiddleT, class LongestT> RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, MiddleT *middle, LongestT &longest, uint64_t *counts, unsigned char order) : - doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) { - } +struct Gram { + Gram(const WordIndex *in_begin, unsigned char order) : begin(in_begin), end(in_begin + order) {} - // Outer unigram loop. - void Apply(std::ostream *progress_out, const char *message, WordIndex unigram_count) { - util::ErsatzProgress progress(progress_out, message, unigram_count + 1); - for (words_[0] = 0; ; ++words_[0]) { - progress.Set(words_[0]); - WordIndex min_continue = unigram_count; - for (SortedFileReader *other = inputs_; other != inputs_end_; ++other) { - if (other->Ended()) continue; - min_continue = std::min(min_continue, other->Header()[0]); - } - // This will write at unigram_count. This is by design so that the next pointers will make sense. - doing_.Unigrams(words_[0], min_continue + 1); - if (min_continue == unigram_count) break; - words_[0] = min_continue; - Middle(0); - } - doing_.Cleanup(); - } + const WordIndex *begin, *end; - private: - void Middle(const unsigned char mid_idx) { - // (mid_idx + 2)-gram. - if (mid_idx == order_minus_2_) { - Longest(); - return; - } - // Orders [2, order) + // For queue, this is the direction we want. + bool operator<(const Gram &other) const { + return std::lexicographical_compare(other.begin, other.end, begin, end); + } +}; - SortedFileReader &reader = inputs_[mid_idx]; +template <class Doing> class BlankManager { + public: + BlankManager(unsigned char total_order, Doing &doing) : total_order_(total_order), been_length_(0), doing_(doing) { + for (float *i = basis_; i != basis_ + kMaxOrder - 1; ++i) *i = kBadProb; + } - if (reader.Ended() || !HeadMatch(words_, words_ + mid_idx + 1, reader.Header())) { - // This order doesn't have a header match, but longer ones might. - MiddleAllBlank(mid_idx); - return; + void Visit(const WordIndex *to, unsigned char length, float prob) { + basis_[length - 1] = prob; + unsigned char overlap = std::min<unsigned char>(length - 1, been_length_); + const WordIndex *cur; + WordIndex *pre; + for (cur = to, pre = been_; cur != to + overlap; ++cur, ++pre) { + if (*pre != *cur) break; } - - // There is a header match. - WordIndex count = reader.ReadCount(); - WordIndex current = reader.ReadWord(); - while (count) { - WordIndex min_continue = std::numeric_limits<WordIndex>::max(); - for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { - if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header())) - min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); - } - while (true) { - if (current > min_continue) { - doing_.MiddleBlank(mid_idx, min_continue); - words_[mid_idx + 1] = min_continue; - Middle(mid_idx + 1); - break; - } - ProbBackoff weights; - reader.ReadWeights(weights); - doing_.Middle(mid_idx, words_, current, weights); - --count; - if (current == min_continue) { - words_[mid_idx + 1] = min_continue; - Middle(mid_idx + 1); - if (count) current = reader.ReadWord(); - break; - } - if (!count) break; - current = reader.ReadWord(); - } + if (cur == to + length - 1) { + *pre = *cur; + been_length_ = length; + return; } - // Count is now zero. Finish off remaining blanks. - MiddleAllBlank(mid_idx); - reader.NextHeader(); - } - - void MiddleAllBlank(const unsigned char mid_idx) { - while (true) { - WordIndex min_continue = std::numeric_limits<WordIndex>::max(); - for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { - if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header())) - min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); - } - if (min_continue == std::numeric_limits<WordIndex>::max()) return; - doing_.MiddleBlank(mid_idx, min_continue); - words_[mid_idx + 1] = min_continue; - Middle(mid_idx + 1); + // There are blanks to insert starting with order blank. + unsigned char blank = cur - to + 1; + UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context."); + const float *lower_basis; + for (lower_basis = basis_ + blank - 2; *lower_basis == kBadProb; --lower_basis) {} + unsigned char based_on = lower_basis - basis_ + 1; + for (; cur != to + length - 1; ++blank, ++cur, ++pre) { + assert(*lower_basis != kBadProb); + doing_.MiddleBlank(blank, to, based_on, *lower_basis); + *pre = *cur; + // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram. + basis_[blank - 1] = kBadProb; } + been_length_ = length; } - void Longest() { - SortedFileReader &reader = *(inputs_end_ - 1); - if (reader.Ended() || !HeadMatch(words_, words_ + order_minus_2_ + 1, reader.Header())) return; - WordIndex count = reader.ReadCount(); - for (WordIndex i = 0; i < count; ++i) { - WordIndex word = reader.ReadWord(); - Prob prob; - reader.ReadWeights(prob); - doing_.Longest(word, prob); - } - reader.NextHeader(); - return; - } + private: + const unsigned char total_order_; - Doing doing_; + WordIndex been_[kMaxOrder]; + unsigned char been_length_; - SortedFileReader *inputs_; - SortedFileReader *inputs_end_; + float basis_[kMaxOrder]; + + Doing &doing_; +}; - WordIndex words_[kMaxOrder]; +template <class Doing> void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) { + util::ErsatzProgress progress(progress_out, message, unigram_count + 1); + unsigned int unigram = 0; + std::priority_queue<Gram> grams; + grams.push(Gram(&unigram, 1)); + for (unsigned char i = 2; i <= total_order; ++i) { + if (input[i-2]) grams.push(Gram(reinterpret_cast<const WordIndex*>(input[i-2].Data()), i)); + } - const unsigned char order_minus_2_; -}; + BlankManager<Doing> blank(total_order, doing); + + while (true) { + Gram top = grams.top(); + grams.pop(); + unsigned char order = top.end - top.begin; + if (order == 1) { + blank.Visit(&unigram, 1, doing.UnigramProb(unigram)); + doing.Unigram(unigram); + progress.Set(unigram); + if (++unigram == unigram_count + 1) break; + grams.push(top); + } else { + if (order == total_order) { + blank.Visit(top.begin, order, reinterpret_cast<const Prob*>(top.end)->prob); + doing.Longest(top.begin); + } else { + blank.Visit(top.begin, order, reinterpret_cast<const ProbBackoff*>(top.end)->prob); + doing.Middle(order, top.begin); + } + RecordReader &reader = input[order - 2]; + if (++reader) grams.push(top); + } + } + assert(grams.empty()); + doing.Cleanup(); +} void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) { if (fixed[0] != initial[0]) UTIL_THROW(util::Exception, "Unigram count should be constant but initial is " << initial[0] << " and recounted is " << fixed[0]); @@ -778,120 +412,122 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u } } -bool IsDirectory(const char *path) { - struct stat info; - if (0 != stat(path, &info)) return false; - return S_ISDIR(info.st_mode); -} - -template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { - ProbBackoff weights; - std::vector<float> probs, backoffs; - probs.reserve(count); +template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, const std::vector<float> &additional, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) { + std::vector<float> probs(additional), backoffs; + probs.reserve(count + additional.size()); backoffs.reserve(count); - for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { - uint64_t entries = reader.ReadCount(); - for (uint64_t c = 0; c < entries; ++c) { - reader.ReadWord(); - reader.ReadWeights(weights); - // kBlankProb isn't added yet. - probs.push_back(weights.prob); - if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); - ++progress; - } + for (reader.Rewind(); reader; ++reader) { + const ProbBackoff &weights = *reinterpret_cast<const ProbBackoff*>(reinterpret_cast<const uint8_t*>(reader.Data()) + sizeof(WordIndex) * order); + probs.push_back(weights.prob); + if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); + ++progress; } quant.Train(order, probs, backoffs); } -template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { - Prob weights; +template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) { std::vector<float> probs, backoffs; probs.reserve(count); - for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { - uint64_t entries = reader.ReadCount(); - for (uint64_t c = 0; c < entries; ++c) { - reader.ReadWord(); - reader.ReadWeights(weights); - // kBlankProb isn't added yet. - probs.push_back(weights.prob); - ++progress; - } + for (reader.Rewind(); reader; ++reader) { + const Prob &weights = *reinterpret_cast<const Prob*>(reinterpret_cast<const uint8_t*>(reader.Data()) + sizeof(WordIndex) * order); + probs.push_back(weights.prob); + ++progress; } quant.TrainProb(order, probs); } +void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) { + // Fill unigram probabilities. + try { + rewind(file); + for (WordIndex i = 0; i < unigram_count; ++i) { + ReadOrThrow(file, &unigrams[i].weights, sizeof(ProbBackoff)); + if (contexts && *reinterpret_cast<const WordIndex*>(contexts.Data()) == i) { + SetExtension(unigrams[i].weights.backoff); + ++contexts; + } + } + } catch (util::Exception &e) { + e << " while re-reading unigram probabilities"; + throw; + } +} + } // namespace template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { - std::vector<SortedFileReader> inputs(counts.size() - 1); - std::vector<ContextReader> contexts(counts.size() - 1); + RecordReader inputs[kMaxOrder - 1]; + RecordReader contexts[kMaxOrder - 1]; for (unsigned char i = 2; i <= counts.size(); ++i) { std::stringstream assembled; assembled << file_prefix << static_cast<unsigned int>(i) << "_merged"; - inputs[i-2].Init(assembled.str(), i); - RemoveOrThrow(assembled.str().c_str()); + inputs[i-2].Init(assembled.str(), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff))); + util::RemoveOrThrow(assembled.str().c_str()); assembled << kContextSuffix; - contexts[i-2].Reset(assembled.str().c_str(), i-1); - RemoveOrThrow(assembled.str().c_str()); + contexts[i-2].Init(assembled.str(), (i-1) * sizeof(WordIndex)); + util::RemoveOrThrow(assembled.str().c_str()); } + SRISucks sri; std::vector<uint64_t> fixed_counts(counts.size()); { - RecursiveInsert<JustCount> counter(&*inputs.begin(), &*contexts.begin(), NULL, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); - counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); + std::string temp(file_prefix); temp += "unigrams"; + util::scoped_fd unigram_file(util::OpenReadOrThrow(temp.c_str())); + util::scoped_memory unigrams; + MapRead(util::POPULATE_OR_READ, unigram_file.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); + FindBlanks finder(&*fixed_counts.begin(), counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri); + RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); } - for (std::vector<SortedFileReader>::const_iterator i = inputs.begin(); i != inputs.end(); ++i) { - if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs.begin() + 2) << "-gram table did not complete reading"); + for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) { + if (*i) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading"); } SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; + util::scoped_FILE unigram_file; + { + std::string name(file_prefix + "unigrams"); + unigram_file.reset(OpenOrThrow(name.c_str(), "r")); + util::RemoveOrThrow(name.c_str()); + } + sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); + out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config); + for (unsigned char i = 2; i <= counts.size(); ++i) { + inputs[i-2].Rewind(); + } if (Quant::kTrain) { util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); for (unsigned char i = 2; i < counts.size(); ++i) { - TrainQuantizer(i, counts[i-1], inputs[i-2], progress, quant); + TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant); } TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant); quant.FinishedLoading(config); } + UnigramValue *unigrams = out.unigram.Raw(); + PopulateUnigramWeights(unigram_file.get(), counts[0], contexts[0], unigrams); + unigram_file.reset(); + for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); } - UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert<WriteEntries<Quant, Bhiksha> > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); - inserter.Apply(config.messages, "Building trie", fixed_counts[0]); - } - - // Fill unigram probabilities. - try { - std::string name(file_prefix + "unigrams"); - util::scoped_FILE file(OpenOrThrow(name.c_str(), "r")); - for (WordIndex i = 0; i < counts[0]; ++i) { - ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff)); - if (contexts[0] && **contexts[0] == i) { - SetExtension(unigrams[i].weights.backoff); - ++contexts[0]; - } - } - RemoveOrThrow(name.c_str()); - } catch (util::Exception &e) { - e << " while re-reading unigram probabilities"; - throw; + WriteEntries<Quant, Bhiksha> writer(contexts, unigrams, out.middle_begin_, out.longest, counts.size(), sri); + RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); } // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. for (unsigned char order = 2; order <= counts.size(); ++order) { - const ContextReader &context = contexts[order - 2]; + const RecordReader &context = contexts[order - 2]; if (context) { FormatLoadException e; - e << "An " << static_cast<unsigned int>(order) << "-gram has the context (i.e. all but the last word):"; - for (const WordIndex *i = *context; i != *context + order - 1; ++i) { + e << "An " << static_cast<unsigned int>(order) << "-gram has context"; + const WordIndex *ctx = reinterpret_cast<const WordIndex*>(context.Data()); + for (const WordIndex *i = ctx; i != ctx + order - 1; ++i) { e << ' ' << *i; } e << " so this context must appear in the model as a " << static_cast<unsigned int>(order - 1) << "-gram but it does not"; @@ -945,6 +581,14 @@ template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBin longest.LoadedBinary(); } +namespace { +bool IsDirectory(const char *path) { + struct stat info; + if (0 != stat(path, &info)) return false; + return S_ISDIR(info.st_mode); +} +} // namespace + template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 2f39c09f..c3e02a98 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -1,10 +1,16 @@ #ifndef LM_SEARCH_TRIE__ #define LM_SEARCH_TRIE__ -#include "lm/binary_format.hh" +#include "lm/config.hh" +#include "lm/model_type.hh" +#include "lm/return.hh" #include "lm/trie.hh" #include "lm/weights.hh" +#include "util/file_piece.hh" + +#include <vector> + #include <assert.h> namespace lm { @@ -30,6 +36,8 @@ template <class Quant, class Bhiksha> class TrieSearch { static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd); + static const unsigned int kVersion = 0; + static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { Quant::UpdateConfigFromBinary(fd, counts, config); AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); @@ -57,12 +65,16 @@ template <class Quant, class Bhiksha> class TrieSearch { void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); - void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { - unigram.Find(word, prob, backoff, node); + void LookupUnigram(WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { + unigram.Find(word, ret.prob, backoff, node); + ret.independent_left = (node.begin == node.end); + ret.extend_left = static_cast<uint64_t>(word); } - bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { - return mid.Find(word, prob, backoff, node); + bool LookupMiddle(const Middle &mid, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { + if (!mid.Find(word, ret.prob, backoff, node, ret.extend_left)) return false; + ret.independent_left = (node.begin == node.end); + return true; } bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { @@ -76,14 +88,25 @@ template <class Quant, class Bhiksha> class TrieSearch { bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { // TODO: don't decode backoff. assert(begin != end); - float ignored_prob, ignored_backoff; - LookupUnigram(*begin, ignored_prob, ignored_backoff, node); + FullScoreReturn ignored; + float ignored_backoff; + LookupUnigram(*begin, ignored_backoff, node, ignored); for (const WordIndex *i = begin + 1; i < end; ++i) { if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false; } return true; } + Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const { + if (extend_length == 1) { + float ignored; + Node ret; + unigram.Find(static_cast<WordIndex>(extend_pointer), prob, ignored, ret); + return ret; + } + return middle_begin_[extend_length - 2].ReadEntry(extend_pointer, prob); + } + private: friend void BuildTrie<Quant, Bhiksha>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 8c536e66..4e60b184 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -86,7 +86,7 @@ template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::Inse ++insert_index_; } -template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { +template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const { uint64_t at_pointer; if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; @@ -94,6 +94,9 @@ template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find uint64_t index = at_pointer; at_pointer *= total_bits_; at_pointer += word_bits_; + + pointer = at_pointer; + quant_.Read(base_, at_pointer, prob, backoff); at_pointer += quant_.TotalBits(); diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 53612064..a9f5e417 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -94,10 +94,18 @@ template <class Quant, class Bhiksha> class BitPackedMiddle : public BitPacked { void LoadedBinary() { bhiksha_.LoadedBinary(); } - bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const; + bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const; bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; + NodeRange ReadEntry(uint64_t pointer, float &prob) { + quant_.ReadProb(base_, pointer, prob); + NodeRange ret; + // pointer/total_bits_ should always round down. + bhiksha_.ReadNext(base_, pointer + quant_.TotalBits(), pointer / total_bits_, total_bits_, ret); + return ret; + } + private: Quant quant_; Bhiksha bhiksha_; diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc new file mode 100644 index 00000000..01c4e490 --- /dev/null +++ b/klm/lm/trie_sort.cc @@ -0,0 +1,261 @@ +#include "lm/trie_sort.hh" + +#include "lm/config.hh" +#include "lm/lm_exception.hh" +#include "lm/read_arpa.hh" +#include "lm/vocab.hh" +#include "lm/weights.hh" +#include "lm/word_index.hh" +#include "util/file_piece.hh" +#include "util/mmap.hh" +#include "util/proxy_iterator.hh" +#include "util/sized_iterator.hh" + +#include <algorithm> +#include <cstring> +#include <cstdio> +#include <deque> +#include <limits> +#include <vector> + +namespace lm { +namespace ngram { +namespace trie { + +const char *kContextSuffix = "_contexts"; + +FILE *OpenOrThrow(const char *name, const char *mode) { + FILE *ret = fopen(name, mode); + if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode); + return ret; +} + +void WriteOrThrow(FILE *to, const void *data, size_t size) { + assert(size); + if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); +} + +namespace { + +typedef util::SizedIterator NGramIter; + +// Proxy for an entry except there is some extra cruft between the entries. This is used to sort (n-1)-grams using the same memory as the sorted n-grams. +class PartialViewProxy { + public: + PartialViewProxy() : attention_size_(0), inner_() {} + + PartialViewProxy(void *ptr, std::size_t block_size, std::size_t attention_size) : attention_size_(attention_size), inner_(ptr, block_size) {} + + operator std::string() const { + return std::string(reinterpret_cast<const char*>(inner_.Data()), attention_size_); + } + + PartialViewProxy &operator=(const PartialViewProxy &from) { + memcpy(inner_.Data(), from.inner_.Data(), attention_size_); + return *this; + } + + PartialViewProxy &operator=(const std::string &from) { + memcpy(inner_.Data(), from.data(), attention_size_); + return *this; + } + + const void *Data() const { return inner_.Data(); } + void *Data() { return inner_.Data(); } + + private: + friend class util::ProxyIterator<PartialViewProxy>; + + typedef std::string value_type; + + const std::size_t attention_size_; + + typedef util::SizedInnerIterator InnerIterator; + InnerIterator &Inner() { return inner_; } + const InnerIterator &Inner() const { return inner_; } + InnerIterator inner_; +}; + +typedef util::ProxyIterator<PartialViewProxy> PartialIter; + +std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order) { + std::stringstream assembled; + assembled << file_prefix << static_cast<unsigned int>(order) << '_' << batch; + std::string ret(assembled.str()); + util::scoped_fd out(util::CreateOrThrow(ret.c_str())); + util::WriteOrThrow(out.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin); + return ret; +} + +void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) { + const size_t context_size = sizeof(WordIndex) * (order - 1); + // Sort just the contexts using the same memory. + PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); + PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size)); + + std::sort(context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1))); + + std::string name(ngram_file_name + kContextSuffix); + util::scoped_FILE out(OpenOrThrow(name.c_str(), "w")); + + // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. + if (context_begin == context_end) return; + PartialIter i(context_begin); + WriteOrThrow(out.get(), i->Data(), context_size); + const void *previous = i->Data(); + ++i; + for (; i != context_end; ++i) { + if (memcmp(previous, i->Data(), context_size)) { + WriteOrThrow(out.get(), i->Data(), context_size); + previous = i->Data(); + } + } +} + +struct ThrowCombine { + void operator()(std::size_t /*entry_size*/, const void * /*first*/, const void * /*second*/, FILE * /*out*/) const { + UTIL_THROW(FormatLoadException, "Duplicate n-gram detected."); + } +}; + +// Useful for context files that just contain records with no value. +struct FirstCombine { + void operator()(std::size_t entry_size, const void *first, const void * /*second*/, FILE *out) const { + WriteOrThrow(out, first, entry_size); + } +}; + +template <class Combine> void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order, const Combine &combine = ThrowCombine()) { + std::size_t entry_size = sizeof(WordIndex) * order + weights_size; + RecordReader first, second; + first.Init(first_name.c_str(), entry_size); + util::RemoveOrThrow(first_name.c_str()); + second.Init(second_name.c_str(), entry_size); + util::RemoveOrThrow(second_name.c_str()); + util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w")); + EntryCompare less(order); + while (first && second) { + if (less(first.Data(), second.Data())) { + WriteOrThrow(out_file.get(), first.Data(), entry_size); + ++first; + } else if (less(second.Data(), first.Data())) { + WriteOrThrow(out_file.get(), second.Data(), entry_size); + ++second; + } else { + combine(entry_size, first.Data(), second.Data(), out_file.get()); + ++first; ++second; + } + } + for (RecordReader &remains = (first ? second : first); remains; ++remains) { + WriteOrThrow(out_file.get(), remains.Data(), entry_size); + } +} + +void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) { + ReadNGramHeader(f, order); + const size_t count = counts[order - 1]; + // Size of weights. Does it include backoff? + const size_t words_size = sizeof(WordIndex) * order; + const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float)); + const size_t entry_size = words_size + weights_size; + const size_t batch_size = std::min(count, mem.size() / entry_size); + uint8_t *const begin = reinterpret_cast<uint8_t*>(mem.get()); + std::deque<std::string> files; + for (std::size_t batch = 0, done = 0; done < count; ++batch) { + uint8_t *out = begin; + uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; + if (order == counts.size()) { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn); + } + } else { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn); + } + } + // Sort full records by full n-gram. + util::SizedProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); + // parallel_sort uses too much RAM + std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare<EntryCompare>(EntryCompare(order))); + files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order)); + WriteContextFile(begin, out_end, files.back(), entry_size, order); + + done += (out_end - begin) / entry_size; + } + + // All individual files created. Merge them. + + std::size_t merge_count = 0; + while (files.size() > 1) { + std::stringstream assembled; + assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++); + files.push_back(assembled.str()); + MergeSortedFiles(files[0], files[1], files.back(), weights_size, order, ThrowCombine()); + MergeSortedFiles(files[0], files[1], files.back(), 0, order, FirstCombine()); + files.pop_front(); + files.pop_front(); + } + if (!files.empty()) { + std::stringstream assembled; + assembled << file_prefix << static_cast<unsigned int>(order) << "_merged"; + std::string merged_name(assembled.str()); + if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); + std::string context_name = files[0] + kContextSuffix; + merged_name += kContextSuffix; + if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str()); + } +} + +} // namespace + +void RecordReader::Init(const std::string &name, std::size_t entry_size) { + file_.reset(OpenOrThrow(name.c_str(), "r+")); + data_.reset(malloc(entry_size)); + UTIL_THROW_IF(!data_.get(), util::ErrnoException, "Failed to malloc read buffer"); + remains_ = true; + entry_size_ = entry_size; + ++*this; +} + +void RecordReader::Overwrite(const void *start, std::size_t amount) { + long internal = (uint8_t*)start - (uint8_t*)data_.get(); + UTIL_THROW_IF(fseek(file_.get(), internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision"); + WriteOrThrow(file_.get(), start, amount); + long forward = entry_size_ - internal - amount; + if (forward) UTIL_THROW_IF(fseek(file_.get(), forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision"); +} + +void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { + PositiveProbWarn warn(config.positive_log_probability); + { + std::string unigram_name = file_prefix + "unigrams"; + util::scoped_fd unigram_file; + // In case <unk> appears. + size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); + util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out); + Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn); + CheckSpecials(config, vocab); + if (!vocab.SawUnk()) ++counts[0]; + } + + // Only use as much buffer as we need. + size_t buffer_use = 0; + for (unsigned int order = 2; order < counts.size(); ++order) { + buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1])); + } + buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); + buffer = std::min<size_t>(buffer, buffer_use); + + util::scoped_memory mem; + mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); + if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); + + for (unsigned char order = 2; order <= counts.size(); ++order) { + ConvertToSorted(f, vocab, counts, mem, file_prefix, order, warn); + } + ReadEnd(f); +} + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh new file mode 100644 index 00000000..a6916483 --- /dev/null +++ b/klm/lm/trie_sort.hh @@ -0,0 +1,94 @@ +#ifndef LM_TRIE_SORT__ +#define LM_TRIE_SORT__ + +#include "lm/word_index.hh" + +#include "util/file.hh" +#include "util/scoped.hh" + +#include <cstddef> +#include <functional> +#include <string> +#include <vector> + +#include <inttypes.h> + +namespace util { class FilePiece; } + +// Step of trie builder: create sorted files. +namespace lm { +namespace ngram { +class SortedVocabulary; +class Config; + +namespace trie { + +extern const char *kContextSuffix; +FILE *OpenOrThrow(const char *name, const char *mode); +void WriteOrThrow(FILE *to, const void *data, size_t size); + +class EntryCompare : public std::binary_function<const void*, const void*, bool> { + public: + explicit EntryCompare(unsigned char order) : order_(order) {} + + bool operator()(const void *first_void, const void *second_void) const { + const WordIndex *first = static_cast<const WordIndex*>(first_void); + const WordIndex *second = static_cast<const WordIndex*>(second_void); + const WordIndex *end = first + order_; + for (; first != end; ++first, ++second) { + if (*first < *second) return true; + if (*first > *second) return false; + } + return false; + } + private: + unsigned char order_; +}; + +class RecordReader { + public: + RecordReader() : remains_(true) {} + + void Init(const std::string &name, std::size_t entry_size); + + void *Data() { return data_.get(); } + const void *Data() const { return data_.get(); } + + RecordReader &operator++() { + std::size_t ret = fread(data_.get(), entry_size_, 1, file_.get()); + if (!ret) { + UTIL_THROW_IF(!feof(file_.get()), util::ErrnoException, "Error reading temporary file"); + remains_ = false; + } + return *this; + } + + operator bool() const { return remains_; } + + void Rewind() { + rewind(file_.get()); + remains_ = true; + ++*this; + } + + std::size_t EntrySize() const { return entry_size_; } + + void Overwrite(const void *start, std::size_t amount); + + private: + util::scoped_malloc data_; + + bool remains_; + + std::size_t entry_size_; + + util::scoped_FILE file_; +}; + +void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab); + +} // namespace trie +} // namespace ngram +} // namespace lm + +#endif // LM_TRIE_SORT__ diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh index 08627efd..6a5a0196 100644 --- a/klm/lm/virtual_interface.hh +++ b/klm/lm/virtual_interface.hh @@ -1,37 +1,13 @@ #ifndef LM_VIRTUAL_INTERFACE__ #define LM_VIRTUAL_INTERFACE__ +#include "lm/return.hh" #include "lm/word_index.hh" #include "util/string_piece.hh" #include <string> namespace lm { - -/* Structure returned by scoring routines. */ -struct FullScoreReturn { - // log10 probability - float prob; - - /* The length of n-gram matched. Do not use this for recombination. - * Consider a model containing only the following n-grams: - * -1 foo - * -3.14 bar - * -2.718 baz -5 - * -6 foo bar - * - * If you score ``bar'' then ngram_length is 1 and recombination state is the - * empty string because bar has zero backoff and does not extend to the - * right. - * If you score ``foo'' then ngram_length is 1 and recombination state is - * ``foo''. - * - * Ideally, keep output states around and compare them. Failing that, - * get out_state.ValidLength() and use that length for recombination. - */ - unsigned char ngram_length; -}; - namespace base { template <class T, class U, class V> class ModelFacade; diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 04979d51..03b0767a 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -1,5 +1,6 @@ #include "lm/vocab.hh" +#include "lm/binary_format.hh" #include "lm/enumerate_vocab.hh" #include "lm/lm_exception.hh" #include "lm/config.hh" @@ -56,16 +57,6 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { } } -void WriteOrThrow(int fd, const void *data_void, std::size_t size) { - const uint8_t *data = static_cast<const uint8_t*>(data_void); - while (size) { - ssize_t ret = write(fd, data, size); - if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed"); - data += ret; - size -= ret; - } -} - } // namespace WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner) : inner_(inner) {} @@ -80,7 +71,7 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { void WriteWordsWrapper::Write(int fd) { if ((off_t)-1 == lseek(fd, 0, SEEK_END)) UTIL_THROW(util::ErrnoException, "Failed to seek in binary to vocab words"); - WriteOrThrow(fd, buffer_.data(), buffer_.size()); + util::WriteOrThrow(fd, buffer_.data(), buffer_.size()); } SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {} @@ -146,15 +137,28 @@ void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { SetSpecial(Index("<s>"), Index("</s>"), 0); } +namespace { +const unsigned int kProbingVocabularyVersion = 0; +} // namespace + +namespace detail { +struct ProbingVocabularyHeader { + // Lowest unused vocab id. This is also the number of words, including <unk>. + unsigned int version; + WordIndex bound; +}; +} // namespace detail + ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {} std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) { - return Lookup::Size(entries, config.probing_multiplier); + return Align8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier); } void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { - lookup_ = Lookup(start, allocated); - available_ = 1; + header_ = static_cast<detail::ProbingVocabularyHeader*>(start); + lookup_ = Lookup(static_cast<uint8_t*>(start) + Align8(sizeof(detail::ProbingVocabularyHeader)), allocated); + bound_ = 1; saw_unk_ = false; } @@ -172,20 +176,24 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) { saw_unk_ = true; return 0; } else { - if (enumerate_) enumerate_->Add(available_, str); - lookup_.Insert(Lookup::Packing::Make(hashed, available_)); - return available_++; + if (enumerate_) enumerate_->Add(bound_, str); + lookup_.Insert(Lookup::Packing::Make(hashed, bound_)); + return bound_++; } } void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) { lookup_.FinishedInserting(); + header_->bound = bound_; + header_->version = kProbingVocabularyVersion; SetSpecial(Index("<s>"), Index("</s>"), 0); } void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { + UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code."); lookup_.LoadedBinary(); - available_ = ReadWords(fd, to); + ReadWords(fd, to); + bound_ = header_->bound; SetSpecial(Index("<s>"), Index("</s>"), 0); } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 9d218fff..41e97052 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -25,6 +25,7 @@ uint64_t HashForVocab(const char *str, std::size_t len); inline uint64_t HashForVocab(const StringPiece &str) { return HashForVocab(str.data(), str.length()); } +class ProbingVocabularyHeader; } // namespace detail class WriteWordsWrapper : public EnumerateVocab { @@ -113,10 +114,7 @@ class ProbingVocabulary : public base::Vocabulary { static size_t Size(std::size_t entries, const Config &config); // Vocab words are [0, Bound()). - // WARNING WARNING: returns UINT_MAX when loading binary and not enumerating vocabulary. - // Fixing this bug requires a binary file format change and will be fixed with the next binary file format update. - // Specifically, the binary file format does not currently indicate whether <unk> is in count or not. - WordIndex Bound() const { return available_; } + WordIndex Bound() const { return bound_; } // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); @@ -141,11 +139,13 @@ class ProbingVocabulary : public base::Vocabulary { Lookup lookup_; - WordIndex available_; + WordIndex bound_; bool saw_unk_; EnumerateVocab *enumerate_; + + detail::ProbingVocabularyHeader *header_; }; void MissingUnknown(const Config &config) throw(SpecialWordMissingException); diff --git a/klm/test.sh b/klm/test.sh index d02a3dc9..fb33300a 100755 --- a/klm/test.sh +++ b/klm/test.sh @@ -2,7 +2,7 @@ #Run tests. Requires Boost. set -e ./compile.sh -for i in util/{bit_packing,file_piece,joint_sort,key_value_packing,probing_hash_table,sorted_uniform}_test lm/model_test; do +for i in util/{bit_packing,file_piece,joint_sort,key_value_packing,probing_hash_table,sorted_uniform}_test lm/{model,left}_test; do g++ -I. -O3 $CXXFLAGS $i.cc {lm,util}/*.o -lboost_test_exec_monitor -lz -o $i pushd $(dirname $i) >/dev/null && ./$(basename $i) || echo "$i failed"; popd >/dev/null done diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index 9f47d559..33266b94 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -86,6 +86,20 @@ inline void WriteFloat32(void *base, uint64_t bit_off, float value) { const uint32_t kSignBit = 0x80000000; +inline void SetSign(float &to) { + FloatEnc enc; + enc.f = to; + enc.i |= kSignBit; + to = enc.f; +} + +inline void UnsetSign(float &to) { + FloatEnc enc; + enc.f = to; + enc.i &= ~kSignBit; + to = enc.f; +} + inline float ReadNonPositiveFloat31(const void *base, uint64_t bit_off) { FloatEnc encoded; encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 31); diff --git a/klm/util/exception.cc b/klm/util/exception.cc index 62280970..96951495 100644 --- a/klm/util/exception.cc +++ b/klm/util/exception.cc @@ -79,4 +79,9 @@ ErrnoException::ErrnoException() throw() : errno_(errno) { ErrnoException::~ErrnoException() throw() {} +EndOfFileException::EndOfFileException() throw() { + *this << "End of file"; +} +EndOfFileException::~EndOfFileException() throw() {} + } // namespace util diff --git a/klm/util/exception.hh b/klm/util/exception.hh index 81675a57..6d6a37cb 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -105,6 +105,12 @@ class ErrnoException : public Exception { int errno_; }; +class EndOfFileException : public Exception { + public: + EndOfFileException() throw(); + ~EndOfFileException() throw(); +}; + } // namespace util #endif // UTIL_EXCEPTION__ diff --git a/klm/util/file.cc b/klm/util/file.cc new file mode 100644 index 00000000..d707568e --- /dev/null +++ b/klm/util/file.cc @@ -0,0 +1,74 @@ +#include "util/file.hh" + +#include "util/exception.hh" + +#include <cstdlib> +#include <cstdio> +#include <iostream> + +#include <sys/types.h> +#include <sys/stat.h> +#include <fcntl.h> +#include <unistd.h> +#include <inttypes.h> + +namespace util { + +scoped_fd::~scoped_fd() { + if (fd_ != -1 && close(fd_)) { + std::cerr << "Could not close file " << fd_ << std::endl; + std::abort(); + } +} + +scoped_FILE::~scoped_FILE() { + if (file_ && std::fclose(file_)) { + std::cerr << "Could not close file " << std::endl; + std::abort(); + } +} + +int OpenReadOrThrow(const char *name) { + int ret; + UTIL_THROW_IF(-1 == (ret = open(name, O_RDONLY)), ErrnoException, "while opening " << name); + return ret; +} + +int CreateOrThrow(const char *name) { + int ret; + UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR)), ErrnoException, "while creating " << name); + return ret; +} + +off_t SizeFile(int fd) { + struct stat sb; + if (fstat(fd, &sb) == -1 || (!sb.st_size && !S_ISREG(sb.st_mode))) return kBadSize; + return sb.st_size; +} + +void ReadOrThrow(int fd, void *to_void, std::size_t amount) { + uint8_t *to = static_cast<uint8_t*>(to_void); + while (amount) { + ssize_t ret = read(fd, to, amount); + if (ret == -1) UTIL_THROW(ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); + if (ret == 0) UTIL_THROW(Exception, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read."); + amount -= ret; + to += ret; + } +} + +void WriteOrThrow(int fd, const void *data_void, std::size_t size) { + const uint8_t *data = static_cast<const uint8_t*>(data_void); + while (size) { + ssize_t ret = write(fd, data, size); + if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed"); + data += ret; + size -= ret; + } +} + +void RemoveOrThrow(const char *name) { + UTIL_THROW_IF(std::remove(name), util::ErrnoException, "Could not remove " << name); +} + +} // namespace util diff --git a/klm/util/file.hh b/klm/util/file.hh new file mode 100644 index 00000000..d6cca41d --- /dev/null +++ b/klm/util/file.hh @@ -0,0 +1,74 @@ +#ifndef UTIL_FILE__ +#define UTIL_FILE__ + +#include <cstdio> +#include <unistd.h> + +namespace util { + +class scoped_fd { + public: + scoped_fd() : fd_(-1) {} + + explicit scoped_fd(int fd) : fd_(fd) {} + + ~scoped_fd(); + + void reset(int to) { + scoped_fd other(fd_); + fd_ = to; + } + + int get() const { return fd_; } + + int operator*() const { return fd_; } + + int release() { + int ret = fd_; + fd_ = -1; + return ret; + } + + operator bool() { return fd_ != -1; } + + private: + int fd_; + + scoped_fd(const scoped_fd &); + scoped_fd &operator=(const scoped_fd &); +}; + +class scoped_FILE { + public: + explicit scoped_FILE(std::FILE *file = NULL) : file_(file) {} + + ~scoped_FILE(); + + std::FILE *get() { return file_; } + const std::FILE *get() const { return file_; } + + void reset(std::FILE *to = NULL) { + scoped_FILE other(file_); + file_ = to; + } + + private: + std::FILE *file_; +}; + +int OpenReadOrThrow(const char *name); + +int CreateOrThrow(const char *name); + +// Return value for SizeFile when it can't size properly. +const off_t kBadSize = -1; +off_t SizeFile(int fd); + +void ReadOrThrow(int fd, void *to, std::size_t size); +void WriteOrThrow(int fd, const void *data_void, std::size_t size); + +void RemoveOrThrow(const char *name); + +} // namespace util + +#endif // UTIL_FILE__ diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index cbe4234f..b57582a0 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -1,6 +1,7 @@ #include "util/file_piece.hh" #include "util/exception.hh" +#include "util/file.hh" #include <iostream> #include <string> @@ -21,11 +22,6 @@ namespace util { -EndOfFileException::EndOfFileException() throw() { - *this << "End of file"; -} -EndOfFileException::~EndOfFileException() throw() {} - ParseNumberException::ParseNumberException(StringPiece value) throw() { *this << "Could not parse \"" << value << "\" into a number"; } @@ -40,18 +36,6 @@ GZException::GZException(void *file) { // Sigh this is the only way I could come up with to do a _const_ bool. It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale). const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; -int OpenReadOrThrow(const char *name) { - int ret; - UTIL_THROW_IF(-1 == (ret = open(name, O_RDONLY)), ErrnoException, "while opening " << name); - return ret; -} - -off_t SizeFile(int fd) { - struct stat sb; - if (fstat(fd, &sb) == -1 || (!sb.st_size && !S_ISREG(sb.st_mode))) return kBadSize; - return sb.st_size; -} - FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) : file_(OpenReadOrThrow(name)), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)), progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) { diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index a5c00910..a627f38c 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -3,9 +3,9 @@ #include "util/ersatz_progress.hh" #include "util/exception.hh" +#include "util/file.hh" #include "util/have.hh" #include "util/mmap.hh" -#include "util/scoped.hh" #include "util/string_piece.hh" #include <string> @@ -14,12 +14,6 @@ namespace util { -class EndOfFileException : public Exception { - public: - EndOfFileException() throw(); - ~EndOfFileException() throw(); -}; - class ParseNumberException : public Exception { public: explicit ParseNumberException(StringPiece value) throw(); @@ -33,14 +27,8 @@ class GZException : public Exception { ~GZException() throw() {} }; -int OpenReadOrThrow(const char *name); - extern const bool kSpaces[256]; -// Return value for SizeFile when it can't size properly. -const off_t kBadSize = -1; -off_t SizeFile(int fd); - // Memory backing the returned StringPiece may vanish on the next call. class FilePiece { public: diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc index e7c0643b..5ce7adc9 100644 --- a/klm/util/mmap.cc +++ b/klm/util/mmap.cc @@ -1,6 +1,6 @@ #include "util/exception.hh" +#include "util/file.hh" #include "util/mmap.hh" -#include "util/scoped.hh" #include <iostream> @@ -66,20 +66,6 @@ void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int return ret; } -namespace { -void ReadAll(int fd, void *to_void, std::size_t amount) { - uint8_t *to = static_cast<uint8_t*>(to_void); - while (amount) { - ssize_t ret = read(fd, to, amount); - if (ret == -1) UTIL_THROW(ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); - if (ret == 0) UTIL_THROW(Exception, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read."); - amount -= ret; - to += ret; - } -} - -} // namespace - const int kFileFlags = #ifdef MAP_FILE MAP_FILE | MAP_SHARED @@ -106,7 +92,7 @@ void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_m out.reset(malloc(size), size, scoped_memory::MALLOC_ALLOCATED); if (!out.get()) UTIL_THROW(util::ErrnoException, "Allocating " << size << " bytes with malloc"); if (-1 == lseek(fd, offset, SEEK_SET)) UTIL_THROW(ErrnoException, "lseek to " << offset << " in fd " << fd << " failed."); - ReadAll(fd, out.get(), size); + ReadOrThrow(fd, out.get(), size); break; } } diff --git a/klm/util/mmap.hh b/klm/util/mmap.hh index e4439fa4..b0eb6672 100644 --- a/klm/util/mmap.hh +++ b/klm/util/mmap.hh @@ -2,8 +2,6 @@ #define UTIL_MMAP__ // Utilities for mmaped files. -#include "util/scoped.hh" - #include <cstddef> #include <inttypes.h> @@ -11,6 +9,8 @@ namespace util { +class scoped_fd; + // (void*)-1 is MAP_FAILED; this is done to avoid including the mmap header here. class scoped_mmap { public: diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc index fec47fd9..d58a0727 100644 --- a/klm/util/murmur_hash.cc +++ b/klm/util/murmur_hash.cc @@ -1,129 +1,129 @@ -/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All - * code is released to the public domain. For business purposes, Murmurhash is - * under the MIT license." - * This is modified from the original: - * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. - * length changed to unsigned int. - * placed in namespace util - * add MurmurHashNative - * default option = 0 for seed - */ - -#include "util/murmur_hash.hh" - -namespace util { - -//----------------------------------------------------------------------------- -// MurmurHash2, 64-bit versions, by Austin Appleby - -// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment -// and endian-ness issues if used across multiple platforms. - -// 64-bit hash for 64-bit platforms - -uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) -{ - const uint64_t m = 0xc6a4a7935bd1e995ULL; - const int r = 47; - - uint64_t h = seed ^ (len * m); - - const uint64_t * data = (const uint64_t *)key; - const uint64_t * end = data + (len/8); - - while(data != end) - { - uint64_t k = *data++; - - k *= m; - k ^= k >> r; - k *= m; - - h ^= k; - h *= m; - } - - const unsigned char * data2 = (const unsigned char*)data; - - switch(len & 7) - { - case 7: h ^= uint64_t(data2[6]) << 48; - case 6: h ^= uint64_t(data2[5]) << 40; - case 5: h ^= uint64_t(data2[4]) << 32; - case 4: h ^= uint64_t(data2[3]) << 24; - case 3: h ^= uint64_t(data2[2]) << 16; - case 2: h ^= uint64_t(data2[1]) << 8; - case 1: h ^= uint64_t(data2[0]); - h *= m; - }; - - h ^= h >> r; - h *= m; - h ^= h >> r; - - return h; -} - - -// 64-bit hash for 32-bit platforms - -uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) -{ - const unsigned int m = 0x5bd1e995; - const int r = 24; - - unsigned int h1 = seed ^ len; - unsigned int h2 = 0; - - const unsigned int * data = (const unsigned int *)key; - - while(len >= 8) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - - unsigned int k2 = *data++; - k2 *= m; k2 ^= k2 >> r; k2 *= m; - h2 *= m; h2 ^= k2; - len -= 4; - } - - if(len >= 4) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - } - - switch(len) - { - case 3: h2 ^= ((unsigned char*)data)[2] << 16; - case 2: h2 ^= ((unsigned char*)data)[1] << 8; - case 1: h2 ^= ((unsigned char*)data)[0]; - h2 *= m; - }; - - h1 ^= h2 >> 18; h1 *= m; - h2 ^= h1 >> 22; h2 *= m; - h1 ^= h2 >> 17; h1 *= m; - h2 ^= h1 >> 19; h2 *= m; - - uint64_t h = h1; - - h = (h << 32) | h2; - - return h; -} - -uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { - if (sizeof(int) == 4) { - return MurmurHash64B(key, len, seed); - } else { - return MurmurHash64A(key, len, seed); - } -} - -} // namespace util +/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All
+ * code is released to the public domain. For business purposes, Murmurhash is
+ * under the MIT license."
+ * This is modified from the original:
+ * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit.
+ * length changed to unsigned int.
+ * placed in namespace util
+ * add MurmurHashNative
+ * default option = 0 for seed
+ */
+
+#include "util/murmur_hash.hh"
+
+namespace util {
+
+//-----------------------------------------------------------------------------
+// MurmurHash2, 64-bit versions, by Austin Appleby
+
+// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment
+// and endian-ness issues if used across multiple platforms.
+
+// 64-bit hash for 64-bit platforms
+
+uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed )
+{
+ const uint64_t m = 0xc6a4a7935bd1e995ULL;
+ const int r = 47;
+
+ uint64_t h = seed ^ (len * m);
+
+ const uint64_t * data = (const uint64_t *)key;
+ const uint64_t * end = data + (len/8);
+
+ while(data != end)
+ {
+ uint64_t k = *data++;
+
+ k *= m;
+ k ^= k >> r;
+ k *= m;
+
+ h ^= k;
+ h *= m;
+ }
+
+ const unsigned char * data2 = (const unsigned char*)data;
+
+ switch(len & 7)
+ {
+ case 7: h ^= uint64_t(data2[6]) << 48;
+ case 6: h ^= uint64_t(data2[5]) << 40;
+ case 5: h ^= uint64_t(data2[4]) << 32;
+ case 4: h ^= uint64_t(data2[3]) << 24;
+ case 3: h ^= uint64_t(data2[2]) << 16;
+ case 2: h ^= uint64_t(data2[1]) << 8;
+ case 1: h ^= uint64_t(data2[0]);
+ h *= m;
+ };
+
+ h ^= h >> r;
+ h *= m;
+ h ^= h >> r;
+
+ return h;
+}
+
+
+// 64-bit hash for 32-bit platforms
+
+uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed )
+{
+ const unsigned int m = 0x5bd1e995;
+ const int r = 24;
+
+ unsigned int h1 = seed ^ len;
+ unsigned int h2 = 0;
+
+ const unsigned int * data = (const unsigned int *)key;
+
+ while(len >= 8)
+ {
+ unsigned int k1 = *data++;
+ k1 *= m; k1 ^= k1 >> r; k1 *= m;
+ h1 *= m; h1 ^= k1;
+ len -= 4;
+
+ unsigned int k2 = *data++;
+ k2 *= m; k2 ^= k2 >> r; k2 *= m;
+ h2 *= m; h2 ^= k2;
+ len -= 4;
+ }
+
+ if(len >= 4)
+ {
+ unsigned int k1 = *data++;
+ k1 *= m; k1 ^= k1 >> r; k1 *= m;
+ h1 *= m; h1 ^= k1;
+ len -= 4;
+ }
+
+ switch(len)
+ {
+ case 3: h2 ^= ((unsigned char*)data)[2] << 16;
+ case 2: h2 ^= ((unsigned char*)data)[1] << 8;
+ case 1: h2 ^= ((unsigned char*)data)[0];
+ h2 *= m;
+ };
+
+ h1 ^= h2 >> 18; h1 *= m;
+ h2 ^= h1 >> 22; h2 *= m;
+ h1 ^= h2 >> 17; h1 *= m;
+ h2 ^= h1 >> 19; h2 *= m;
+
+ uint64_t h = h1;
+
+ h = (h << 32) | h2;
+
+ return h;
+}
+
+uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) {
+ if (sizeof(int) == 4) {
+ return MurmurHash64B(key, len, seed);
+ } else {
+ return MurmurHash64A(key, len, seed);
+ }
+}
+
+} // namespace util
diff --git a/klm/util/scoped.cc b/klm/util/scoped.cc deleted file mode 100644 index a4cc5016..00000000 --- a/klm/util/scoped.cc +++ /dev/null @@ -1,24 +0,0 @@ -#include "util/scoped.hh" - -#include <iostream> - -#include <stdlib.h> -#include <unistd.h> - -namespace util { - -scoped_fd::~scoped_fd() { - if (fd_ != -1 && close(fd_)) { - std::cerr << "Could not close file " << fd_ << std::endl; - abort(); - } -} - -scoped_FILE::~scoped_FILE() { - if (file_ && fclose(file_)) { - std::cerr << "Could not close file " << std::endl; - abort(); - } -} - -} // namespace util diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh index d36a7df3..12e6652b 100644 --- a/klm/util/scoped.hh +++ b/klm/util/scoped.hh @@ -1,10 +1,11 @@ #ifndef UTIL_SCOPED__ #define UTIL_SCOPED__ -/* Other scoped objects in the style of scoped_ptr. */ +#include "util/exception.hh" +/* Other scoped objects in the style of scoped_ptr. */ #include <cstddef> -#include <cstdio> +#include <cstdlib> namespace util { @@ -34,52 +35,33 @@ template <class T, class R, R (*Free)(T*)> class scoped_thing { scoped_thing &operator=(const scoped_thing &); }; -class scoped_fd { +class scoped_malloc { public: - scoped_fd() : fd_(-1) {} + scoped_malloc() : p_(NULL) {} - explicit scoped_fd(int fd) : fd_(fd) {} + scoped_malloc(void *p) : p_(p) {} - ~scoped_fd(); + ~scoped_malloc() { std::free(p_); } - void reset(int to) { - scoped_fd other(fd_); - fd_ = to; + void reset(void *p = NULL) { + scoped_malloc other(p_); + p_ = p; } - int get() const { return fd_; } - - int operator*() const { return fd_; } - - int release() { - int ret = fd_; - fd_ = -1; - return ret; + void call_realloc(std::size_t to) { + void *ret; + UTIL_THROW_IF(!(ret = std::realloc(p_, to)), util::ErrnoException, "realloc to " << to << " bytes failed."); + p_ = ret; } - private: - int fd_; - - scoped_fd(const scoped_fd &); - scoped_fd &operator=(const scoped_fd &); -}; - -class scoped_FILE { - public: - explicit scoped_FILE(std::FILE *file = NULL) : file_(file) {} - - ~scoped_FILE(); - - std::FILE *get() { return file_; } - const std::FILE *get() const { return file_; } - - void reset(std::FILE *to = NULL) { - scoped_FILE other(file_); - file_ = to; - } + void *get() { return p_; } + const void *get() const { return p_; } private: - std::FILE *file_; + void *p_; + + scoped_malloc(const scoped_malloc &); + scoped_malloc &operator=(const scoped_malloc &); }; // Hat tip to boost. diff --git a/klm/util/sized_iterator.hh b/klm/util/sized_iterator.hh new file mode 100644 index 00000000..47dfc245 --- /dev/null +++ b/klm/util/sized_iterator.hh @@ -0,0 +1,107 @@ +#ifndef UTIL_SIZED_ITERATOR__ +#define UTIL_SIZED_ITERATOR__ + +#include "util/proxy_iterator.hh" + +#include <functional> +#include <string> + +#include <inttypes.h> +#include <string.h> + +namespace util { + +class SizedInnerIterator { + public: + SizedInnerIterator() {} + + SizedInnerIterator(void *ptr, std::size_t size) : ptr_(static_cast<uint8_t*>(ptr)), size_(size) {} + + bool operator==(const SizedInnerIterator &other) const { + return ptr_ == other.ptr_; + } + bool operator<(const SizedInnerIterator &other) const { + return ptr_ < other.ptr_; + } + SizedInnerIterator &operator+=(std::ptrdiff_t amount) { + ptr_ += amount * size_; + return *this; + } + std::ptrdiff_t operator-(const SizedInnerIterator &other) const { + return (ptr_ - other.ptr_) / size_; + } + + const void *Data() const { return ptr_; } + void *Data() { return ptr_; } + std::size_t EntrySize() const { return size_; } + + private: + uint8_t *ptr_; + std::size_t size_; +}; + +class SizedProxy { + public: + SizedProxy() {} + + SizedProxy(void *ptr, std::size_t size) : inner_(ptr, size) {} + + operator std::string() const { + return std::string(reinterpret_cast<const char*>(inner_.Data()), inner_.EntrySize()); + } + + SizedProxy &operator=(const SizedProxy &from) { + memcpy(inner_.Data(), from.inner_.Data(), inner_.EntrySize()); + return *this; + } + + SizedProxy &operator=(const std::string &from) { + memcpy(inner_.Data(), from.data(), inner_.EntrySize()); + return *this; + } + + const void *Data() const { return inner_.Data(); } + void *Data() { return inner_.Data(); } + + private: + friend class util::ProxyIterator<SizedProxy>; + + typedef std::string value_type; + + typedef SizedInnerIterator InnerIterator; + + InnerIterator &Inner() { return inner_; } + const InnerIterator &Inner() const { return inner_; } + InnerIterator inner_; +}; + +typedef ProxyIterator<SizedProxy> SizedIterator; + +inline SizedIterator SizedIt(void *ptr, std::size_t size) { return SizedIterator(SizedProxy(ptr, size)); } + +// Useful wrapper for a comparison function i.e. sort. +template <class Delegate, class Proxy = SizedProxy> class SizedCompare : public std::binary_function<const Proxy &, const Proxy &, bool> { + public: + explicit SizedCompare(const Delegate &delegate = Delegate()) : delegate_(delegate) {} + + bool operator()(const Proxy &first, const Proxy &second) const { + return delegate_(first.Data(), second.Data()); + } + bool operator()(const Proxy &first, const std::string &second) const { + return delegate_(first.Data(), second.data()); + } + bool operator()(const std::string &first, const Proxy &second) const { + return delegate_(first.data(), second.Data()); + } + bool operator()(const std::string &first, const std::string &second) const { + return delegate_(first.data(), second.data()); + } + + const Delegate &GetDelegate() const { return delegate_; } + + private: + const Delegate delegate_; +}; + +} // namespace util +#endif // UTIL_SIZED_ITERATOR__ diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh new file mode 100644 index 00000000..ee1c7ab2 --- /dev/null +++ b/klm/util/tokenize_piece.hh @@ -0,0 +1,69 @@ +#ifndef UTIL_TOKENIZE_PIECE__ +#define UTIL_TOKENIZE_PIECE__ + +#include "util/string_piece.hh" + +#include <boost/iterator/iterator_facade.hpp> + +/* Usage: + * + * for (PieceIterator<' '> i(" foo \r\n bar "); i; ++i) { + * std::cout << *i << "\n"; + * } + * + */ + +namespace util { + +// Tokenize a StringPiece using an iterator interface. boost::tokenizer doesn't work with StringPiece. +template <char d> class PieceIterator : public boost::iterator_facade<PieceIterator<d>, const StringPiece, boost::forward_traversal_tag> { + public: + // Default construct is end, which is also accessed by kEndPieceIterator; + PieceIterator() {} + + explicit PieceIterator(const StringPiece &str) + : after_(str) { + increment(); + } + + bool operator!() const { + return after_.data() == 0; + } + operator bool() const { + return after_.data() != 0; + } + + static PieceIterator<d> end() { + return PieceIterator<d>(); + } + + private: + friend class boost::iterator_core_access; + + void increment() { + const char *start = after_.data(); + for (; (start != after_.data() + after_.size()) && (d == *start); ++start) {} + if (start == after_.data() + after_.size()) { + // End condition. + after_.clear(); + return; + } + const char *finish = start; + for (; (finish != after_.data() + after_.size()) && (d != *finish); ++finish) {} + current_ = StringPiece(start, finish - start); + after_ = StringPiece(finish, after_.data() + after_.size() - finish); + } + + bool equal(const PieceIterator &other) const { + return after_.data() == other.after_.data(); + } + + const StringPiece &dereference() const { return current_; } + + StringPiece current_; + StringPiece after_; +}; + +} // namespace util + +#endif // UTIL_TOKENIZE_PIECE__ |