diff options
author | armatthews <armatthe@cmu.edu> | 2014-10-13 14:59:23 -0400 |
---|---|---|
committer | armatthews <armatthe@cmu.edu> | 2014-10-13 14:59:23 -0400 |
commit | 9a06ff1465eb3477ac3d1e92ab52e7eae40316a8 (patch) | |
tree | 808c266a3f510d00f37cd19c3f1da91d8fc683f7 /klm/lm/builder/adjust_counts.hh | |
parent | e51da099233df0a384b04fe5908b30e44040d13e (diff) | |
parent | d3e2ec203a5cf550320caa8023ac3dd103b0be7d (diff) |
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'klm/lm/builder/adjust_counts.hh')
-rw-r--r-- | klm/lm/builder/adjust_counts.hh | 41 |
1 files changed, 33 insertions, 8 deletions
diff --git a/klm/lm/builder/adjust_counts.hh b/klm/lm/builder/adjust_counts.hh index f38ff79d..a5435c28 100644 --- a/klm/lm/builder/adjust_counts.hh +++ b/klm/lm/builder/adjust_counts.hh @@ -1,24 +1,35 @@ -#ifndef LM_BUILDER_ADJUST_COUNTS__ -#define LM_BUILDER_ADJUST_COUNTS__ +#ifndef LM_BUILDER_ADJUST_COUNTS_H +#define LM_BUILDER_ADJUST_COUNTS_H #include "lm/builder/discount.hh" +#include "lm/lm_exception.hh" #include "util/exception.hh" #include <vector> #include <stdint.h> +namespace util { namespace stream { class ChainPositions; } } + namespace lm { namespace builder { -class ChainPositions; - class BadDiscountException : public util::Exception { public: BadDiscountException() throw(); ~BadDiscountException() throw(); }; +struct DiscountConfig { + // Overrides discounts for orders [1,discount_override.size()]. + std::vector<Discount> overwrite; + // If discounting fails for an order, copy them from here. + Discount fallback; + // What to do when discounts are out of range or would trigger divison by + // zero. It it does something other than THROW_UP, use fallback_discount. + WarningAction bad_action; +}; + /* Compute adjusted counts. * Input: unique suffix sorted N-grams (and just the N-grams) with raw counts. * Output: [1,N]-grams with adjusted counts. @@ -27,18 +38,32 @@ class BadDiscountException : public util::Exception { */ class AdjustCounts { public: - AdjustCounts(std::vector<uint64_t> &counts, std::vector<Discount> &discounts) - : counts_(counts), discounts_(discounts) {} + // counts: output + // counts_pruned: output + // discounts: mostly output. If the input already has entries, they will be kept. + // prune_thresholds: input. n-grams with normal (not adjusted) count below this will be pruned. + AdjustCounts( + const std::vector<uint64_t> &prune_thresholds, + std::vector<uint64_t> &counts, + std::vector<uint64_t> &counts_pruned, + const DiscountConfig &discount_config, + std::vector<Discount> &discounts) + : prune_thresholds_(prune_thresholds), counts_(counts), counts_pruned_(counts_pruned), discount_config_(discount_config), discounts_(discounts) + {} - void Run(const ChainPositions &positions); + void Run(const util::stream::ChainPositions &positions); private: + const std::vector<uint64_t> &prune_thresholds_; std::vector<uint64_t> &counts_; + std::vector<uint64_t> &counts_pruned_; + + DiscountConfig discount_config_; std::vector<Discount> &discounts_; }; } // namespace builder } // namespace lm -#endif // LM_BUILDER_ADJUST_COUNTS__ +#endif // LM_BUILDER_ADJUST_COUNTS_H |