diff options
Diffstat (limited to 'klm')
| -rw-r--r-- | klm/lm/build_binary.cc | 102 | ||||
| -rw-r--r-- | klm/lm/config.cc | 2 | ||||
| -rw-r--r-- | klm/lm/config.hh | 2 | ||||
| -rw-r--r-- | klm/lm/model.cc | 2 | ||||
| -rw-r--r-- | klm/lm/search_trie.cc | 20 | ||||
| -rw-r--r-- | klm/lm/vocab.cc | 2 | ||||
| -rw-r--r-- | klm/util/bit_packing.hh | 7 | ||||
| -rw-r--r-- | klm/util/exception.cc | 4 | ||||
| -rw-r--r-- | klm/util/have.hh | 6 | 
9 files changed, 80 insertions, 67 deletions
| diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index d6dd5994..920ff080 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,8 +15,9 @@ namespace ngram {  namespace {  void Usage(const char *name) { -  std::cerr << "Usage: " << name << " [-u unknown_probability] [-s] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n" -"-u sets the default probability for <unk> if the ARPA file does not have one.\n" +  std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n" +"-u sets the default log10 probability for <unk> if the ARPA file does not have\n" +"one.\n"  "-s allows models to be built even if they do not have <s> and </s>.\n\n"  "type is one of probing, trie, or sorted:\n\n"  "probing uses a probing hash table.  It is the fastest but uses the most memory.\n" @@ -69,65 +70,58 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {  } // namespace lm  } // namespace -void terminate_handler() { -  try { throw; } -  catch(const std::exception& e) { -    std::cerr << e.what() << std::endl; -  } -  catch(...) { -    std::cerr << "A non-standard exception was thrown." << std::endl; -  } -  std::abort(); -} -  int main(int argc, char *argv[]) {    using namespace lm::ngram; -  std::set_terminate(terminate_handler); - -  lm::ngram::Config config; -  int opt; -  while ((opt = getopt(argc, argv, "su:p:t:m:")) != -1) { -    switch(opt) { -      case 'u': -        config.unknown_missing_prob = ParseFloat(optarg); -        break; -      case 'p': -        config.probing_multiplier = ParseFloat(optarg); -        break; -      case 't': -        config.temporary_directory_prefix = optarg; -        break; -      case 'm': -        config.building_memory = ParseUInt(optarg) * 1048576; -        break; -      case 's': -        config.sentence_marker_missing = lm::ngram::Config::SILENT; -        break; -      default: -        Usage(argv[0]); +  try { +    lm::ngram::Config config; +    int opt; +    while ((opt = getopt(argc, argv, "su:p:t:m:")) != -1) { +      switch(opt) { +        case 'u': +          config.unknown_missing_logprob = ParseFloat(optarg); +          break; +        case 'p': +          config.probing_multiplier = ParseFloat(optarg); +          break; +        case 't': +          config.temporary_directory_prefix = optarg; +          break; +        case 'm': +          config.building_memory = ParseUInt(optarg) * 1048576; +          break; +        case 's': +          config.sentence_marker_missing = lm::ngram::Config::SILENT; +          break; +        default: +          Usage(argv[0]); +      }      } -  } -  if (optind + 1 == argc) { -    ShowSizes(argv[optind], config); -  } else if (optind + 2 == argc) { -    config.write_mmap = argv[optind + 1]; -    ProbingModel(argv[optind], config); -  } else if (optind + 3 == argc) { -    const char *model_type = argv[optind]; -    const char *from_file = argv[optind + 1]; -    config.write_mmap = argv[optind + 2]; -    if (!strcmp(model_type, "probing")) { -      ProbingModel(from_file, config); -    } else if (!strcmp(model_type, "sorted")) { -      SortedModel(from_file, config); -    } else if (!strcmp(model_type, "trie")) { -      TrieModel(from_file, config); +    if (optind + 1 == argc) { +      ShowSizes(argv[optind], config); +    } else if (optind + 2 == argc) { +      config.write_mmap = argv[optind + 1]; +      ProbingModel(argv[optind], config); +    } else if (optind + 3 == argc) { +      const char *model_type = argv[optind]; +      const char *from_file = argv[optind + 1]; +      config.write_mmap = argv[optind + 2]; +      if (!strcmp(model_type, "probing")) { +        ProbingModel(from_file, config); +      } else if (!strcmp(model_type, "sorted")) { +        SortedModel(from_file, config); +      } else if (!strcmp(model_type, "trie")) { +        TrieModel(from_file, config); +      } else { +        Usage(argv[0]); +      }      } else {        Usage(argv[0]);      } -  } else { -    Usage(argv[0]); +  } +  catch (std::exception &e) { +    std::cerr << e.what() << std::endl; +    abort();    }    return 0;  } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index d8773fe5..71646e51 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -10,7 +10,7 @@ Config::Config() :    enumerate_vocab(NULL),    unknown_missing(COMPLAIN),    sentence_marker_missing(THROW_UP), -  unknown_missing_prob(0.0), +  unknown_missing_logprob(-100.0),    probing_multiplier(1.5),    building_memory(1073741824ULL), // 1 GB    temporary_directory_prefix(NULL), diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 17f67df3..1f7762be 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -36,7 +36,7 @@ struct Config {    // The probability to substitute for <unk> if it's missing from the model.      // No effect if the model has <unk> or unknown_missing == THROW_UP. -  float unknown_missing_prob; +  float unknown_missing_logprob;    // Size multiplier for probing hash table.  Must be > 1.  Space is linear in    // this.  Time is probing_multiplier / (probing_multiplier - 1).  No effect diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 14949e97..1492276a 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -86,7 +86,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT      assert(config.unknown_missing != Config::THROW_UP);      // Default probabilities for unknown.        search_.unigram.Unknown().backoff = 0.0; -    search_.unigram.Unknown().prob = config.unknown_missing_prob; +    search_.unigram.Unknown().prob = config.unknown_missing_logprob;    }    FinishFile(config, kModelType, counts, backing_);  } diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 63631223..b830dfc3 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -535,13 +535,16 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st    }  } -void ARPAToSortedFiles(const Config &config, util::FilePiece &f, const std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { +void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {    {      std::string unigram_name = file_prefix + "unigrams";      util::scoped_fd unigram_file; -    util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff)); +    // In case <unk> appears.   +    size_t extra_count = counts[0] + 1; +    util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff));      Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()));      CheckSpecials(config, vocab); +    if (!vocab.SawUnk()) ++counts[0];    }    // Only use as much buffer as we need.   @@ -572,7 +575,7 @@ bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const W    return true;  } -// Counting phrase +// Phase to count n-grams, including blanks inserted because they were pruned but have extensions  class JustCount {    public:      JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order) @@ -603,6 +606,7 @@ class JustCount {      uint64_t *const counts_, *const longest_counts_;  }; +// Phase to actually write n-grams to the trie.    class WriteEntries {    public:      WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) :  @@ -764,7 +768,7 @@ template <class Doing> class RecursiveInsert {  void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) {    if (fixed[0] != initial[0]) UTIL_THROW(util::Exception, "Unigram count should be constant but initial is " << initial[0] << " and recounted is " << fixed[0]); -  if (fixed.back() != initial.back()) UTIL_THROW(util::Exception, "Longest count should be constant"); +  if (fixed.back() != initial.back()) UTIL_THROW(util::Exception, "Longest count should be constant but it changed from " << initial.back() << " to " << fixed.back());    for (unsigned char i = 0; i < initial.size(); ++i) {      if (fixed[i] < initial[i]) UTIL_THROW(util::Exception, "Counts came out lower than expected.  This shouldn't happen");    } @@ -789,6 +793,9 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co      RecursiveInsert<JustCount> counter(inputs, contexts, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());      counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]);    } +  for (SortedFileReader *i = inputs; i < inputs + counts.size() - 1; ++i) { +    if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading"); +  }    SanityCheckCounts(counts, fixed_counts);    counts = fixed_counts; @@ -805,7 +812,7 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co    }    // Fill unigram probabilities.   -  { +  try {      std::string name(file_prefix + "unigrams");      util::scoped_FILE file(OpenOrThrow(name.c_str(), "r"));      for (WordIndex i = 0; i < counts[0]; ++i) { @@ -816,6 +823,9 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co        }      }      RemoveOrThrow(name.c_str()); +  } catch (util::Exception &e) { +    e << " while re-reading unigram probabilities"; +    throw;    }    // Do not disable this error message or else too little state will be returned.  Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.    diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 415f8331..fd11ad2c 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -192,7 +192,7 @@ void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {      case Config::SILENT:        return;      case Config::COMPLAIN: -      if (config.messages) *config.messages << "The ARPA file is missing <unk>.  Substituting probability " << config.unknown_missing_prob << "." << std::endl; +      if (config.messages) *config.messages << "The ARPA file is missing <unk>.  Substituting log10 probability " << config.unknown_missing_logprob << "." << std::endl;        break;      case Config::THROW_UP:        UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing <unk> and the model is configured to throw an exception."); diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index 70cfc2d2..5c71c792 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -28,16 +28,19 @@ namespace util {   * but it may be called multiple times when that's inconvenient.     */ -inline uint8_t BitPackShift(uint8_t bit, uint8_t length) { +  // Fun fact: __BYTE_ORDER is wrong on Solaris Sparc, but the version without __ is correct.    #if BYTE_ORDER == LITTLE_ENDIAN +inline uint8_t BitPackShift(uint8_t bit, uint8_t /*length*/) {    return bit; +}  #elif BYTE_ORDER == BIG_ENDIAN +inline uint8_t BitPackShift(uint8_t bit, uint8_t length) {    return 64 - length - bit; +}  #else  #error "Bit packing code isn't written for your byte order."  #endif -}  /* Pack integers up to 57 bits using their least significant digits.    * The length is specified using mask: diff --git a/klm/util/exception.cc b/klm/util/exception.cc index 077405f4..84f9fe7c 100644 --- a/klm/util/exception.cc +++ b/klm/util/exception.cc @@ -9,11 +9,11 @@ Exception::Exception() throw() {}  Exception::~Exception() throw() {}  Exception::Exception(const Exception &from) : std::exception() { -  stream_.str(from.stream_.str()); +  stream_ << from.stream_.str();  }  Exception &Exception::operator=(const Exception &from) { -  stream_.str(from.stream_.str()); +  stream_ << from.stream_.str();    return *this;  } diff --git a/klm/util/have.hh b/klm/util/have.hh index 7cf62008..f2f0cf90 100644 --- a/klm/util/have.hh +++ b/klm/util/have.hh @@ -2,8 +2,14 @@  #ifndef UTIL_HAVE__  #define UTIL_HAVE__ +#ifndef HAVE_ZLIB  #define HAVE_ZLIB +#endif +  // #define HAVE_ICU + +#ifndef HAVE_BOOST  #define HAVE_BOOST +#endif  #endif // UTIL_HAVE__ | 
