diff options
Diffstat (limited to 'klm/search/nbest.cc')
-rw-r--r-- | klm/search/nbest.cc | 106 |
1 files changed, 106 insertions, 0 deletions
diff --git a/klm/search/nbest.cc b/klm/search/nbest.cc new file mode 100644 index 00000000..ec3322c9 --- /dev/null +++ b/klm/search/nbest.cc @@ -0,0 +1,106 @@ +#include "search/nbest.hh" + +#include "util/pool.hh" + +#include <algorithm> +#include <functional> +#include <queue> + +#include <assert.h> +#include <math.h> + +namespace search { + +NBestList::NBestList(std::vector<PartialEdge> &partials, util::Pool &entry_pool, std::size_t keep) { + assert(!partials.empty()); + std::vector<PartialEdge>::iterator end; + if (partials.size() > keep) { + end = partials.begin() + keep; + std::nth_element(partials.begin(), end, partials.end(), std::greater<PartialEdge>()); + } else { + end = partials.end(); + } + for (std::vector<PartialEdge>::const_iterator i(partials.begin()); i != end; ++i) { + queue_.push(QueueEntry(entry_pool.Allocate(QueueEntry::Size(i->GetArity())), *i)); + } +} + +Score NBestList::TopAfterConstructor() const { + assert(revealed_.empty()); + return queue_.top().GetScore(); +} + +const std::vector<Applied> &NBestList::Extract(util::Pool &pool, std::size_t n) { + while (revealed_.size() < n && !queue_.empty()) { + MoveTop(pool); + } + return revealed_; +} + +Score NBestList::Visit(util::Pool &pool, std::size_t index) { + if (index + 1 < revealed_.size()) + return revealed_[index + 1].GetScore() - revealed_[index].GetScore(); + if (queue_.empty()) + return -INFINITY; + if (index + 1 == revealed_.size()) + return queue_.top().GetScore() - revealed_[index].GetScore(); + assert(index == revealed_.size()); + + MoveTop(pool); + + if (queue_.empty()) return -INFINITY; + return queue_.top().GetScore() - revealed_[index].GetScore(); +} + +Applied NBestList::Get(util::Pool &pool, std::size_t index) { + assert(index <= revealed_.size()); + if (index == revealed_.size()) MoveTop(pool); + return revealed_[index]; +} + +void NBestList::MoveTop(util::Pool &pool) { + assert(!queue_.empty()); + QueueEntry entry(queue_.top()); + queue_.pop(); + RevealedRef *const children_begin = entry.Children(); + RevealedRef *const children_end = children_begin + entry.GetArity(); + Score basis = entry.GetScore(); + for (RevealedRef *child = children_begin; child != children_end; ++child) { + Score change = child->in_->Visit(pool, child->index_); + if (change != -INFINITY) { + assert(change < 0.001); + QueueEntry new_entry(pool.Allocate(QueueEntry::Size(entry.GetArity())), basis + change, entry.GetArity(), entry.GetNote()); + std::copy(children_begin, child, new_entry.Children()); + RevealedRef *update = new_entry.Children() + (child - children_begin); + update->in_ = child->in_; + update->index_ = child->index_ + 1; + std::copy(child + 1, children_end, update + 1); + queue_.push(new_entry); + } + // Gesmundo, A. and Henderson, J. Faster Cube Pruning, IWSLT 2010. + if (child->index_) break; + } + + // Convert QueueEntry to Applied. This leaves some unused memory. + void *overwrite = entry.Children(); + for (unsigned int i = 0; i < entry.GetArity(); ++i) { + RevealedRef from(*(static_cast<const RevealedRef*>(overwrite) + i)); + *(static_cast<Applied*>(overwrite) + i) = from.in_->Get(pool, from.index_); + } + revealed_.push_back(Applied(entry.Base())); +} + +NBestComplete NBest::Complete(std::vector<PartialEdge> &partials) { + assert(!partials.empty()); + NBestList *list = list_pool_.construct(partials, entry_pool_, config_.keep); + return NBestComplete( + list, + partials.front().CompletedState(), // All partials have the same state + list->TopAfterConstructor()); +} + +const std::vector<Applied> &NBest::Extract(History history) { + return static_cast<NBestList*>(history)->Extract(entry_pool_, config_.size); +} + +} // namespace search |