diff options
author | Paul Baltescu <pauldb89@gmail.com> | 2013-02-21 14:13:55 +0000 |
---|---|---|
committer | Paul Baltescu <pauldb89@gmail.com> | 2013-02-21 14:13:55 +0000 |
commit | bca26d953a774b8efca12f30407390b3f5eef9d0 (patch) | |
tree | fe922de5c89b1844f677d550dcc24e87edd67a55 /klm/lm/builder/interpolate.cc | |
parent | 54a1c0e2bde259e3acc9c0a8ec8da3c7704e80ca (diff) | |
parent | 95c364f2cb002241c4a62bedb1c5ef6f1e9a7f22 (diff) |
Merge branch 'master' of https://github.com/pauldb89/cdec
Diffstat (limited to 'klm/lm/builder/interpolate.cc')
-rw-r--r-- | klm/lm/builder/interpolate.cc | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/klm/lm/builder/interpolate.cc b/klm/lm/builder/interpolate.cc new file mode 100644 index 00000000..50026806 --- /dev/null +++ b/klm/lm/builder/interpolate.cc @@ -0,0 +1,65 @@ +#include "lm/builder/interpolate.hh" + +#include "lm/builder/joint_order.hh" +#include "lm/builder/multi_stream.hh" +#include "lm/builder/sort.hh" +#include "lm/lm_exception.hh" + +#include <assert.h> + +namespace lm { namespace builder { +namespace { + +class Callback { + public: + Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) { + probs_[0] = uniform_prob; + for (std::size_t i = 0; i < backoffs.size(); ++i) { + backoffs_.push_back(backoffs[i]); + } + } + + ~Callback() { + for (std::size_t i = 0; i < backoffs_.size(); ++i) { + if (backoffs_[i]) { + std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl; + abort(); + } + } + } + + void Enter(unsigned order_minus_1, NGram &gram) { + Payload &pay = gram.Value(); + pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1]; + probs_[order_minus_1 + 1] = pay.complete.prob; + pay.complete.prob = log10(pay.complete.prob); + // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling. + if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) { + pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get())); + ++backoffs_[order_minus_1]; + } else { + // Not a context. + pay.complete.backoff = 0.0; + } + } + + void Exit(unsigned, const NGram &) const {} + + private: + FixedArray<util::stream::Stream> backoffs_; + + std::vector<float> probs_; +}; +} // namespace + +Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs) + : uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {} + +// perform order-wise interpolation +void Interpolate::Run(const ChainPositions &positions) { + assert(positions.size() == backoffs_.size() + 1); + Callback callback(uniform_prob_, backoffs_); + JointOrder<Callback, SuffixOrder>(positions, callback); +} + +}} // namespaces |