diff options
Diffstat (limited to 'klm/lm/builder/interpolate.cc')
-rw-r--r-- | klm/lm/builder/interpolate.cc | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/klm/lm/builder/interpolate.cc b/klm/lm/builder/interpolate.cc new file mode 100644 index 00000000..50026806 --- /dev/null +++ b/klm/lm/builder/interpolate.cc @@ -0,0 +1,65 @@ +#include "lm/builder/interpolate.hh" + +#include "lm/builder/joint_order.hh" +#include "lm/builder/multi_stream.hh" +#include "lm/builder/sort.hh" +#include "lm/lm_exception.hh" + +#include <assert.h> + +namespace lm { namespace builder { +namespace { + +class Callback { + public: + Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) { + probs_[0] = uniform_prob; + for (std::size_t i = 0; i < backoffs.size(); ++i) { + backoffs_.push_back(backoffs[i]); + } + } + + ~Callback() { + for (std::size_t i = 0; i < backoffs_.size(); ++i) { + if (backoffs_[i]) { + std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl; + abort(); + } + } + } + + void Enter(unsigned order_minus_1, NGram &gram) { + Payload &pay = gram.Value(); + pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1]; + probs_[order_minus_1 + 1] = pay.complete.prob; + pay.complete.prob = log10(pay.complete.prob); + // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling. + if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) { + pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get())); + ++backoffs_[order_minus_1]; + } else { + // Not a context. + pay.complete.backoff = 0.0; + } + } + + void Exit(unsigned, const NGram &) const {} + + private: + FixedArray<util::stream::Stream> backoffs_; + + std::vector<float> probs_; +}; +} // namespace + +Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs) + : uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {} + +// perform order-wise interpolation +void Interpolate::Run(const ChainPositions &positions) { + assert(positions.size() == backoffs_.size() + 1); + Callback callback(uniform_prob_, backoffs_); + JointOrder<Callback, SuffixOrder>(positions, callback); +} + +}} // namespaces |