summaryrefslogtreecommitdiff
path: root/klm/lm/builder/adjust_counts.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/builder/adjust_counts.hh')
-rw-r--r--klm/lm/builder/adjust_counts.hh41
1 files changed, 33 insertions, 8 deletions
diff --git a/klm/lm/builder/adjust_counts.hh b/klm/lm/builder/adjust_counts.hh
index f38ff79d..a5435c28 100644
--- a/klm/lm/builder/adjust_counts.hh
+++ b/klm/lm/builder/adjust_counts.hh
@@ -1,24 +1,35 @@
-#ifndef LM_BUILDER_ADJUST_COUNTS__
-#define LM_BUILDER_ADJUST_COUNTS__
+#ifndef LM_BUILDER_ADJUST_COUNTS_H
+#define LM_BUILDER_ADJUST_COUNTS_H
#include "lm/builder/discount.hh"
+#include "lm/lm_exception.hh"
#include "util/exception.hh"
#include <vector>
#include <stdint.h>
+namespace util { namespace stream { class ChainPositions; } }
+
namespace lm {
namespace builder {
-class ChainPositions;
-
class BadDiscountException : public util::Exception {
public:
BadDiscountException() throw();
~BadDiscountException() throw();
};
+struct DiscountConfig {
+ // Overrides discounts for orders [1,discount_override.size()].
+ std::vector<Discount> overwrite;
+ // If discounting fails for an order, copy them from here.
+ Discount fallback;
+ // What to do when discounts are out of range or would trigger divison by
+ // zero. It it does something other than THROW_UP, use fallback_discount.
+ WarningAction bad_action;
+};
+
/* Compute adjusted counts.
* Input: unique suffix sorted N-grams (and just the N-grams) with raw counts.
* Output: [1,N]-grams with adjusted counts.
@@ -27,18 +38,32 @@ class BadDiscountException : public util::Exception {
*/
class AdjustCounts {
public:
- AdjustCounts(std::vector<uint64_t> &counts, std::vector<Discount> &discounts)
- : counts_(counts), discounts_(discounts) {}
+ // counts: output
+ // counts_pruned: output
+ // discounts: mostly output. If the input already has entries, they will be kept.
+ // prune_thresholds: input. n-grams with normal (not adjusted) count below this will be pruned.
+ AdjustCounts(
+ const std::vector<uint64_t> &prune_thresholds,
+ std::vector<uint64_t> &counts,
+ std::vector<uint64_t> &counts_pruned,
+ const DiscountConfig &discount_config,
+ std::vector<Discount> &discounts)
+ : prune_thresholds_(prune_thresholds), counts_(counts), counts_pruned_(counts_pruned), discount_config_(discount_config), discounts_(discounts)
+ {}
- void Run(const ChainPositions &positions);
+ void Run(const util::stream::ChainPositions &positions);
private:
+ const std::vector<uint64_t> &prune_thresholds_;
std::vector<uint64_t> &counts_;
+ std::vector<uint64_t> &counts_pruned_;
+
+ DiscountConfig discount_config_;
std::vector<Discount> &discounts_;
};
} // namespace builder
} // namespace lm
-#endif // LM_BUILDER_ADJUST_COUNTS__
+#endif // LM_BUILDER_ADJUST_COUNTS_H