From bc95fedbaa083d557840db6ac2cbf14e2a3eccce Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 20 May 2011 16:19:04 -0400 Subject: kenlm update including being nicer to NFS --- klm/lm/binary_format.cc | 2 ++ klm/lm/binary_format.hh | 13 ++++++++++--- klm/lm/build_binary.cc | 15 ++++++++------- klm/lm/config.cc | 1 + klm/lm/config.hh | 6 +++++- klm/lm/lm_exception.hh | 2 ++ klm/lm/model.cc | 2 +- klm/lm/model.hh | 33 +++++++++++++++++++++++---------- klm/lm/read_arpa.cc | 14 ++++++++++++++ klm/lm/read_arpa.hh | 33 +++++++++++++++++++++++++++------ klm/lm/search_hashed.cc | 16 +++++++++------- klm/lm/search_trie.cc | 27 ++++++++++++++------------- klm/lm/sri.cc | 18 ++++++------------ klm/lm/virtual_interface.hh | 26 +++++++++++++++++++++++++- klm/lm/vocab.cc | 12 ++++++------ 15 files changed, 153 insertions(+), 67 deletions(-) (limited to 'klm/lm') diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 9be0bc8e..34d9ffca 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -112,6 +112,8 @@ uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &b void FinishFile(const Config &config, ModelType model_type, const std::vector &counts, Backing &backing) { if (config.write_mmap) { + if (msync(backing.search.get(), backing.search.size(), MS_SYNC) || msync(backing.vocab.get(), backing.vocab.size(), MS_SYNC)) + UTIL_THROW(util::ErrnoException, "msync failed for " << config.write_mmap); // header and vocab share the same mmap. The header is written here because we know the counts. Parameters params; params.counts = counts; diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index 72d8c159..1fc71be4 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -18,6 +18,12 @@ namespace ngram { typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2} ModelType; +/*Inspect a file to determine if it is a binary lm. If not, return false. + * If so, return true and set recognized to the type. This is the only API in + * this header designed for use by decoder authors. + */ +bool RecognizeBinary(const char *file, ModelType &recognized); + struct FixedWidthParameters { unsigned char order; float probing_multiplier; @@ -27,6 +33,7 @@ struct FixedWidthParameters { bool has_vocabulary; }; +// Parameters stored in the header of a binary file. struct Parameters { FixedWidthParameters fixed; std::vector counts; @@ -41,10 +48,13 @@ struct Backing { util::scoped_memory search; }; +// Create just enough of a binary file to write vocabulary to it. uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); // Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing); +// Write header to binary file. This is done last to prevent incomplete files +// from loading. void FinishFile(const Config &config, ModelType model_type, const std::vector &counts, Backing &backing); namespace detail { @@ -61,8 +71,6 @@ void ComplainAboutARPA(const Config &config, ModelType model_type); } // namespace detail -bool RecognizeBinary(const char *file, ModelType &recognized); - template void LoadLM(const char *file, const Config &config, To &to) { Backing &backing = to.MutableBacking(); backing.file.reset(util::OpenReadOrThrow(file)); @@ -86,7 +94,6 @@ template void LoadLM(const char *file, const Config &config, To &to) e << " File: " << file; throw; } - } } // namespace ngram diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 920ff080..91ad2fb9 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,10 +15,11 @@ namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n" + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n" "-u sets the default log10 probability for if the ARPA file does not have\n" "one.\n" -"-s allows models to be built even if they do not have and .\n\n" +"-s allows models to be built even if they do not have and .\n" +"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" "type is one of probing, trie, or sorted:\n\n" "probing uses a probing hash table. It is the fastest but uses the most memory.\n" "-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" @@ -63,7 +64,6 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { std::cout << "bytes\n" "probing " << std::setw(length) << probing_size << " assuming -p " << config.probing_multiplier << "\n" "trie " << std::setw(length) << TrieModel::Size(counts, config) << "\n"; -/* "sorted " << std::setw(length) << SortedModel::Size(counts, config) << "\n";*/ } } // namespace ngram @@ -76,7 +76,7 @@ int main(int argc, char *argv[]) { try { lm::ngram::Config config; int opt; - while ((opt = getopt(argc, argv, "su:p:t:m:")) != -1) { + while ((opt = getopt(argc, argv, "siu:p:t:m:")) != -1) { switch(opt) { case 'u': config.unknown_missing_logprob = ParseFloat(optarg); @@ -91,7 +91,10 @@ int main(int argc, char *argv[]) { config.building_memory = ParseUInt(optarg) * 1048576; break; case 's': - config.sentence_marker_missing = lm::ngram::Config::SILENT; + config.sentence_marker_missing = lm::SILENT; + break; + case 'i': + config.positive_log_probability = lm::SILENT; break; default: Usage(argv[0]); @@ -108,8 +111,6 @@ int main(int argc, char *argv[]) { config.write_mmap = argv[optind + 2]; if (!strcmp(model_type, "probing")) { ProbingModel(from_file, config); - } else if (!strcmp(model_type, "sorted")) { - SortedModel(from_file, config); } else if (!strcmp(model_type, "trie")) { TrieModel(from_file, config); } else { diff --git a/klm/lm/config.cc b/klm/lm/config.cc index 71646e51..cee8fce2 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -10,6 +10,7 @@ Config::Config() : enumerate_vocab(NULL), unknown_missing(COMPLAIN), sentence_marker_missing(THROW_UP), + positive_log_probability(THROW_UP), unknown_missing_logprob(-100.0), probing_multiplier(1.5), building_memory(1073741824ULL), // 1 GB diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 1f7762be..6c7fe39b 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -3,6 +3,7 @@ #include +#include "lm/lm_exception.hh" #include "util/mmap.hh" /* Configuration for ngram model. Separate header to reduce pollution. */ @@ -27,13 +28,16 @@ struct Config { // ONLY EFFECTIVE WHEN READING ARPA - typedef enum {THROW_UP, COMPLAIN, SILENT} WarningAction; // What to do when isn't in the provided model. WarningAction unknown_missing; // What to do when or is missing from the model. // If THROW_UP, the exception will be of type util::SpecialWordMissingException. WarningAction sentence_marker_missing; + // What to do with a positive log probability. For COMPLAIN and SILENT, map + // to 0. + WarningAction positive_log_probability; + // The probability to substitute for if it's missing from the model. // No effect if the model has or unknown_missing == THROW_UP. float unknown_missing_logprob; diff --git a/klm/lm/lm_exception.hh b/klm/lm/lm_exception.hh index aa3ca886..f607ced1 100644 --- a/klm/lm/lm_exception.hh +++ b/klm/lm/lm_exception.hh @@ -11,6 +11,8 @@ namespace lm { +typedef enum {THROW_UP, COMPLAIN, SILENT} WarningAction; + class ConfigException : public util::Exception { public: ConfigException() throw(); diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 1492276a..f0579c0c 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -83,7 +83,7 @@ template void GenericModel class GenericModel : public base::ModelFacade, State, VocabularyT> { private: typedef base::ModelFacade, State, VocabularyT> P; @@ -75,23 +75,37 @@ template class GenericModel : public base::Mod // itself. static size_t Size(const std::vector &counts, const Config &config = Config()); + /* Load the model from a file. It may be an ARPA or binary file. Binary + * files must have the format expected by this class or you'll get an + * exception. So TrieModel can only load ARPA or binary created by + * TrieModel. To classify binary files, call RecognizeBinary in + * lm/binary_format.hh. + */ GenericModel(const char *file, const Config &config = Config()); + /* Score p(new_word | in_state) and incorporate new_word into out_state. + * Note that in_state and out_state must be different references: + * &in_state != &out_state. + */ FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const; - /* Slower call without in_state. Don't use this if you can avoid it. This - * is mostly a hack for Hieu to integrate it into Moses which sometimes - * forgets LM state (i.e. it doesn't store it with the phrase). Sigh. - * The context indices should be in an array. - * If context_rbegin != context_rend then *context_rbegin is the word - * before new_word. + /* Slower call without in_state. Try to remember state, but sometimes it + * would cost too much memory or your decoder isn't setup properly. + * To use this function, make an array of WordIndex containing the context + * vocabulary ids in reverse order. Then, pass the bounds of the array: + * [context_rbegin, context_rend). The new_word is not part of the context + * array unless you intend to repeat words. */ FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; /* Get the state for a context. Don't use this if you can avoid it. Use * BeginSentenceState or EmptyContextState and extend from those. If * you're only going to use this state to call FullScore once, use - * FullScoreForgotState. */ + * FullScoreForgotState. + * To use this function, make an array of WordIndex containing the context + * vocabulary ids in reverse order. Then, pass the bounds of the array: + * [context_rbegin, context_rend). + */ void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const; private: @@ -131,9 +145,8 @@ typedef detail::GenericModel ProbingMod // Default implementation. No real reason for it to be the default. typedef ProbingModel Model; +// Smaller implementation. typedef ::lm::ngram::SortedVocabulary SortedVocabulary; -typedef detail::GenericModel SortedModel; - typedef detail::GenericModel TrieModel; } // namespace ngram diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 0e90196d..060a97ea 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -3,6 +3,7 @@ #include "lm/blank.hh" #include +#include #include #include @@ -115,4 +116,17 @@ void ReadEnd(util::FilePiece &in) { } catch (const util::EndOfFileException &e) {} } +void PositiveProbWarn::Warn(float prob) { + switch (action_) { + case THROW_UP: + UTIL_THROW(FormatLoadException, "Positive log probability " << prob << " in the model. This is a bug in IRSTLM; you can set config.positive_log_probability = SILENT or pass -i to build_binary to substitute 0.0 for the log probability. Error"); + case COMPLAIN: + std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM. This and subsequent entires will be mapepd to 0 log probability." << std::endl; + action_ = SILENT; + break; + case SILENT: + break; + } +} + } // namespace lm diff --git a/klm/lm/read_arpa.hh b/klm/lm/read_arpa.hh index 4953d40e..ab996bde 100644 --- a/klm/lm/read_arpa.hh +++ b/klm/lm/read_arpa.hh @@ -22,10 +22,26 @@ void ReadEnd(util::FilePiece &in); extern const bool kARPASpaces[256]; -template void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff *unigrams) { +// Positive log probability warning. +class PositiveProbWarn { + public: + PositiveProbWarn() : action_(THROW_UP) {} + + explicit PositiveProbWarn(WarningAction action) : action_(action) {} + + void Warn(float prob); + + private: + WarningAction action_; +}; + +template void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff *unigrams, PositiveProbWarn &warn) { try { float prob = f.ReadFloat(); - if (prob > 0) UTIL_THROW(FormatLoadException, "Positive probability " << prob); + if (prob > 0.0) { + warn.Warn(prob); + prob = 0.0; + } if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability"); ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited(kARPASpaces))]; value.prob = prob; @@ -36,18 +52,23 @@ template void Read1Gram(util::FilePiece &f, Voc &vocab, ProbBackoff } } -template void Read1Grams(util::FilePiece &f, std::size_t count, Voc &vocab, ProbBackoff *unigrams) { +// Return true if a positive log probability came out. +template void Read1Grams(util::FilePiece &f, std::size_t count, Voc &vocab, ProbBackoff *unigrams, PositiveProbWarn &warn) { ReadNGramHeader(f, 1); for (std::size_t i = 0; i < count; ++i) { - Read1Gram(f, vocab, unigrams); + Read1Gram(f, vocab, unigrams, warn); } vocab.FinishedLoading(unigrams); } -template void ReadNGram(util::FilePiece &f, const unsigned char n, const Voc &vocab, WordIndex *const reverse_indices, Weights &weights) { +// Return true if a positive log probability came out. +template void ReadNGram(util::FilePiece &f, const unsigned char n, const Voc &vocab, WordIndex *const reverse_indices, Weights &weights, PositiveProbWarn &warn) { try { weights.prob = f.ReadFloat(); - if (weights.prob > 0) UTIL_THROW(FormatLoadException, "Positive probability " << weights.prob); + if (weights.prob > 0.0) { + warn.Warn(weights.prob); + weights.prob = 0.0; + } for (WordIndex *vocab_out = reverse_indices + n - 1; vocab_out >= reverse_indices; --vocab_out) { *vocab_out = vocab.Index(f.ReadDelimited(kARPASpaces)); } diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index bb3b955a..eaad59ab 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -48,7 +48,7 @@ class ActivateUnigram { ProbBackoff *modify_; }; -template void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector &middle, Activate activate, Store &store) { +template void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector &middle, Activate activate, Store &store, PositiveProbWarn &warn) { ReadNGramHeader(f, n); ProbBackoff blank; @@ -61,7 +61,7 @@ template void ReadNGrams( typename Store::Packing::Value value; typename Middle::ConstIterator found; for (size_t i = 0; i < count; ++i) { - ReadNGram(f, n, vocab, vocab_ids, value); + ReadNGram(f, n, vocab, vocab_ids, value, warn); keys[0] = detail::CombineWordHash(static_cast(*vocab_ids), vocab_ids[1]); for (unsigned int h = 1; h < n - 1; ++h) { keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]); @@ -85,20 +85,22 @@ template template void TemplateHashe // TODO: fix sorted. SetupMemory(GrowForSearch(config, Size(counts, config), backing), counts, config); - Read1Grams(f, counts[0], vocab, unigram.Raw()); + PositiveProbWarn warn(config.positive_log_probability); + + Read1Grams(f, counts[0], vocab, unigram.Raw(), warn); CheckSpecials(config, vocab); try { if (counts.size() > 2) { - ReadNGrams(f, 2, counts[1], vocab, middle, ActivateUnigram(unigram.Raw()), middle[0]); + ReadNGrams(f, 2, counts[1], vocab, middle, ActivateUnigram(unigram.Raw()), middle[0], warn); } for (unsigned int n = 3; n < counts.size(); ++n) { - ReadNGrams(f, n, counts[n-1], vocab, middle, ActivateLowerMiddle(middle[n-3]), middle[n-2]); + ReadNGrams(f, n, counts[n-1], vocab, middle, ActivateLowerMiddle(middle[n-3]), middle[n-2], warn); } if (counts.size() > 2) { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateLowerMiddle(middle.back()), longest); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateLowerMiddle(middle.back()), longest, warn); } else { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateUnigram(unigram.Raw()), longest); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateUnigram(unigram.Raw()), longest, warn); } } catch (util::ProbingSizeException &e) { UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n"); diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index b830dfc3..7c57072b 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -293,7 +293,7 @@ class SortedFileReader { ReadOrThrow(file_.get(), &weights, sizeof(Weights)); } - bool Ended() { + bool Ended() const { return ended_; } @@ -480,7 +480,7 @@ void MergeContextFiles(const std::string &first_base, const std::string &second_ CopyRestOrThrow(remaining.GetFile(), out.get()); } -void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) { +void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) { ReadNGramHeader(f, order); const size_t count = counts[order - 1]; // Size of weights. Does it include backoff? @@ -495,11 +495,11 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; if (order == counts.size()) { for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size)); + ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); } } else { for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size)); + ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); } } // Sort full records by full n-gram. @@ -536,13 +536,14 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st } void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { + PositiveProbWarn warn(config.positive_log_probability); { std::string unigram_name = file_prefix + "unigrams"; util::scoped_fd unigram_file; // In case appears. size_t extra_count = counts[0] + 1; util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff)); - Read1Grams(f, counts[0], vocab, reinterpret_cast(unigram_mmap.get())); + Read1Grams(f, counts[0], vocab, reinterpret_cast(unigram_mmap.get()), warn); CheckSpecials(config, vocab); if (!vocab.SawUnk()) ++counts[0]; } @@ -560,7 +561,7 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector &initial, const std::vector &counts, const Config &config, TrieSearch &out, Backing &backing) { - SortedFileReader inputs[counts.size() - 1]; - ContextReader contexts[counts.size() - 1]; + std::vector inputs(counts.size() - 1); + std::vector contexts(counts.size() - 1); for (unsigned char i = 2; i <= counts.size(); ++i) { std::stringstream assembled; @@ -790,11 +791,11 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co std::vector fixed_counts(counts.size()); { - RecursiveInsert counter(inputs, contexts, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert counter(&*inputs.begin(), &*contexts.begin(), NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); } - for (SortedFileReader *i = inputs; i < inputs + counts.size() - 1; ++i) { - if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading"); + for (std::vector::const_iterator i = inputs.begin(); i != inputs.end(); ++i) { + if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs.begin() + 2) << "-gram table did not complete reading"); } SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; @@ -807,7 +808,7 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert inserter(inputs, contexts, unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert inserter(&*inputs.begin(), &*contexts.begin(), unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); inserter.Apply(config.messages, "Building trie", fixed_counts[0]); } @@ -849,7 +850,7 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex()); } out.middle.back().FinishedLoading(out.longest.InsertIndex()); - } + } } bool IsDirectory(const char *path) { diff --git a/klm/lm/sri.cc b/klm/lm/sri.cc index b634d200..825f699b 100644 --- a/klm/lm/sri.cc +++ b/klm/lm/sri.cc @@ -93,18 +93,12 @@ FullScoreReturn Model::FullScore(const State &in_state, const WordIndex new_word const_history = local_history; } FullScoreReturn ret; - if (new_word != not_found_) { - ret.ngram_length = MatchedLength(*sri_, new_word, const_history); - out_state.history_[0] = new_word; - out_state.valid_length_ = std::min(ret.ngram_length, Order() - 1); - std::copy(const_history, const_history + out_state.valid_length_ - 1, out_state.history_ + 1); - if (out_state.valid_length_ < kMaxOrder - 1) { - out_state.history_[out_state.valid_length_] = Vocab_None; - } - } else { - ret.ngram_length = 0; - if (kMaxOrder > 1) out_state.history_[0] = Vocab_None; - out_state.valid_length_ = 0; + ret.ngram_length = MatchedLength(*sri_, new_word, const_history); + out_state.history_[0] = new_word; + out_state.valid_length_ = std::min(ret.ngram_length, Order() - 1); + std::copy(const_history, const_history + out_state.valid_length_ - 1, out_state.history_ + 1); + if (out_state.valid_length_ < kMaxOrder - 1) { + out_state.history_[out_state.valid_length_] = Vocab_None; } ret.prob = sri_->wordProb(new_word, const_history); return ret; diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh index f15f8789..08627efd 100644 --- a/klm/lm/virtual_interface.hh +++ b/klm/lm/virtual_interface.hh @@ -8,8 +8,27 @@ namespace lm { +/* Structure returned by scoring routines. */ struct FullScoreReturn { + // log10 probability float prob; + + /* The length of n-gram matched. Do not use this for recombination. + * Consider a model containing only the following n-grams: + * -1 foo + * -3.14 bar + * -2.718 baz -5 + * -6 foo bar + * + * If you score ``bar'' then ngram_length is 1 and recombination state is the + * empty string because bar has zero backoff and does not extend to the + * right. + * If you score ``foo'' then ngram_length is 1 and recombination state is + * ``foo''. + * + * Ideally, keep output states around and compare them. Failing that, + * get out_state.ValidLength() and use that length for recombination. + */ unsigned char ngram_length; }; @@ -72,7 +91,8 @@ class Vocabulary { /* There are two ways to access a Model. * * - * OPTION 1: Access the Model directly (e.g. lm::ngram::Model in ngram.hh). + * OPTION 1: Access the Model directly (e.g. lm::ngram::Model in model.hh). + * * Every Model implements the scoring function: * float Score( * const Model::State &in_state, @@ -85,6 +105,7 @@ class Vocabulary { * const WordIndex new_word, * Model::State &out_state) const; * + * * There are also accessor functions: * const State &BeginSentenceState() const; * const State &NullContextState() const; @@ -114,6 +135,7 @@ class Vocabulary { * * All the State objects are POD, so it's ok to use raw memory for storing * State. + * in_state and out_state must not have the same address. */ class Model { public: @@ -123,8 +145,10 @@ class Model { const void *BeginSentenceMemory() const { return begin_sentence_memory_; } const void *NullContextMemory() const { return null_context_memory_; } + // Requires in_state != out_state virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0; + // Requires in_state != out_state virtual FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0; unsigned char Order() const { return order_; } diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index fd11ad2c..515af5db 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -189,24 +189,24 @@ void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { void MissingUnknown(const Config &config) throw(SpecialWordMissingException) { switch(config.unknown_missing) { - case Config::SILENT: + case SILENT: return; - case Config::COMPLAIN: + case COMPLAIN: if (config.messages) *config.messages << "The ARPA file is missing . Substituting log10 probability " << config.unknown_missing_logprob << "." << std::endl; break; - case Config::THROW_UP: + case THROW_UP: UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing and the model is configured to throw an exception."); } } void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException) { switch (config.sentence_marker_missing) { - case Config::SILENT: + case SILENT: return; - case Config::COMPLAIN: + case COMPLAIN: if (config.messages) *config.messages << "Missing special word " << str << "; will treat it as ."; break; - case Config::THROW_UP: + case THROW_UP: UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing " << str << " and the model is configured to reject these models. Run build_binary -s to disable this check."); } } -- cgit v1.2.3