diff options
| author | armatthews <armatthe@cmu.edu> | 2014-10-13 14:59:23 -0400 | 
|---|---|---|
| committer | armatthews <armatthe@cmu.edu> | 2014-10-13 14:59:23 -0400 | 
| commit | b26cda84e05d4523eee069234a975a0153bf8608 (patch) | |
| tree | 61c9da4f8dd6070f27c8e81812a76fc0a8cf2d8d /klm/lm/builder/adjust_counts.hh | |
| parent | cd7bc67f475fdfd07fba003ac4cca40e83944740 (diff) | |
| parent | b1ed81ef3216b212295afa76c5d20a56fb647204 (diff) | |
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'klm/lm/builder/adjust_counts.hh')
| -rw-r--r-- | klm/lm/builder/adjust_counts.hh | 41 | 
1 files changed, 33 insertions, 8 deletions
| diff --git a/klm/lm/builder/adjust_counts.hh b/klm/lm/builder/adjust_counts.hh index f38ff79d..a5435c28 100644 --- a/klm/lm/builder/adjust_counts.hh +++ b/klm/lm/builder/adjust_counts.hh @@ -1,24 +1,35 @@ -#ifndef LM_BUILDER_ADJUST_COUNTS__ -#define LM_BUILDER_ADJUST_COUNTS__ +#ifndef LM_BUILDER_ADJUST_COUNTS_H +#define LM_BUILDER_ADJUST_COUNTS_H  #include "lm/builder/discount.hh" +#include "lm/lm_exception.hh"  #include "util/exception.hh"  #include <vector>  #include <stdint.h> +namespace util { namespace stream { class ChainPositions; } } +  namespace lm {  namespace builder { -class ChainPositions; -  class BadDiscountException : public util::Exception {    public:      BadDiscountException() throw();      ~BadDiscountException() throw();  }; +struct DiscountConfig { +  // Overrides discounts for orders [1,discount_override.size()]. +  std::vector<Discount> overwrite; +  // If discounting fails for an order, copy them from here. +  Discount fallback; +  // What to do when discounts are out of range or would trigger divison by +  // zero.  It it does something other than THROW_UP, use fallback_discount. +  WarningAction bad_action; +}; +  /* Compute adjusted counts.     * Input: unique suffix sorted N-grams (and just the N-grams) with raw counts.   * Output: [1,N]-grams with adjusted counts.   @@ -27,18 +38,32 @@ class BadDiscountException : public util::Exception {   */  class AdjustCounts {    public: -    AdjustCounts(std::vector<uint64_t> &counts, std::vector<Discount> &discounts) -      : counts_(counts), discounts_(discounts) {} +    // counts: output +    // counts_pruned: output +    // discounts: mostly output.  If the input already has entries, they will be kept. +    // prune_thresholds: input.  n-grams with normal (not adjusted) count below this will be pruned. +    AdjustCounts( +        const std::vector<uint64_t> &prune_thresholds, +        std::vector<uint64_t> &counts, +        std::vector<uint64_t> &counts_pruned, +        const DiscountConfig &discount_config, +        std::vector<Discount> &discounts) +      : prune_thresholds_(prune_thresholds), counts_(counts), counts_pruned_(counts_pruned), discount_config_(discount_config), discounts_(discounts)  +    {} -    void Run(const ChainPositions &positions); +    void Run(const util::stream::ChainPositions &positions);    private: +    const std::vector<uint64_t> &prune_thresholds_;       std::vector<uint64_t> &counts_; +    std::vector<uint64_t> &counts_pruned_; + +    DiscountConfig discount_config_;      std::vector<Discount> &discounts_;  };  } // namespace builder  } // namespace lm -#endif // LM_BUILDER_ADJUST_COUNTS__ +#endif // LM_BUILDER_ADJUST_COUNTS_H | 
