diff options
-rw-r--r-- | klm/lm/build_binary.cc | 102 | ||||
-rw-r--r-- | klm/lm/config.cc | 2 | ||||
-rw-r--r-- | klm/lm/config.hh | 2 | ||||
-rw-r--r-- | klm/lm/model.cc | 2 | ||||
-rw-r--r-- | klm/lm/search_trie.cc | 20 | ||||
-rw-r--r-- | klm/lm/vocab.cc | 2 | ||||
-rw-r--r-- | klm/util/exception.cc | 4 |
7 files changed, 69 insertions, 65 deletions
diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index d6dd5994..920ff080 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,8 +15,9 @@ namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u unknown_probability] [-s] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n" -"-u sets the default probability for <unk> if the ARPA file does not have one.\n" + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n" +"-u sets the default log10 probability for <unk> if the ARPA file does not have\n" +"one.\n" "-s allows models to be built even if they do not have <s> and </s>.\n\n" "type is one of probing, trie, or sorted:\n\n" "probing uses a probing hash table. It is the fastest but uses the most memory.\n" @@ -69,65 +70,58 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { } // namespace lm } // namespace -void terminate_handler() { - try { throw; } - catch(const std::exception& e) { - std::cerr << e.what() << std::endl; - } - catch(...) { - std::cerr << "A non-standard exception was thrown." << std::endl; - } - std::abort(); -} - int main(int argc, char *argv[]) { using namespace lm::ngram; - std::set_terminate(terminate_handler); - - lm::ngram::Config config; - int opt; - while ((opt = getopt(argc, argv, "su:p:t:m:")) != -1) { - switch(opt) { - case 'u': - config.unknown_missing_prob = ParseFloat(optarg); - break; - case 'p': - config.probing_multiplier = ParseFloat(optarg); - break; - case 't': - config.temporary_directory_prefix = optarg; - break; - case 'm': - config.building_memory = ParseUInt(optarg) * 1048576; - break; - case 's': - config.sentence_marker_missing = lm::ngram::Config::SILENT; - break; - default: - Usage(argv[0]); + try { + lm::ngram::Config config; + int opt; + while ((opt = getopt(argc, argv, "su:p:t:m:")) != -1) { + switch(opt) { + case 'u': + config.unknown_missing_logprob = ParseFloat(optarg); + break; + case 'p': + config.probing_multiplier = ParseFloat(optarg); + break; + case 't': + config.temporary_directory_prefix = optarg; + break; + case 'm': + config.building_memory = ParseUInt(optarg) * 1048576; + break; + case 's': + config.sentence_marker_missing = lm::ngram::Config::SILENT; + break; + default: + Usage(argv[0]); + } } - } - if (optind + 1 == argc) { - ShowSizes(argv[optind], config); - } else if (optind + 2 == argc) { - config.write_mmap = argv[optind + 1]; - ProbingModel(argv[optind], config); - } else if (optind + 3 == argc) { - const char *model_type = argv[optind]; - const char *from_file = argv[optind + 1]; - config.write_mmap = argv[optind + 2]; - if (!strcmp(model_type, "probing")) { - ProbingModel(from_file, config); - } else if (!strcmp(model_type, "sorted")) { - SortedModel(from_file, config); - } else if (!strcmp(model_type, "trie")) { - TrieModel(from_file, config); + if (optind + 1 == argc) { + ShowSizes(argv[optind], config); + } else if (optind + 2 == argc) { + config.write_mmap = argv[optind + 1]; + ProbingModel(argv[optind], config); + } else if (optind + 3 == argc) { + const char *model_type = argv[optind]; + const char *from_file = argv[optind + 1]; + config.write_mmap = argv[optind + 2]; + if (!strcmp(model_type, "probing")) { + ProbingModel(from_file, config); + } else if (!strcmp(model_type, "sorted")) { + SortedModel(from_file, config); + } else if (!strcmp(model_type, "trie")) { + TrieModel(from_file, config); + } else { + Usage(argv[0]); + } } else { Usage(argv[0]); } - } else { - Usage(argv[0]); + } + catch (std::exception &e) { + std::cerr << e.what() << std::endl; + abort(); } return 0; } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index d8773fe5..71646e51 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -10,7 +10,7 @@ Config::Config() : enumerate_vocab(NULL), unknown_missing(COMPLAIN), sentence_marker_missing(THROW_UP), - unknown_missing_prob(0.0), + unknown_missing_logprob(-100.0), probing_multiplier(1.5), building_memory(1073741824ULL), // 1 GB temporary_directory_prefix(NULL), diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 17f67df3..1f7762be 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -36,7 +36,7 @@ struct Config { // The probability to substitute for <unk> if it's missing from the model. // No effect if the model has <unk> or unknown_missing == THROW_UP. - float unknown_missing_prob; + float unknown_missing_logprob; // Size multiplier for probing hash table. Must be > 1. Space is linear in // this. Time is probing_multiplier / (probing_multiplier - 1). No effect diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 14949e97..1492276a 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -86,7 +86,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT assert(config.unknown_missing != Config::THROW_UP); // Default probabilities for unknown. search_.unigram.Unknown().backoff = 0.0; - search_.unigram.Unknown().prob = config.unknown_missing_prob; + search_.unigram.Unknown().prob = config.unknown_missing_logprob; } FinishFile(config, kModelType, counts, backing_); } diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 63631223..b830dfc3 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -535,13 +535,16 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st } } -void ARPAToSortedFiles(const Config &config, util::FilePiece &f, const std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { +void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { { std::string unigram_name = file_prefix + "unigrams"; util::scoped_fd unigram_file; - util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff)); + // In case <unk> appears. + size_t extra_count = counts[0] + 1; + util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff)); Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get())); CheckSpecials(config, vocab); + if (!vocab.SawUnk()) ++counts[0]; } // Only use as much buffer as we need. @@ -572,7 +575,7 @@ bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const W return true; } -// Counting phrase +// Phase to count n-grams, including blanks inserted because they were pruned but have extensions class JustCount { public: JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order) @@ -603,6 +606,7 @@ class JustCount { uint64_t *const counts_, *const longest_counts_; }; +// Phase to actually write n-grams to the trie. class WriteEntries { public: WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : @@ -764,7 +768,7 @@ template <class Doing> class RecursiveInsert { void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) { if (fixed[0] != initial[0]) UTIL_THROW(util::Exception, "Unigram count should be constant but initial is " << initial[0] << " and recounted is " << fixed[0]); - if (fixed.back() != initial.back()) UTIL_THROW(util::Exception, "Longest count should be constant"); + if (fixed.back() != initial.back()) UTIL_THROW(util::Exception, "Longest count should be constant but it changed from " << initial.back() << " to " << fixed.back()); for (unsigned char i = 0; i < initial.size(); ++i) { if (fixed[i] < initial[i]) UTIL_THROW(util::Exception, "Counts came out lower than expected. This shouldn't happen"); } @@ -789,6 +793,9 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co RecursiveInsert<JustCount> counter(inputs, contexts, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); } + for (SortedFileReader *i = inputs; i < inputs + counts.size() - 1; ++i) { + if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading"); + } SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; @@ -805,7 +812,7 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co } // Fill unigram probabilities. - { + try { std::string name(file_prefix + "unigrams"); util::scoped_FILE file(OpenOrThrow(name.c_str(), "r")); for (WordIndex i = 0; i < counts[0]; ++i) { @@ -816,6 +823,9 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co } } RemoveOrThrow(name.c_str()); + } catch (util::Exception &e) { + e << " while re-reading unigram probabilities"; + throw; } // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 415f8331..fd11ad2c 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -192,7 +192,7 @@ void MissingUnknown(const Config &config) throw(SpecialWordMissingException) { case Config::SILENT: return; case Config::COMPLAIN: - if (config.messages) *config.messages << "The ARPA file is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl; + if (config.messages) *config.messages << "The ARPA file is missing <unk>. Substituting log10 probability " << config.unknown_missing_logprob << "." << std::endl; break; case Config::THROW_UP: UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing <unk> and the model is configured to throw an exception."); diff --git a/klm/util/exception.cc b/klm/util/exception.cc index 077405f4..84f9fe7c 100644 --- a/klm/util/exception.cc +++ b/klm/util/exception.cc @@ -9,11 +9,11 @@ Exception::Exception() throw() {} Exception::~Exception() throw() {} Exception::Exception(const Exception &from) : std::exception() { - stream_.str(from.stream_.str()); + stream_ << from.stream_.str(); } Exception &Exception::operator=(const Exception &from) { - stream_.str(from.stream_.str()); + stream_ << from.stream_.str(); return *this; } |