diff options
author | Wu, Ke <wuke@cs.umd.edu> | 2014-12-17 16:11:38 -0500 |
---|---|---|
committer | Wu, Ke <wuke@cs.umd.edu> | 2014-12-17 16:11:38 -0500 |
commit | 1613f1fc44ca67820afd7e7b21eb54b316c8ce55 (patch) | |
tree | e02b77084f28a18df6b854f87a986124db44d717 /klm/lm/builder/adjust_counts.hh | |
parent | bd9308e22b5434aa220cc57d82ee867464a011f1 (diff) | |
parent | 796768086a687d3f1856fef6489c34fe4d373642 (diff) |
Merge with upstream
Diffstat (limited to 'klm/lm/builder/adjust_counts.hh')
-rw-r--r-- | klm/lm/builder/adjust_counts.hh | 41 |
1 files changed, 33 insertions, 8 deletions
diff --git a/klm/lm/builder/adjust_counts.hh b/klm/lm/builder/adjust_counts.hh index f38ff79d..a5435c28 100644 --- a/klm/lm/builder/adjust_counts.hh +++ b/klm/lm/builder/adjust_counts.hh @@ -1,24 +1,35 @@ -#ifndef LM_BUILDER_ADJUST_COUNTS__ -#define LM_BUILDER_ADJUST_COUNTS__ +#ifndef LM_BUILDER_ADJUST_COUNTS_H +#define LM_BUILDER_ADJUST_COUNTS_H #include "lm/builder/discount.hh" +#include "lm/lm_exception.hh" #include "util/exception.hh" #include <vector> #include <stdint.h> +namespace util { namespace stream { class ChainPositions; } } + namespace lm { namespace builder { -class ChainPositions; - class BadDiscountException : public util::Exception { public: BadDiscountException() throw(); ~BadDiscountException() throw(); }; +struct DiscountConfig { + // Overrides discounts for orders [1,discount_override.size()]. + std::vector<Discount> overwrite; + // If discounting fails for an order, copy them from here. + Discount fallback; + // What to do when discounts are out of range or would trigger divison by + // zero. It it does something other than THROW_UP, use fallback_discount. + WarningAction bad_action; +}; + /* Compute adjusted counts. * Input: unique suffix sorted N-grams (and just the N-grams) with raw counts. * Output: [1,N]-grams with adjusted counts. @@ -27,18 +38,32 @@ class BadDiscountException : public util::Exception { */ class AdjustCounts { public: - AdjustCounts(std::vector<uint64_t> &counts, std::vector<Discount> &discounts) - : counts_(counts), discounts_(discounts) {} + // counts: output + // counts_pruned: output + // discounts: mostly output. If the input already has entries, they will be kept. + // prune_thresholds: input. n-grams with normal (not adjusted) count below this will be pruned. + AdjustCounts( + const std::vector<uint64_t> &prune_thresholds, + std::vector<uint64_t> &counts, + std::vector<uint64_t> &counts_pruned, + const DiscountConfig &discount_config, + std::vector<Discount> &discounts) + : prune_thresholds_(prune_thresholds), counts_(counts), counts_pruned_(counts_pruned), discount_config_(discount_config), discounts_(discounts) + {} - void Run(const ChainPositions &positions); + void Run(const util::stream::ChainPositions &positions); private: + const std::vector<uint64_t> &prune_thresholds_; std::vector<uint64_t> &counts_; + std::vector<uint64_t> &counts_pruned_; + + DiscountConfig discount_config_; std::vector<Discount> &discounts_; }; } // namespace builder } // namespace lm -#endif // LM_BUILDER_ADJUST_COUNTS__ +#endif // LM_BUILDER_ADJUST_COUNTS_H |