diff options
author | Patrick Simianer <p@simianer.de> | 2014-10-13 19:03:48 +0100 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2014-10-13 19:03:48 +0100 |
commit | cb9fb7088dde35881516c088db402abe747d49fa (patch) | |
tree | a91e4935a7941f1b261f76d88ab41fa3078a1891 /klm/lm/builder/adjust_counts.hh | |
parent | 0a00e57e921c8eca8e02364db7d2e6607bfdcebc (diff) | |
parent | b1ed81ef3216b212295afa76c5d20a56fb647204 (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'klm/lm/builder/adjust_counts.hh')
-rw-r--r-- | klm/lm/builder/adjust_counts.hh | 41 |
1 files changed, 33 insertions, 8 deletions
diff --git a/klm/lm/builder/adjust_counts.hh b/klm/lm/builder/adjust_counts.hh index f38ff79d..a5435c28 100644 --- a/klm/lm/builder/adjust_counts.hh +++ b/klm/lm/builder/adjust_counts.hh @@ -1,24 +1,35 @@ -#ifndef LM_BUILDER_ADJUST_COUNTS__ -#define LM_BUILDER_ADJUST_COUNTS__ +#ifndef LM_BUILDER_ADJUST_COUNTS_H +#define LM_BUILDER_ADJUST_COUNTS_H #include "lm/builder/discount.hh" +#include "lm/lm_exception.hh" #include "util/exception.hh" #include <vector> #include <stdint.h> +namespace util { namespace stream { class ChainPositions; } } + namespace lm { namespace builder { -class ChainPositions; - class BadDiscountException : public util::Exception { public: BadDiscountException() throw(); ~BadDiscountException() throw(); }; +struct DiscountConfig { + // Overrides discounts for orders [1,discount_override.size()]. + std::vector<Discount> overwrite; + // If discounting fails for an order, copy them from here. + Discount fallback; + // What to do when discounts are out of range or would trigger divison by + // zero. It it does something other than THROW_UP, use fallback_discount. + WarningAction bad_action; +}; + /* Compute adjusted counts. * Input: unique suffix sorted N-grams (and just the N-grams) with raw counts. * Output: [1,N]-grams with adjusted counts. @@ -27,18 +38,32 @@ class BadDiscountException : public util::Exception { */ class AdjustCounts { public: - AdjustCounts(std::vector<uint64_t> &counts, std::vector<Discount> &discounts) - : counts_(counts), discounts_(discounts) {} + // counts: output + // counts_pruned: output + // discounts: mostly output. If the input already has entries, they will be kept. + // prune_thresholds: input. n-grams with normal (not adjusted) count below this will be pruned. + AdjustCounts( + const std::vector<uint64_t> &prune_thresholds, + std::vector<uint64_t> &counts, + std::vector<uint64_t> &counts_pruned, + const DiscountConfig &discount_config, + std::vector<Discount> &discounts) + : prune_thresholds_(prune_thresholds), counts_(counts), counts_pruned_(counts_pruned), discount_config_(discount_config), discounts_(discounts) + {} - void Run(const ChainPositions &positions); + void Run(const util::stream::ChainPositions &positions); private: + const std::vector<uint64_t> &prune_thresholds_; std::vector<uint64_t> &counts_; + std::vector<uint64_t> &counts_pruned_; + + DiscountConfig discount_config_; std::vector<Discount> &discounts_; }; } // namespace builder } // namespace lm -#endif // LM_BUILDER_ADJUST_COUNTS__ +#endif // LM_BUILDER_ADJUST_COUNTS_H |