diff options
author | Chris Dyer <redpony@gmail.com> | 2014-10-13 00:42:37 -0400 |
---|---|---|
committer | Chris Dyer <redpony@gmail.com> | 2014-10-13 00:42:37 -0400 |
commit | b1ed81ef3216b212295afa76c5d20a56fb647204 (patch) | |
tree | 9633cdc1b8a341dfa58b0b7fec0e2cae44d28835 /klm/lm/builder/adjust_counts.hh | |
parent | 1b17f61d359be6e1c3cea29f8c100db3bcdd73a0 (diff) |
new kenlm
Diffstat (limited to 'klm/lm/builder/adjust_counts.hh')
-rw-r--r-- | klm/lm/builder/adjust_counts.hh | 41 |
1 files changed, 33 insertions, 8 deletions
diff --git a/klm/lm/builder/adjust_counts.hh b/klm/lm/builder/adjust_counts.hh index f38ff79d..a5435c28 100644 --- a/klm/lm/builder/adjust_counts.hh +++ b/klm/lm/builder/adjust_counts.hh @@ -1,24 +1,35 @@ -#ifndef LM_BUILDER_ADJUST_COUNTS__ -#define LM_BUILDER_ADJUST_COUNTS__ +#ifndef LM_BUILDER_ADJUST_COUNTS_H +#define LM_BUILDER_ADJUST_COUNTS_H #include "lm/builder/discount.hh" +#include "lm/lm_exception.hh" #include "util/exception.hh" #include <vector> #include <stdint.h> +namespace util { namespace stream { class ChainPositions; } } + namespace lm { namespace builder { -class ChainPositions; - class BadDiscountException : public util::Exception { public: BadDiscountException() throw(); ~BadDiscountException() throw(); }; +struct DiscountConfig { + // Overrides discounts for orders [1,discount_override.size()]. + std::vector<Discount> overwrite; + // If discounting fails for an order, copy them from here. + Discount fallback; + // What to do when discounts are out of range or would trigger divison by + // zero. It it does something other than THROW_UP, use fallback_discount. + WarningAction bad_action; +}; + /* Compute adjusted counts. * Input: unique suffix sorted N-grams (and just the N-grams) with raw counts. * Output: [1,N]-grams with adjusted counts. @@ -27,18 +38,32 @@ class BadDiscountException : public util::Exception { */ class AdjustCounts { public: - AdjustCounts(std::vector<uint64_t> &counts, std::vector<Discount> &discounts) - : counts_(counts), discounts_(discounts) {} + // counts: output + // counts_pruned: output + // discounts: mostly output. If the input already has entries, they will be kept. + // prune_thresholds: input. n-grams with normal (not adjusted) count below this will be pruned. + AdjustCounts( + const std::vector<uint64_t> &prune_thresholds, + std::vector<uint64_t> &counts, + std::vector<uint64_t> &counts_pruned, + const DiscountConfig &discount_config, + std::vector<Discount> &discounts) + : prune_thresholds_(prune_thresholds), counts_(counts), counts_pruned_(counts_pruned), discount_config_(discount_config), discounts_(discounts) + {} - void Run(const ChainPositions &positions); + void Run(const util::stream::ChainPositions &positions); private: + const std::vector<uint64_t> &prune_thresholds_; std::vector<uint64_t> &counts_; + std::vector<uint64_t> &counts_pruned_; + + DiscountConfig discount_config_; std::vector<Discount> &discounts_; }; } // namespace builder } // namespace lm -#endif // LM_BUILDER_ADJUST_COUNTS__ +#endif // LM_BUILDER_ADJUST_COUNTS_H |