summaryrefslogtreecommitdiff
path: root/klm
diff options
context:
space:
mode:
authorJonathan Clark <jon.h.clark@gmail.com>2011-03-24 09:51:40 -0400
committerJonathan Clark <jon.h.clark@gmail.com>2011-03-24 09:51:40 -0400
commiteb33700d1c868662b5d0abedaaf3fa47948a89d0 (patch)
treeed70be84820d243524bab0b59a84b8da033a9c41 /klm
parentba4f147f84aa0d4623da640a2d0de7e6242a53af (diff)
parenta580faa8177331cf51138a2208e276b703470934 (diff)
Undo some silly local changes so we can pull
Diffstat (limited to 'klm')
-rw-r--r--klm/lm/build_binary.cc102
-rw-r--r--klm/lm/config.cc2
-rw-r--r--klm/lm/config.hh2
-rw-r--r--klm/lm/model.cc2
-rw-r--r--klm/lm/search_trie.cc20
-rw-r--r--klm/lm/vocab.cc2
-rw-r--r--klm/util/bit_packing.hh7
-rw-r--r--klm/util/exception.cc4
-rw-r--r--klm/util/have.hh6
9 files changed, 80 insertions, 67 deletions
diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc
index d6dd5994..920ff080 100644
--- a/klm/lm/build_binary.cc
+++ b/klm/lm/build_binary.cc
@@ -15,8 +15,9 @@ namespace ngram {
namespace {
void Usage(const char *name) {
- std::cerr << "Usage: " << name << " [-u unknown_probability] [-s] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n"
-"-u sets the default probability for <unk> if the ARPA file does not have one.\n"
+ std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n"
+"-u sets the default log10 probability for <unk> if the ARPA file does not have\n"
+"one.\n"
"-s allows models to be built even if they do not have <s> and </s>.\n\n"
"type is one of probing, trie, or sorted:\n\n"
"probing uses a probing hash table. It is the fastest but uses the most memory.\n"
@@ -69,65 +70,58 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {
} // namespace lm
} // namespace
-void terminate_handler() {
- try { throw; }
- catch(const std::exception& e) {
- std::cerr << e.what() << std::endl;
- }
- catch(...) {
- std::cerr << "A non-standard exception was thrown." << std::endl;
- }
- std::abort();
-}
-
int main(int argc, char *argv[]) {
using namespace lm::ngram;
- std::set_terminate(terminate_handler);
-
- lm::ngram::Config config;
- int opt;
- while ((opt = getopt(argc, argv, "su:p:t:m:")) != -1) {
- switch(opt) {
- case 'u':
- config.unknown_missing_prob = ParseFloat(optarg);
- break;
- case 'p':
- config.probing_multiplier = ParseFloat(optarg);
- break;
- case 't':
- config.temporary_directory_prefix = optarg;
- break;
- case 'm':
- config.building_memory = ParseUInt(optarg) * 1048576;
- break;
- case 's':
- config.sentence_marker_missing = lm::ngram::Config::SILENT;
- break;
- default:
- Usage(argv[0]);
+ try {
+ lm::ngram::Config config;
+ int opt;
+ while ((opt = getopt(argc, argv, "su:p:t:m:")) != -1) {
+ switch(opt) {
+ case 'u':
+ config.unknown_missing_logprob = ParseFloat(optarg);
+ break;
+ case 'p':
+ config.probing_multiplier = ParseFloat(optarg);
+ break;
+ case 't':
+ config.temporary_directory_prefix = optarg;
+ break;
+ case 'm':
+ config.building_memory = ParseUInt(optarg) * 1048576;
+ break;
+ case 's':
+ config.sentence_marker_missing = lm::ngram::Config::SILENT;
+ break;
+ default:
+ Usage(argv[0]);
+ }
}
- }
- if (optind + 1 == argc) {
- ShowSizes(argv[optind], config);
- } else if (optind + 2 == argc) {
- config.write_mmap = argv[optind + 1];
- ProbingModel(argv[optind], config);
- } else if (optind + 3 == argc) {
- const char *model_type = argv[optind];
- const char *from_file = argv[optind + 1];
- config.write_mmap = argv[optind + 2];
- if (!strcmp(model_type, "probing")) {
- ProbingModel(from_file, config);
- } else if (!strcmp(model_type, "sorted")) {
- SortedModel(from_file, config);
- } else if (!strcmp(model_type, "trie")) {
- TrieModel(from_file, config);
+ if (optind + 1 == argc) {
+ ShowSizes(argv[optind], config);
+ } else if (optind + 2 == argc) {
+ config.write_mmap = argv[optind + 1];
+ ProbingModel(argv[optind], config);
+ } else if (optind + 3 == argc) {
+ const char *model_type = argv[optind];
+ const char *from_file = argv[optind + 1];
+ config.write_mmap = argv[optind + 2];
+ if (!strcmp(model_type, "probing")) {
+ ProbingModel(from_file, config);
+ } else if (!strcmp(model_type, "sorted")) {
+ SortedModel(from_file, config);
+ } else if (!strcmp(model_type, "trie")) {
+ TrieModel(from_file, config);
+ } else {
+ Usage(argv[0]);
+ }
} else {
Usage(argv[0]);
}
- } else {
- Usage(argv[0]);
+ }
+ catch (std::exception &e) {
+ std::cerr << e.what() << std::endl;
+ abort();
}
return 0;
}
diff --git a/klm/lm/config.cc b/klm/lm/config.cc
index d8773fe5..71646e51 100644
--- a/klm/lm/config.cc
+++ b/klm/lm/config.cc
@@ -10,7 +10,7 @@ Config::Config() :
enumerate_vocab(NULL),
unknown_missing(COMPLAIN),
sentence_marker_missing(THROW_UP),
- unknown_missing_prob(0.0),
+ unknown_missing_logprob(-100.0),
probing_multiplier(1.5),
building_memory(1073741824ULL), // 1 GB
temporary_directory_prefix(NULL),
diff --git a/klm/lm/config.hh b/klm/lm/config.hh
index 17f67df3..1f7762be 100644
--- a/klm/lm/config.hh
+++ b/klm/lm/config.hh
@@ -36,7 +36,7 @@ struct Config {
// The probability to substitute for <unk> if it's missing from the model.
// No effect if the model has <unk> or unknown_missing == THROW_UP.
- float unknown_missing_prob;
+ float unknown_missing_logprob;
// Size multiplier for probing hash table. Must be > 1. Space is linear in
// this. Time is probing_multiplier / (probing_multiplier - 1). No effect
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index 14949e97..1492276a 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -86,7 +86,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
assert(config.unknown_missing != Config::THROW_UP);
// Default probabilities for unknown.
search_.unigram.Unknown().backoff = 0.0;
- search_.unigram.Unknown().prob = config.unknown_missing_prob;
+ search_.unigram.Unknown().prob = config.unknown_missing_logprob;
}
FinishFile(config, kModelType, counts, backing_);
}
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 63631223..b830dfc3 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -535,13 +535,16 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
}
}
-void ARPAToSortedFiles(const Config &config, util::FilePiece &f, const std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
+void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
{
std::string unigram_name = file_prefix + "unigrams";
util::scoped_fd unigram_file;
- util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff));
+ // In case <unk> appears.
+ size_t extra_count = counts[0] + 1;
+ util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff));
Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()));
CheckSpecials(config, vocab);
+ if (!vocab.SawUnk()) ++counts[0];
}
// Only use as much buffer as we need.
@@ -572,7 +575,7 @@ bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const W
return true;
}
-// Counting phrase
+// Phase to count n-grams, including blanks inserted because they were pruned but have extensions
class JustCount {
public:
JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order)
@@ -603,6 +606,7 @@ class JustCount {
uint64_t *const counts_, *const longest_counts_;
};
+// Phase to actually write n-grams to the trie.
class WriteEntries {
public:
WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) :
@@ -764,7 +768,7 @@ template <class Doing> class RecursiveInsert {
void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) {
if (fixed[0] != initial[0]) UTIL_THROW(util::Exception, "Unigram count should be constant but initial is " << initial[0] << " and recounted is " << fixed[0]);
- if (fixed.back() != initial.back()) UTIL_THROW(util::Exception, "Longest count should be constant");
+ if (fixed.back() != initial.back()) UTIL_THROW(util::Exception, "Longest count should be constant but it changed from " << initial.back() << " to " << fixed.back());
for (unsigned char i = 0; i < initial.size(); ++i) {
if (fixed[i] < initial[i]) UTIL_THROW(util::Exception, "Counts came out lower than expected. This shouldn't happen");
}
@@ -789,6 +793,9 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co
RecursiveInsert<JustCount> counter(inputs, contexts, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]);
}
+ for (SortedFileReader *i = inputs; i < inputs + counts.size() - 1; ++i) {
+ if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading");
+ }
SanityCheckCounts(counts, fixed_counts);
counts = fixed_counts;
@@ -805,7 +812,7 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co
}
// Fill unigram probabilities.
- {
+ try {
std::string name(file_prefix + "unigrams");
util::scoped_FILE file(OpenOrThrow(name.c_str(), "r"));
for (WordIndex i = 0; i < counts[0]; ++i) {
@@ -816,6 +823,9 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co
}
}
RemoveOrThrow(name.c_str());
+ } catch (util::Exception &e) {
+ e << " while re-reading unigram probabilities";
+ throw;
}
// Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index 415f8331..fd11ad2c 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -192,7 +192,7 @@ void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
case Config::SILENT:
return;
case Config::COMPLAIN:
- if (config.messages) *config.messages << "The ARPA file is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl;
+ if (config.messages) *config.messages << "The ARPA file is missing <unk>. Substituting log10 probability " << config.unknown_missing_logprob << "." << std::endl;
break;
case Config::THROW_UP:
UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing <unk> and the model is configured to throw an exception.");
diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh
index 70cfc2d2..5c71c792 100644
--- a/klm/util/bit_packing.hh
+++ b/klm/util/bit_packing.hh
@@ -28,16 +28,19 @@ namespace util {
* but it may be called multiple times when that's inconvenient.
*/
-inline uint8_t BitPackShift(uint8_t bit, uint8_t length) {
+
// Fun fact: __BYTE_ORDER is wrong on Solaris Sparc, but the version without __ is correct.
#if BYTE_ORDER == LITTLE_ENDIAN
+inline uint8_t BitPackShift(uint8_t bit, uint8_t /*length*/) {
return bit;
+}
#elif BYTE_ORDER == BIG_ENDIAN
+inline uint8_t BitPackShift(uint8_t bit, uint8_t length) {
return 64 - length - bit;
+}
#else
#error "Bit packing code isn't written for your byte order."
#endif
-}
/* Pack integers up to 57 bits using their least significant digits.
* The length is specified using mask:
diff --git a/klm/util/exception.cc b/klm/util/exception.cc
index 077405f4..84f9fe7c 100644
--- a/klm/util/exception.cc
+++ b/klm/util/exception.cc
@@ -9,11 +9,11 @@ Exception::Exception() throw() {}
Exception::~Exception() throw() {}
Exception::Exception(const Exception &from) : std::exception() {
- stream_.str(from.stream_.str());
+ stream_ << from.stream_.str();
}
Exception &Exception::operator=(const Exception &from) {
- stream_.str(from.stream_.str());
+ stream_ << from.stream_.str();
return *this;
}
diff --git a/klm/util/have.hh b/klm/util/have.hh
index 7cf62008..f2f0cf90 100644
--- a/klm/util/have.hh
+++ b/klm/util/have.hh
@@ -2,8 +2,14 @@
#ifndef UTIL_HAVE__
#define UTIL_HAVE__
+#ifndef HAVE_ZLIB
#define HAVE_ZLIB
+#endif
+
// #define HAVE_ICU
+
+#ifndef HAVE_BOOST
#define HAVE_BOOST
+#endif
#endif // UTIL_HAVE__