diff options
-rw-r--r-- | decoder/cdec_ff.cc | 1 | ||||
-rw-r--r-- | decoder/ff_klm.cc | 1 | ||||
-rw-r--r-- | klm/lm/binary_format.cc | 2 | ||||
-rw-r--r-- | klm/lm/binary_format.hh | 13 | ||||
-rw-r--r-- | klm/lm/build_binary.cc | 15 | ||||
-rw-r--r-- | klm/lm/config.cc | 1 | ||||
-rw-r--r-- | klm/lm/config.hh | 6 | ||||
-rw-r--r-- | klm/lm/lm_exception.hh | 2 | ||||
-rw-r--r-- | klm/lm/model.cc | 2 | ||||
-rw-r--r-- | klm/lm/model.hh | 33 | ||||
-rw-r--r-- | klm/lm/read_arpa.cc | 14 | ||||
-rw-r--r-- | klm/lm/read_arpa.hh | 33 | ||||
-rw-r--r-- | klm/lm/search_hashed.cc | 16 | ||||
-rw-r--r-- | klm/lm/search_trie.cc | 27 | ||||
-rw-r--r-- | klm/lm/sri.cc | 18 | ||||
-rw-r--r-- | klm/lm/virtual_interface.hh | 26 | ||||
-rw-r--r-- | klm/lm/vocab.cc | 12 | ||||
-rw-r--r-- | klm/util/file_piece.cc | 7 | ||||
-rw-r--r-- | klm/util/mmap.cc | 5 | ||||
-rw-r--r-- | klm/util/string_piece.hh | 19 |
20 files changed, 180 insertions, 73 deletions
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 7ec54a5a..64a2bd94 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -53,7 +53,6 @@ void register_feature_functions() { ff_registry.Register("SpanFeatures", new FFFactory<SpanFeatures>()); ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory<CMR2008ReorderingFeatures>()); ff_registry.Register("KLanguageModel", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >()); - ff_registry.Register("KLanguageModel_Sorted", new FFFactory<KLanguageModel<lm::ngram::SortedModel> >()); ff_registry.Register("KLanguageModel_Trie", new FFFactory<KLanguageModel<lm::ngram::TrieModel> >()); ff_registry.Register("KLanguageModel_Probing", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >()); ff_registry.Register("NonLatinCount", new FFFactory<NonLatinCount>); diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index cc7cd427..ab44232a 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -426,6 +426,5 @@ void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state, // instantiate templates template class KLanguageModel<lm::ngram::ProbingModel>; -template class KLanguageModel<lm::ngram::SortedModel>; template class KLanguageModel<lm::ngram::TrieModel>; diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 9be0bc8e..34d9ffca 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -112,6 +112,8 @@ uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &b void FinishFile(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, Backing &backing) { if (config.write_mmap) { + if (msync(backing.search.get(), backing.search.size(), MS_SYNC) || msync(backing.vocab.get(), backing.vocab.size(), MS_SYNC)) + UTIL_THROW(util::ErrnoException, "msync failed for " << config.write_mmap); // header and vocab share the same mmap. The header is written here because we know the counts. Parameters params; params.counts = counts; diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index 72d8c159..1fc71be4 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -18,6 +18,12 @@ namespace ngram { typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2} ModelType; +/*Inspect a file to determine if it is a binary lm. If not, return false. + * If so, return true and set recognized to the type. This is the only API in + * this header designed for use by decoder authors. + */ +bool RecognizeBinary(const char *file, ModelType &recognized); + struct FixedWidthParameters { unsigned char order; float probing_multiplier; @@ -27,6 +33,7 @@ struct FixedWidthParameters { bool has_vocabulary; }; +// Parameters stored in the header of a binary file. struct Parameters { FixedWidthParameters fixed; std::vector<uint64_t> counts; @@ -41,10 +48,13 @@ struct Backing { util::scoped_memory search; }; +// Create just enough of a binary file to write vocabulary to it. uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); // Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing); +// Write header to binary file. This is done last to prevent incomplete files +// from loading. void FinishFile(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, Backing &backing); namespace detail { @@ -61,8 +71,6 @@ void ComplainAboutARPA(const Config &config, ModelType model_type); } // namespace detail -bool RecognizeBinary(const char *file, ModelType &recognized); - template <class To> void LoadLM(const char *file, const Config &config, To &to) { Backing &backing = to.MutableBacking(); backing.file.reset(util::OpenReadOrThrow(file)); @@ -86,7 +94,6 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to) e << " File: " << file; throw; } - } } // namespace ngram diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 920ff080..91ad2fb9 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,10 +15,11 @@ namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n" + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n" "-u sets the default log10 probability for <unk> if the ARPA file does not have\n" "one.\n" -"-s allows models to be built even if they do not have <s> and </s>.\n\n" +"-s allows models to be built even if they do not have <s> and </s>.\n" +"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" "type is one of probing, trie, or sorted:\n\n" "probing uses a probing hash table. It is the fastest but uses the most memory.\n" "-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" @@ -63,7 +64,6 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { std::cout << "bytes\n" "probing " << std::setw(length) << probing_size << " assuming -p " << config.probing_multiplier << "\n" "trie " << std::setw(length) << TrieModel::Size(counts, config) << "\n"; -/* "sorted " << std::setw(length) << SortedModel::Size(counts, config) << "\n";*/ } } // namespace ngram @@ -76,7 +76,7 @@ int main(int argc, char *argv[]) { try { lm::ngram::Config config; int opt; - while ((opt = getopt(argc, argv, "su:p:t:m:")) != -1) { + while ((opt = getopt(argc, argv, "siu:p:t:m:")) != -1) { switch(opt) { case 'u': config.unknown_missing_logprob = ParseFloat(optarg); @@ -91,7 +91,10 @@ int main(int argc, char *argv[]) { config.building_memory = ParseUInt(optarg) * 1048576; break; case 's': - config.sentence_marker_missing = lm::ngram::Config::SILENT; + config.sentence_marker_missing = lm::SILENT; + break; + case 'i': + config.positive_log_probability = lm::SILENT; break; default: Usage(argv[0]); @@ -108,8 +111,6 @@ int main(int argc, char *argv[]) { config.write_mmap = argv[optind + 2]; if (!strcmp(model_type, "probing")) { ProbingModel(from_file, config); - } else if (!strcmp(model_type, "sorted")) { - SortedModel(from_file, config); } else if (!strcmp(model_type, "trie")) { TrieModel(from_file, config); } else { diff --git a/klm/lm/config.cc b/klm/lm/config.cc index 71646e51..cee8fce2 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -10,6 +10,7 @@ Config::Config() : enumerate_vocab(NULL), unknown_missing(COMPLAIN), sentence_marker_missing(THROW_UP), + positive_log_probability(THROW_UP), unknown_missing_logprob(-100.0), probing_multiplier(1.5), building_memory(1073741824ULL), // 1 GB diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 1f7762be..6c7fe39b 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -3,6 +3,7 @@ #include <iosfwd> +#include "lm/lm_exception.hh" #include "util/mmap.hh" /* Configuration for ngram model. Separate header to reduce pollution. */ @@ -27,13 +28,16 @@ struct Config { // ONLY EFFECTIVE WHEN READING ARPA - typedef enum {THROW_UP, COMPLAIN, SILENT} WarningAction; // What to do when <unk> isn't in the provided model. WarningAction unknown_missing; // What to do when <s> or </s> is missing from the model. // If THROW_UP, the exception will be of type util::SpecialWordMissingException. WarningAction sentence_marker_missing; + // What to do with a positive log probability. For COMPLAIN and SILENT, map + // to 0. + WarningAction positive_log_probability; + // The probability to substitute for <unk> if it's missing from the model. // No effect if the model has <unk> or unknown_missing == THROW_UP. float unknown_missing_logprob; diff --git a/klm/lm/lm_exception.hh b/klm/lm/lm_exception.hh index aa3ca886..f607ced1 100644 --- a/klm/lm/lm_exception.hh +++ b/klm/lm/lm_exception.hh @@ -11,6 +11,8 @@ namespace lm { +typedef enum {THROW_UP, COMPLAIN, SILENT} WarningAction; + class ConfigException : public util::Exception { public: ConfigException() throw(); diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 1492276a..f0579c0c 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -83,7 +83,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT } if (!vocab_.SawUnk()) { - assert(config.unknown_missing != Config::THROW_UP); + assert(config.unknown_missing != THROW_UP); // Default probabilities for unknown. search_.unigram.Unknown().backoff = 0.0; search_.unigram.Unknown().prob = config.unknown_missing_logprob; diff --git a/klm/lm/model.hh b/klm/lm/model.hh index fd9640c3..b85ccdcc 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -65,7 +65,7 @@ size_t hash_value(const State &state); namespace detail { // Should return the same results as SRI. -// Why VocabularyT instead of just Vocabulary? ModelFacade defines Vocabulary. +// ModelFacade typedefs Vocabulary so we use VocabularyT to avoid naming conflicts. template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> { private: typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P; @@ -75,23 +75,37 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod // itself. static size_t Size(const std::vector<uint64_t> &counts, const Config &config = Config()); + /* Load the model from a file. It may be an ARPA or binary file. Binary + * files must have the format expected by this class or you'll get an + * exception. So TrieModel can only load ARPA or binary created by + * TrieModel. To classify binary files, call RecognizeBinary in + * lm/binary_format.hh. + */ GenericModel(const char *file, const Config &config = Config()); + /* Score p(new_word | in_state) and incorporate new_word into out_state. + * Note that in_state and out_state must be different references: + * &in_state != &out_state. + */ FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const; - /* Slower call without in_state. Don't use this if you can avoid it. This - * is mostly a hack for Hieu to integrate it into Moses which sometimes - * forgets LM state (i.e. it doesn't store it with the phrase). Sigh. - * The context indices should be in an array. - * If context_rbegin != context_rend then *context_rbegin is the word - * before new_word. + /* Slower call without in_state. Try to remember state, but sometimes it + * would cost too much memory or your decoder isn't setup properly. + * To use this function, make an array of WordIndex containing the context + * vocabulary ids in reverse order. Then, pass the bounds of the array: + * [context_rbegin, context_rend). The new_word is not part of the context + * array unless you intend to repeat words. */ FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; /* Get the state for a context. Don't use this if you can avoid it. Use * BeginSentenceState or EmptyContextState and extend from those. If * you're only going to use this state to call FullScore once, use - * FullScoreForgotState. */ + * FullScoreForgotState. + * To use this function, make an array of WordIndex containing the context + * vocabulary ids in reverse order. Then, pass the bounds of the array: + * [context_rbegin, context_rend). + */ void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const; private: @@ -131,9 +145,8 @@ typedef detail::GenericModel<detail::ProbingHashedSearch, Vocabulary> ProbingMod // Default implementation. No real reason for it to be the default. typedef ProbingModel Model; +// Smaller implementation. typedef ::lm::ngram::SortedVocabulary SortedVocabulary; -typedef detail::GenericModel<detail::SortedHashedSearch, SortedVocabulary> SortedModel; - typedef detail::GenericModel<trie::TrieSearch, SortedVocabulary> TrieModel; } // namespace ngram diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 0e90196d..060a97ea 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -3,6 +3,7 @@ #include "lm/blank.hh" #include <cstdlib> +#include <iostream> #include <vector> #include <ctype.h> @@ -115,4 +116,17 @@ void ReadEnd(util::FilePiece &in) { } catch (const util::EndOfFileException &e) {} } +void PositiveProbWarn::Warn(float prob) { + switch (action_) { + case THROW_UP: + UTIL_THROW(FormatLoadException, "Positive log probability " << prob << " in the model. This is a bug in IRSTLM; you can set config.positive_log_probability = SILENT or pass -i to build_binary to substitute 0.0 for the log probability. Error"); + case COMPLAIN: + std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM. This and subsequent entires will be mapepd to 0 log probability." << std::endl; + action_ = SILENT; + break; + case SILENT: + break; + } +} + } // namespace lm diff --git a/klm/lm/read_arpa.hh b/klm/lm/read_arpa.hh index 4953d40e..ab996bde 100644 --- a/klm/lm/read_arpa.hh +++ b/klm/lm/read_arpa.hh @@ -22,10 +22,26 @@ void ReadEnd(util::FilePiece &in); extern const bool kARPASpaces[256]; -template <class Voc> void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff *unigrams) { +// Positive log probability warning. +class PositiveProbWarn { + public: + PositiveProbWarn() : action_(THROW_UP) {} + + explicit PositiveProbWarn(WarningAction action) : action_(action) {} + + void Warn(float prob); + + private: + WarningAction action_; +}; + +template <class Voc> void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff *unigrams, PositiveProbWarn &warn) { try { float prob = f.ReadFloat(); - if (prob > 0) UTIL_THROW(FormatLoadException, "Positive probability " << prob); + if (prob > 0.0) { + warn.Warn(prob); + prob = 0.0; + } if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability"); ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited(kARPASpaces))]; value.prob = prob; @@ -36,18 +52,23 @@ template <class Voc> void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff } } -template <class Voc> void Read1Grams(util::FilePiece &f, std::size_t count, Voc &vocab, ProbBackoff *unigrams) { +// Return true if a positive log probability came out. +template <class Voc> void Read1Grams(util::FilePiece &f, std::size_t count, Voc &vocab, ProbBackoff *unigrams, PositiveProbWarn &warn) { ReadNGramHeader(f, 1); for (std::size_t i = 0; i < count; ++i) { - Read1Gram(f, vocab, unigrams); + Read1Gram(f, vocab, unigrams, warn); } vocab.FinishedLoading(unigrams); } -template <class Voc, class Weights> void ReadNGram(util::FilePiece &f, const unsigned char n, const Voc &vocab, WordIndex *const reverse_indices, Weights &weights) { +// Return true if a positive log probability came out. +template <class Voc, class Weights> void ReadNGram(util::FilePiece &f, const unsigned char n, const Voc &vocab, WordIndex *const reverse_indices, Weights &weights, PositiveProbWarn &warn) { try { weights.prob = f.ReadFloat(); - if (weights.prob > 0) UTIL_THROW(FormatLoadException, "Positive probability " << weights.prob); + if (weights.prob > 0.0) { + warn.Warn(weights.prob); + weights.prob = 0.0; + } for (WordIndex *vocab_out = reverse_indices + n - 1; vocab_out >= reverse_indices; --vocab_out) { *vocab_out = vocab.Index(f.ReadDelimited(kARPASpaces)); } diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index bb3b955a..eaad59ab 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -48,7 +48,7 @@ class ActivateUnigram { ProbBackoff *modify_; }; -template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector<Middle> &middle, Activate activate, Store &store) { +template <class Voc, class Store, class Middle, class Activate> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector<Middle> &middle, Activate activate, Store &store, PositiveProbWarn &warn) { ReadNGramHeader(f, n); ProbBackoff blank; @@ -61,7 +61,7 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams( typename Store::Packing::Value value; typename Middle::ConstIterator found; for (size_t i = 0; i < count; ++i) { - ReadNGram(f, n, vocab, vocab_ids, value); + ReadNGram(f, n, vocab, vocab_ids, value, warn); keys[0] = detail::CombineWordHash(static_cast<uint64_t>(*vocab_ids), vocab_ids[1]); for (unsigned int h = 1; h < n - 1; ++h) { keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]); @@ -85,20 +85,22 @@ template <class MiddleT, class LongestT> template <class Voc> void TemplateHashe // TODO: fix sorted. SetupMemory(GrowForSearch(config, Size(counts, config), backing), counts, config); - Read1Grams(f, counts[0], vocab, unigram.Raw()); + PositiveProbWarn warn(config.positive_log_probability); + + Read1Grams(f, counts[0], vocab, unigram.Raw(), warn); CheckSpecials(config, vocab); try { if (counts.size() > 2) { - ReadNGrams(f, 2, counts[1], vocab, middle, ActivateUnigram(unigram.Raw()), middle[0]); + ReadNGrams(f, 2, counts[1], vocab, middle, ActivateUnigram(unigram.Raw()), middle[0], warn); } for (unsigned int n = 3; n < counts.size(); ++n) { - ReadNGrams(f, n, counts[n-1], vocab, middle, ActivateLowerMiddle<Middle>(middle[n-3]), middle[n-2]); + ReadNGrams(f, n, counts[n-1], vocab, middle, ActivateLowerMiddle<Middle>(middle[n-3]), middle[n-2], warn); } if (counts.size() > 2) { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateLowerMiddle<Middle>(middle.back()), longest); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateLowerMiddle<Middle>(middle.back()), longest, warn); } else { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateUnigram(unigram.Raw()), longest); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateUnigram(unigram.Raw()), longest, warn); } } catch (util::ProbingSizeException &e) { UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n"); diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index b830dfc3..7c57072b 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -293,7 +293,7 @@ class SortedFileReader { ReadOrThrow(file_.get(), &weights, sizeof(Weights)); } - bool Ended() { + bool Ended() const { return ended_; } @@ -480,7 +480,7 @@ void MergeContextFiles(const std::string &first_base, const std::string &second_ CopyRestOrThrow(remaining.GetFile(), out.get()); } -void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) { +void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) { ReadNGramHeader(f, order); const size_t count = counts[order - 1]; // Size of weights. Does it include backoff? @@ -495,11 +495,11 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; if (order == counts.size()) { for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size)); + ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn); } } else { for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size)); + ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn); } } // Sort full records by full n-gram. @@ -536,13 +536,14 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st } void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { + PositiveProbWarn warn(config.positive_log_probability); { std::string unigram_name = file_prefix + "unigrams"; util::scoped_fd unigram_file; // In case <unk> appears. size_t extra_count = counts[0] + 1; util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff)); - Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get())); + Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn); CheckSpecials(config, vocab); if (!vocab.SawUnk()) ++counts[0]; } @@ -560,7 +561,7 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uin if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); for (unsigned char order = 2; order <= counts.size(); ++order) { - ConvertToSorted(f, vocab, counts, mem, file_prefix, order); + ConvertToSorted(f, vocab, counts, mem, file_prefix, order, warn); } ReadEnd(f); } @@ -775,8 +776,8 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u } void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) { - SortedFileReader inputs[counts.size() - 1]; - ContextReader contexts[counts.size() - 1]; + std::vector<SortedFileReader> inputs(counts.size() - 1); + std::vector<ContextReader> contexts(counts.size() - 1); for (unsigned char i = 2; i <= counts.size(); ++i) { std::stringstream assembled; @@ -790,11 +791,11 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co std::vector<uint64_t> fixed_counts(counts.size()); { - RecursiveInsert<JustCount> counter(inputs, contexts, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert<JustCount> counter(&*inputs.begin(), &*contexts.begin(), NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); } - for (SortedFileReader *i = inputs; i < inputs + counts.size() - 1; ++i) { - if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading"); + for (std::vector<SortedFileReader>::const_iterator i = inputs.begin(); i != inputs.end(); ++i) { + if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs.begin() + 2) << "-gram table did not complete reading"); } SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; @@ -807,7 +808,7 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert<WriteEntries> inserter(inputs, contexts, unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert<WriteEntries> inserter(&*inputs.begin(), &*contexts.begin(), unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); inserter.Apply(config.messages, "Building trie", fixed_counts[0]); } @@ -849,7 +850,7 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex()); } out.middle.back().FinishedLoading(out.longest.InsertIndex()); - } + } } bool IsDirectory(const char *path) { diff --git a/klm/lm/sri.cc b/klm/lm/sri.cc index b634d200..825f699b 100644 --- a/klm/lm/sri.cc +++ b/klm/lm/sri.cc @@ -93,18 +93,12 @@ FullScoreReturn Model::FullScore(const State &in_state, const WordIndex new_word const_history = local_history; } FullScoreReturn ret; - if (new_word != not_found_) { - ret.ngram_length = MatchedLength(*sri_, new_word, const_history); - out_state.history_[0] = new_word; - out_state.valid_length_ = std::min<unsigned char>(ret.ngram_length, Order() - 1); - std::copy(const_history, const_history + out_state.valid_length_ - 1, out_state.history_ + 1); - if (out_state.valid_length_ < kMaxOrder - 1) { - out_state.history_[out_state.valid_length_] = Vocab_None; - } - } else { - ret.ngram_length = 0; - if (kMaxOrder > 1) out_state.history_[0] = Vocab_None; - out_state.valid_length_ = 0; + ret.ngram_length = MatchedLength(*sri_, new_word, const_history); + out_state.history_[0] = new_word; + out_state.valid_length_ = std::min<unsigned char>(ret.ngram_length, Order() - 1); + std::copy(const_history, const_history + out_state.valid_length_ - 1, out_state.history_ + 1); + if (out_state.valid_length_ < kMaxOrder - 1) { + out_state.history_[out_state.valid_length_] = Vocab_None; } ret.prob = sri_->wordProb(new_word, const_history); return ret; diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh index f15f8789..08627efd 100644 --- a/klm/lm/virtual_interface.hh +++ b/klm/lm/virtual_interface.hh @@ -8,8 +8,27 @@ namespace lm { +/* Structure returned by scoring routines. */ struct FullScoreReturn { + // log10 probability float prob; + + /* The length of n-gram matched. Do not use this for recombination. + * Consider a model containing only the following n-grams: + * -1 foo + * -3.14 bar + * -2.718 baz -5 + * -6 foo bar + * + * If you score ``bar'' then ngram_length is 1 and recombination state is the + * empty string because bar has zero backoff and does not extend to the + * right. + * If you score ``foo'' then ngram_length is 1 and recombination state is + * ``foo''. + * + * Ideally, keep output states around and compare them. Failing that, + * get out_state.ValidLength() and use that length for recombination. + */ unsigned char ngram_length; }; @@ -72,7 +91,8 @@ class Vocabulary { /* There are two ways to access a Model. * * - * OPTION 1: Access the Model directly (e.g. lm::ngram::Model in ngram.hh). + * OPTION 1: Access the Model directly (e.g. lm::ngram::Model in model.hh). + * * Every Model implements the scoring function: * float Score( * const Model::State &in_state, @@ -85,6 +105,7 @@ class Vocabulary { * const WordIndex new_word, * Model::State &out_state) const; * + * * There are also accessor functions: * const State &BeginSentenceState() const; * const State &NullContextState() const; @@ -114,6 +135,7 @@ class Vocabulary { * * All the State objects are POD, so it's ok to use raw memory for storing * State. + * in_state and out_state must not have the same address. */ class Model { public: @@ -123,8 +145,10 @@ class Model { const void *BeginSentenceMemory() const { return begin_sentence_memory_; } const void *NullContextMemory() const { return null_context_memory_; } + // Requires in_state != out_state virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0; + // Requires in_state != out_state virtual FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0; unsigned char Order() const { return order_; } diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index fd11ad2c..515af5db 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -189,24 +189,24 @@ void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { void MissingUnknown(const Config &config) throw(SpecialWordMissingException) { switch(config.unknown_missing) { - case Config::SILENT: + case SILENT: return; - case Config::COMPLAIN: + case COMPLAIN: if (config.messages) *config.messages << "The ARPA file is missing <unk>. Substituting log10 probability " << config.unknown_missing_logprob << "." << std::endl; break; - case Config::THROW_UP: + case THROW_UP: UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing <unk> and the model is configured to throw an exception."); } } void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException) { switch (config.sentence_marker_missing) { - case Config::SILENT: + case SILENT: return; - case Config::COMPLAIN: + case COMPLAIN: if (config.messages) *config.messages << "Missing special word " << str << "; will treat it as <unk>."; break; - case Config::THROW_UP: + case THROW_UP: UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing " << str << " and the model is configured to reject these models. Run build_binary -s to disable this check."); } } diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index 67681f7e..f447a70c 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -237,7 +237,12 @@ void FilePiece::MMapShift(off_t desired_begin) throw() { // Forcibly clear the existing mmap first. data_.reset(); - data_.reset(mmap(NULL, mapped_size, PROT_READ, MAP_PRIVATE, *file_, mapped_offset), mapped_size, scoped_memory::MMAP_ALLOCATED); + data_.reset(mmap(NULL, mapped_size, PROT_READ, MAP_SHARED + // Populate where available on linux +#ifdef MAP_POPULATE + | MAP_POPULATE +#endif + , *file_, mapped_offset), mapped_size, scoped_memory::MMAP_ALLOCATED); if (data_.get() == MAP_FAILED) { if (desired_begin) { if (((off_t)-1) == lseek(*file_, desired_begin, SEEK_SET)) UTIL_THROW(ErrnoException, "mmap failed even though it worked before. lseek failed too, so using read isn't an option either."); diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc index 456ce953..e7c0643b 100644 --- a/klm/util/mmap.cc +++ b/klm/util/mmap.cc @@ -15,8 +15,9 @@ namespace util { scoped_mmap::~scoped_mmap() { if (data_ != (void*)-1) { - if (munmap(data_, size_)) { - std::cerr << "munmap failed for " << size_ << " bytes." << std::endl; + // Thanks Denis Filimonov for pointing on NFS likes msync first. + if (msync(data_, size_, MS_SYNC) || munmap(data_, size_)) { + std::cerr << "msync or mmap failed for " << size_ << " bytes." << std::endl; abort(); } } diff --git a/klm/util/string_piece.hh b/klm/util/string_piece.hh index e5b16e38..5de053aa 100644 --- a/klm/util/string_piece.hh +++ b/klm/util/string_piece.hh @@ -60,6 +60,23 @@ #ifdef HAVE_ICU #include <unicode/stringpiece.h> +#include <unicode/uversion.h> + +// Old versions of ICU don't define operator== and operator!=. +#if (U_ICU_VERSION_MAJOR_NUM < 4) || ((U_ICU_VERSION_MAJOR_NUM == 4) && (U_ICU_VERSION_MINOR_NUM < 4)) +#warning You are using an old version of ICU. Consider upgrading to ICU >= 4.6. +inline bool operator==(const StringPiece& x, const StringPiece& y) { + if (x.size() != y.size()) + return false; + + return std::memcmp(x.data(), y.data(), x.size()) == 0; +} + +inline bool operator!=(const StringPiece& x, const StringPiece& y) { + return !(x == y); +} +#endif // old version of ICU + U_NAMESPACE_BEGIN #else @@ -209,7 +226,7 @@ inline bool operator!=(const StringPiece& x, const StringPiece& y) { return !(x == y); } -#endif +#endif // HAVE_ICU undefined inline bool operator<(const StringPiece& x, const StringPiece& y) { const int r = std::memcmp(x.data(), y.data(), |