summaryrefslogtreecommitdiff
path: root/klm/lm/builder/lmplz_main.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/builder/lmplz_main.cc')
-rw-r--r--klm/lm/builder/lmplz_main.cc97
1 files changed, 92 insertions, 5 deletions
diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc
index 2563deed..265dd216 100644
--- a/klm/lm/builder/lmplz_main.cc
+++ b/klm/lm/builder/lmplz_main.cc
@@ -1,4 +1,5 @@
#include "lm/builder/pipeline.hh"
+#include "lm/lm_exception.hh"
#include "util/file.hh"
#include "util/file_piece.hh"
#include "util/usage.hh"
@@ -7,6 +8,7 @@
#include <boost/program_options.hpp>
#include <boost/version.hpp>
+#include <vector>
namespace {
class SizeNotify {
@@ -25,6 +27,57 @@ boost::program_options::typed_value<std::string> *SizeOption(std::size_t &to, co
return boost::program_options::value<std::string>()->notifier(SizeNotify(to))->default_value(default_value);
}
+// Parse and validate pruning thresholds then return vector of threshold counts
+// for each n-grams order.
+std::vector<uint64_t> ParsePruning(const std::vector<std::string> &param, std::size_t order) {
+ // convert to vector of integers
+ std::vector<uint64_t> prune_thresholds;
+ prune_thresholds.reserve(order);
+ for (std::vector<std::string>::const_iterator it(param.begin()); it != param.end(); ++it) {
+ try {
+ prune_thresholds.push_back(boost::lexical_cast<uint64_t>(*it));
+ } catch(const boost::bad_lexical_cast &) {
+ UTIL_THROW(util::Exception, "Bad pruning threshold " << *it);
+ }
+ }
+
+ // Fill with zeros by default.
+ if (prune_thresholds.empty()) {
+ prune_thresholds.resize(order, 0);
+ return prune_thresholds;
+ }
+
+ // validate pruning threshold if specified
+ // throw if each n-gram order has not threshold specified
+ UTIL_THROW_IF(prune_thresholds.size() > order, util::Exception, "You specified pruning thresholds for orders 1 through " << prune_thresholds.size() << " but the model only has order " << order);
+ // threshold for unigram can only be 0 (no pruning)
+ UTIL_THROW_IF(prune_thresholds[0] != 0, util::Exception, "Unigram pruning is not implemented, so the first pruning threshold must be 0.");
+
+ // check if threshold are not in decreasing order
+ uint64_t lower_threshold = 0;
+ for (std::vector<uint64_t>::iterator it = prune_thresholds.begin(); it != prune_thresholds.end(); ++it) {
+ UTIL_THROW_IF(lower_threshold > *it, util::Exception, "Pruning thresholds should be in non-decreasing order. Otherwise substrings would be removed, which is bad for query-time data structures.");
+ lower_threshold = *it;
+ }
+
+ // Pad to all orders using the last value.
+ prune_thresholds.resize(order, prune_thresholds.back());
+ return prune_thresholds;
+}
+
+lm::builder::Discount ParseDiscountFallback(const std::vector<std::string> &param) {
+ lm::builder::Discount ret;
+ UTIL_THROW_IF(param.size() > 3, util::Exception, "Specify at most three fallback discounts: 1, 2, and 3+");
+ UTIL_THROW_IF(param.empty(), util::Exception, "Fallback discounting enabled, but no discount specified");
+ ret.amount[0] = 0.0;
+ for (unsigned i = 0; i < 3; ++i) {
+ float discount = boost::lexical_cast<float>(param[i < param.size() ? i : (param.size() - 1)]);
+ UTIL_THROW_IF(discount < 0.0 || discount > static_cast<float>(i+1), util::Exception, "The discount for count " << (i+1) << " was parsed as " << discount << " which is not in the range [0, " << (i+1) << "].");
+ ret.amount[i + 1] = discount;
+ }
+ return ret;
+}
+
} // namespace
int main(int argc, char *argv[]) {
@@ -34,25 +87,36 @@ int main(int argc, char *argv[]) {
lm::builder::PipelineConfig pipeline;
std::string text, arpa;
+ std::vector<std::string> pruning;
+ std::vector<std::string> discount_fallback;
+ std::vector<std::string> discount_fallback_default;
+ discount_fallback_default.push_back("0.5");
+ discount_fallback_default.push_back("1");
+ discount_fallback_default.push_back("1.5");
options.add_options()
- ("help", po::bool_switch(), "Show this help message")
+ ("help,h", po::bool_switch(), "Show this help message")
("order,o", po::value<std::size_t>(&pipeline.order)
#if BOOST_VERSION >= 104200
->required()
#endif
, "Order of the model")
- ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)")
+ ("interpolate_unigrams", po::value<bool>(&pipeline.initial_probs.interpolate_unigrams)->default_value(true)->implicit_value(true), "Interpolate the unigrams (default) as opposed to giving lots of mass to <unk> like SRI. If you want SRI's behavior with a large <unk> and the old lmplz default, use --interpolate_unigrams 0.")
+ ("skip_symbols", po::bool_switch(), "Treat <s>, </s>, and <unk> as whitespace instead of throwing an exception")
("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")
("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)")
- ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")
- ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file")
+ ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
+ ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write a file containing the unique vocabulary strings delimited by null bytes")
+ ("vocab_pad", po::value<uint64_t>(&pipeline.vocab_size_for_unk)->default_value(0), "If the vocabulary is smaller than this value, pad with <unk> to reach this size. Requires --interpolate_unigrams")
("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.")
("text", po::value<std::string>(&text), "Read text from a file instead of stdin")
- ("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout");
+ ("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout")
+ ("collapse_values", po::bool_switch(&pipeline.output_q), "Collapse probability and backoff into a single value, q that yields the same sentence-level probabilities. See http://kheafield.com/professional/edinburgh/rest_paper.pdf for more details, including a proof.")
+ ("prune", po::value<std::vector<std::string> >(&pruning)->multitoken(), "Prune n-grams with count less than or equal to the given threshold. Specify one value for each order i.e. 0 0 1 to prune singleton trigrams and above. The sequence of values must be non-decreasing and the last value applies to any remaining orders. Unigram pruning is not implemented, so the first value must be zero. Default is to not prune, which is equivalent to --prune 0.")
+ ("discount_fallback", po::value<std::vector<std::string> >(&discount_fallback)->multitoken()->implicit_value(discount_fallback_default, "0.5 1 1.5"), "The closed-form estimate for Kneser-Ney discounts does not work without singletons or doubletons. It can also fail if these values are out of range. This option falls back to user-specified discounts when the closed-form estimate fails. Note that this option is generally a bad idea: you should deduplicate your corpus instead. However, class-based models need custom discounts because they lack singleton unigrams. Provide up to three discounts (for adjusted counts 1, 2, and 3+), which will be applied to all orders where the closed-form estimates fail.");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, options), vm);
@@ -95,6 +159,29 @@ int main(int argc, char *argv[]) {
}
#endif
+ if (pipeline.vocab_size_for_unk && !pipeline.initial_probs.interpolate_unigrams) {
+ std::cerr << "--vocab_pad requires --interpolate_unigrams be on" << std::endl;
+ return 1;
+ }
+
+ if (vm["skip_symbols"].as<bool>()) {
+ pipeline.disallowed_symbol_action = lm::COMPLAIN;
+ } else {
+ pipeline.disallowed_symbol_action = lm::THROW_UP;
+ }
+
+ if (vm.count("discount_fallback")) {
+ pipeline.discount.fallback = ParseDiscountFallback(discount_fallback);
+ pipeline.discount.bad_action = lm::COMPLAIN;
+ } else {
+ // Unused, just here to prevent the compiler from complaining about uninitialized.
+ pipeline.discount.fallback = lm::builder::Discount();
+ pipeline.discount.bad_action = lm::THROW_UP;
+ }
+
+ // parse pruning thresholds. These depend on order, so it is not done as a notifier.
+ pipeline.prune_thresholds = ParsePruning(pruning, pipeline.order);
+
util::NormalizeTempPrefix(pipeline.sort.temp_prefix);
lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs;