diff options
author | Avneesh Saluja <asaluja@gmail.com> | 2013-03-28 18:28:16 -0700 |
---|---|---|
committer | Avneesh Saluja <asaluja@gmail.com> | 2013-03-28 18:28:16 -0700 |
commit | 3d8d656fa7911524e0e6885647173474524e0784 (patch) | |
tree | 81b1ee2fcb67980376d03f0aa48e42e53abff222 /klm/lm/builder/interpolate.cc | |
parent | be7f57fdd484e063775d7abf083b9fa4c403b610 (diff) | |
parent | 96fedabebafe7a38a6d5928be8fff767e411d705 (diff) |
fixed conflicts
Diffstat (limited to 'klm/lm/builder/interpolate.cc')
-rw-r--r-- | klm/lm/builder/interpolate.cc | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/klm/lm/builder/interpolate.cc b/klm/lm/builder/interpolate.cc new file mode 100644 index 00000000..50026806 --- /dev/null +++ b/klm/lm/builder/interpolate.cc @@ -0,0 +1,65 @@ +#include "lm/builder/interpolate.hh" + +#include "lm/builder/joint_order.hh" +#include "lm/builder/multi_stream.hh" +#include "lm/builder/sort.hh" +#include "lm/lm_exception.hh" + +#include <assert.h> + +namespace lm { namespace builder { +namespace { + +class Callback { + public: + Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) { + probs_[0] = uniform_prob; + for (std::size_t i = 0; i < backoffs.size(); ++i) { + backoffs_.push_back(backoffs[i]); + } + } + + ~Callback() { + for (std::size_t i = 0; i < backoffs_.size(); ++i) { + if (backoffs_[i]) { + std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl; + abort(); + } + } + } + + void Enter(unsigned order_minus_1, NGram &gram) { + Payload &pay = gram.Value(); + pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1]; + probs_[order_minus_1 + 1] = pay.complete.prob; + pay.complete.prob = log10(pay.complete.prob); + // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling. + if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) { + pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get())); + ++backoffs_[order_minus_1]; + } else { + // Not a context. + pay.complete.backoff = 0.0; + } + } + + void Exit(unsigned, const NGram &) const {} + + private: + FixedArray<util::stream::Stream> backoffs_; + + std::vector<float> probs_; +}; +} // namespace + +Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs) + : uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {} + +// perform order-wise interpolation +void Interpolate::Run(const ChainPositions &positions) { + assert(positions.size() == backoffs_.size() + 1); + Callback callback(uniform_prob_, backoffs_); + JointOrder<Callback, SuffixOrder>(positions, callback); +} + +}} // namespaces |