summaryrefslogtreecommitdiff
path: root/klm/lm/builder/interpolate.cc
diff options
context:
space:
mode:
authorAvneesh Saluja <asaluja@gmail.com>2013-03-28 18:28:16 -0700
committerAvneesh Saluja <asaluja@gmail.com>2013-03-28 18:28:16 -0700
commit3d8d656fa7911524e0e6885647173474524e0784 (patch)
tree81b1ee2fcb67980376d03f0aa48e42e53abff222 /klm/lm/builder/interpolate.cc
parentbe7f57fdd484e063775d7abf083b9fa4c403b610 (diff)
parent96fedabebafe7a38a6d5928be8fff767e411d705 (diff)
fixed conflicts
Diffstat (limited to 'klm/lm/builder/interpolate.cc')
-rw-r--r--klm/lm/builder/interpolate.cc65
1 files changed, 65 insertions, 0 deletions
diff --git a/klm/lm/builder/interpolate.cc b/klm/lm/builder/interpolate.cc
new file mode 100644
index 00000000..50026806
--- /dev/null
+++ b/klm/lm/builder/interpolate.cc
@@ -0,0 +1,65 @@
+#include "lm/builder/interpolate.hh"
+
+#include "lm/builder/joint_order.hh"
+#include "lm/builder/multi_stream.hh"
+#include "lm/builder/sort.hh"
+#include "lm/lm_exception.hh"
+
+#include <assert.h>
+
+namespace lm { namespace builder {
+namespace {
+
+class Callback {
+ public:
+ Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) {
+ probs_[0] = uniform_prob;
+ for (std::size_t i = 0; i < backoffs.size(); ++i) {
+ backoffs_.push_back(backoffs[i]);
+ }
+ }
+
+ ~Callback() {
+ for (std::size_t i = 0; i < backoffs_.size(); ++i) {
+ if (backoffs_[i]) {
+ std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl;
+ abort();
+ }
+ }
+ }
+
+ void Enter(unsigned order_minus_1, NGram &gram) {
+ Payload &pay = gram.Value();
+ pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
+ probs_[order_minus_1 + 1] = pay.complete.prob;
+ pay.complete.prob = log10(pay.complete.prob);
+ // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling.
+ if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) {
+ pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get()));
+ ++backoffs_[order_minus_1];
+ } else {
+ // Not a context.
+ pay.complete.backoff = 0.0;
+ }
+ }
+
+ void Exit(unsigned, const NGram &) const {}
+
+ private:
+ FixedArray<util::stream::Stream> backoffs_;
+
+ std::vector<float> probs_;
+};
+} // namespace
+
+Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs)
+ : uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {}
+
+// perform order-wise interpolation
+void Interpolate::Run(const ChainPositions &positions) {
+ assert(positions.size() == backoffs_.size() + 1);
+ Callback callback(uniform_prob_, backoffs_);
+ JointOrder<Callback, SuffixOrder>(positions, callback);
+}
+
+}} // namespaces