From b1ed81ef3216b212295afa76c5d20a56fb647204 Mon Sep 17 00:00:00 2001
From: Chris Dyer <redpony@gmail.com>
Date: Mon, 13 Oct 2014 00:42:37 -0400
Subject: new kenlm

---
 klm/lm/builder/adjust_counts.cc | 164 ++++++++++++++++++++++++++++------------
 1 file changed, 117 insertions(+), 47 deletions(-)

(limited to 'klm/lm/builder/adjust_counts.cc')

diff --git a/klm/lm/builder/adjust_counts.cc b/klm/lm/builder/adjust_counts.cc
index a6f48011..803c557d 100644
--- a/klm/lm/builder/adjust_counts.cc
+++ b/klm/lm/builder/adjust_counts.cc
@@ -1,8 +1,9 @@
 #include "lm/builder/adjust_counts.hh"
-#include "lm/builder/multi_stream.hh"
+#include "lm/builder/ngram_stream.hh"
 #include "util/stream/timer.hh"
 
 #include <algorithm>
+#include <iostream>
 
 namespace lm { namespace builder {
 
@@ -10,56 +11,78 @@ BadDiscountException::BadDiscountException() throw() {}
 BadDiscountException::~BadDiscountException() throw() {}
 
 namespace {
-// Return last word in full that is different.  
+// Return last word in full that is different.
 const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) {
   const WordIndex *cur_word = full.end() - 1;
   const WordIndex *pre_word = lower_last.end() - 1;
-  // Find last difference.  
+  // Find last difference.
   for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {}
   return cur_word;
 }
 
 class StatCollector {
   public:
-    StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<Discount> &discounts) 
-      : orders_(order), full_(orders_.back()), counts_(counts), discounts_(discounts) {
+    StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<uint64_t> &counts_pruned, std::vector<Discount> &discounts)
+      : orders_(order), full_(orders_.back()), counts_(counts), counts_pruned_(counts_pruned), discounts_(discounts) {
       memset(&orders_[0], 0, sizeof(OrderStat) * order);
     }
 
     ~StatCollector() {}
 
-    void CalculateDiscounts() {
+    void CalculateDiscounts(const DiscountConfig &config) {
       counts_.resize(orders_.size());
-      discounts_.resize(orders_.size());
+      counts_pruned_.resize(orders_.size());
       for (std::size_t i = 0; i < orders_.size(); ++i) {
         const OrderStat &s = orders_[i];
         counts_[i] = s.count;
+        counts_pruned_[i] = s.count_pruned;
+      }
 
-        for (unsigned j = 1; j < 4; ++j) {
-          // TODO: Specialize error message for j == 3, meaning 3+
-          UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for "
-              << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any "
-              << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?");
-        }
-
-        // See equation (26) in Chen and Goodman.
-        discounts_[i].amount[0] = 0.0;
-        float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]);
-        for (unsigned j = 1; j < 4; ++j) {
-          discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]);
-          UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]);
+      discounts_ = config.overwrite;
+      discounts_.resize(orders_.size());
+      for (std::size_t i = config.overwrite.size(); i < orders_.size(); ++i) {
+        const OrderStat &s = orders_[i];
+        try {
+          for (unsigned j = 1; j < 4; ++j) {
+            // TODO: Specialize error message for j == 3, meaning 3+
+            UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for "
+                << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any "
+                << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?");
+          }
+
+          // See equation (26) in Chen and Goodman.
+          discounts_[i].amount[0] = 0.0;
+          float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]);
+          for (unsigned j = 1; j < 4; ++j) {
+            discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]);
+            UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]);
+          }
+        } catch (const BadDiscountException &e) {
+          switch (config.bad_action) {
+            case THROW_UP:
+              throw;
+            case COMPLAIN:
+              std::cerr << e.what() << "  Substituting fallback discounts D1=" << config.fallback.amount[1] << " D2=" << config.fallback.amount[2] << " D3+=" << config.fallback.amount[3] << std::endl;
+            case SILENT:
+              break;
+          }
+          discounts_[i] = config.fallback;
         }
       }
     }
 
-    void Add(std::size_t order_minus_1, uint64_t count) {
+    void Add(std::size_t order_minus_1, uint64_t count, bool pruned = false) {
       OrderStat &stat = orders_[order_minus_1];
       ++stat.count;
+      if (!pruned)
+        ++stat.count_pruned;
       if (count < 5) ++stat.n[count];
     }
 
-    void AddFull(uint64_t count) {
+    void AddFull(uint64_t count, bool pruned = false) {
       ++full_.count;
+      if (!pruned)
+        ++full_.count_pruned;
       if (count < 5) ++full_.n[count];
     }
 
@@ -68,24 +91,27 @@ class StatCollector {
       // n_1 in equation 26 of Chen and Goodman etc
       uint64_t n[5];
       uint64_t count;
+      uint64_t count_pruned;
     };
 
     std::vector<OrderStat> orders_;
     OrderStat &full_;
 
     std::vector<uint64_t> &counts_;
+    std::vector<uint64_t> &counts_pruned_;
     std::vector<Discount> &discounts_;
 };
 
-// Reads all entries in order like NGramStream does.  
+// Reads all entries in order like NGramStream does.
 // But deletes any entries that have <s> in the 1st (not 0th) position on the
 // way out by putting other entries in their place.  This disrupts the sort
-// order but we don't care because the data is going to be sorted again.  
+// order but we don't care because the data is going to be sorted again.
 class CollapseStream {
   public:
-    CollapseStream(const util::stream::ChainPosition &position) :
+    CollapseStream(const util::stream::ChainPosition &position, uint64_t prune_threshold) :
       current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())),
-      block_(position) {
+      prune_threshold_(prune_threshold),
+      block_(position) { 
       StartBlock();
     }
 
@@ -96,10 +122,18 @@ class CollapseStream {
 
     CollapseStream &operator++() {
       assert(block_);
+      
       if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) {
         memcpy(current_.Base(), copy_from_, current_.TotalSize());
         UpdateCopyFrom();
+        
+        // Mark highest order n-grams for later pruning
+        if(current_.Count() <= prune_threshold_) {
+          current_.Mark(); 
+        }
+        
       }
+    
       current_.NextInMemory();
       uint8_t *block_base = static_cast<uint8_t*>(block_->Get());
       if (current_.Base() == block_base + block_->ValidSize()) {
@@ -107,6 +141,12 @@ class CollapseStream {
         ++block_;
         StartBlock();
       }
+      
+      // Mark highest order n-grams for later pruning
+      if(current_.Count() <= prune_threshold_) {
+        current_.Mark(); 
+      }
+      
       return *this;
     }
 
@@ -119,9 +159,15 @@ class CollapseStream {
       current_.ReBase(block_->Get());
       copy_from_ = static_cast<uint8_t*>(block_->Get()) + block_->ValidSize();
       UpdateCopyFrom();
+      
+      // Mark highest order n-grams for later pruning
+      if(current_.Count() <= prune_threshold_) {
+        current_.Mark(); 
+      }
+      
     }
 
-    // Find last without bos.  
+    // Find last without bos.
     void UpdateCopyFrom() {
       for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) {
         if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break;
@@ -132,83 +178,107 @@ class CollapseStream {
 
     // Goes backwards in the block
     uint8_t *copy_from_;
-
+    uint64_t prune_threshold_;
     util::stream::Link block_;
 };
 
 } // namespace
 
-void AdjustCounts::Run(const ChainPositions &positions) {
+void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
   UTIL_TIMER("(%w s) Adjusted counts\n");
 
   const std::size_t order = positions.size();
-  StatCollector stats(order, counts_, discounts_);
+  StatCollector stats(order, counts_, counts_pruned_, discounts_);
   if (order == 1) {
+
     // Only unigrams.  Just collect stats.  
     for (NGramStream full(positions[0]); full; ++full) 
       stats.AddFull(full->Count());
-    stats.CalculateDiscounts();
+
+    stats.CalculateDiscounts(discount_config_);
     return;
   }
 
   NGramStreams streams;
   streams.Init(positions, positions.size() - 1);
-  CollapseStream full(positions[positions.size() - 1]);
+  
+  CollapseStream full(positions[positions.size() - 1], prune_thresholds_.back());
 
-  // Initialization: <unk> has count 0 and so does <s>.  
+  // Initialization: <unk> has count 0 and so does <s>.
   NGramStream *lower_valid = streams.begin();
   streams[0]->Count() = 0;
   *streams[0]->begin() = kUNK;
   stats.Add(0, 0);
   (++streams[0])->Count() = 0;
   *streams[0]->begin() = kBOS;
-  // not in stats because it will get put in later.  
+  // not in stats because it will get put in later.
 
+  std::vector<uint64_t> lower_counts(positions.size(), 0);
+  
   // iterate over full (the stream of the highest order ngrams)
-  for (; full; ++full) {
+  for (; full; ++full) {  
     const WordIndex *different = FindDifference(*full, **lower_valid);
     std::size_t same = full->end() - 1 - different;
-    // Increment the adjusted count.  
+    // Increment the adjusted count.
     if (same) ++streams[same - 1]->Count();
 
-    // Output all the valid ones that changed.  
+    // Output all the valid ones that changed.
     for (; lower_valid >= &streams[same]; --lower_valid) {
-      stats.Add(lower_valid - streams.begin(), (*lower_valid)->Count());
+      
+      // mjd: review this!
+      uint64_t order = (*lower_valid)->Order();
+      uint64_t realCount = lower_counts[order - 1];
+      if(order > 1 && prune_thresholds_[order - 1] && realCount <= prune_thresholds_[order - 1])
+        (*lower_valid)->Mark();
+      
+      stats.Add(lower_valid - streams.begin(), (*lower_valid)->UnmarkedCount(), (*lower_valid)->IsMarked());
       ++*lower_valid;
     }
+    
+    // Count the true occurrences of lower-order n-grams
+    for (std::size_t i = 0; i < lower_counts.size(); ++i) {
+        if (i >= same) {
+          lower_counts[i] = 0;
+        }
+        lower_counts[i] += full->UnmarkedCount();
+    }
 
     // This is here because bos is also const WordIndex *, so copy gets
-    // consistent argument types.  
+    // consistent argument types.
     const WordIndex *full_end = full->end();
-    // Initialize and mark as valid up to bos.  
+    // Initialize and mark as valid up to bos.
     const WordIndex *bos;
     for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) {
       ++lower_valid;
       std::copy(bos, full_end, (*lower_valid)->begin());
       (*lower_valid)->Count() = 1;
     }
-    // Now bos indicates where <s> is or is the 0th word of full.  
+    // Now bos indicates where <s> is or is the 0th word of full.
     if (bos != full->begin()) {
-      // There is an <s> beyond the 0th word.  
+      // There is an <s> beyond the 0th word.
       NGramStream &to = *++lower_valid;
       std::copy(bos, full_end, to->begin());
-      to->Count() = full->Count();  
+
+      // mjd: what is this doing?
+      to->Count() = full->UnmarkedCount(); 
     } else {
-      stats.AddFull(full->Count());
+      stats.AddFull(full->UnmarkedCount(), full->IsMarked()); 
     }
     assert(lower_valid >= &streams[0]);
   }
 
   // Output everything valid.
   for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) {
-    stats.Add(s - streams.begin(), (*s)->Count());
+    if((*s)->Count() <= prune_thresholds_[(*s)->Order() - 1])
+      (*s)->Mark();
+    stats.Add(s - streams.begin(), (*s)->UnmarkedCount(), (*s)->IsMarked());
     ++*s;
   }
-  // Poison everyone!  Except the N-grams which were already poisoned by the input.   
+  // Poison everyone!  Except the N-grams which were already poisoned by the input.
   for (NGramStream *s = streams.begin(); s != streams.end(); ++s)
     s->Poison();
 
-  stats.CalculateDiscounts();
+  stats.CalculateDiscounts(discount_config_);
 
   // NOTE: See special early-return case for unigrams near the top of this function
 }
-- 
cgit v1.2.3