#include "search/nbest.hh" #include "util/pool.hh" #include <algorithm> #include <functional> #include <queue> #include <assert.h> #include <math.h> namespace search { NBestList::NBestList(std::vector<PartialEdge> &partials, util::Pool &entry_pool, std::size_t keep) { assert(!partials.empty()); std::vector<PartialEdge>::iterator end; if (partials.size() > keep) { end = partials.begin() + keep; std::nth_element(partials.begin(), end, partials.end(), std::greater<PartialEdge>()); } else { end = partials.end(); } for (std::vector<PartialEdge>::const_iterator i(partials.begin()); i != end; ++i) { queue_.push(QueueEntry(entry_pool.Allocate(QueueEntry::Size(i->GetArity())), *i)); } } Score NBestList::TopAfterConstructor() const { assert(revealed_.empty()); return queue_.top().GetScore(); } const std::vector<Applied> &NBestList::Extract(util::Pool &pool, std::size_t n) { while (revealed_.size() < n && !queue_.empty()) { MoveTop(pool); } return revealed_; } Score NBestList::Visit(util::Pool &pool, std::size_t index) { if (index + 1 < revealed_.size()) return revealed_[index + 1].GetScore() - revealed_[index].GetScore(); if (queue_.empty()) return -INFINITY; if (index + 1 == revealed_.size()) return queue_.top().GetScore() - revealed_[index].GetScore(); assert(index == revealed_.size()); MoveTop(pool); if (queue_.empty()) return -INFINITY; return queue_.top().GetScore() - revealed_[index].GetScore(); } Applied NBestList::Get(util::Pool &pool, std::size_t index) { assert(index <= revealed_.size()); if (index == revealed_.size()) MoveTop(pool); return revealed_[index]; } void NBestList::MoveTop(util::Pool &pool) { assert(!queue_.empty()); QueueEntry entry(queue_.top()); queue_.pop(); RevealedRef *const children_begin = entry.Children(); RevealedRef *const children_end = children_begin + entry.GetArity(); Score basis = entry.GetScore(); for (RevealedRef *child = children_begin; child != children_end; ++child) { Score change = child->in_->Visit(pool, child->index_); if (change != -INFINITY) { assert(change < 0.001); QueueEntry new_entry(pool.Allocate(QueueEntry::Size(entry.GetArity())), basis + change, entry.GetArity(), entry.GetNote()); std::copy(children_begin, child, new_entry.Children()); RevealedRef *update = new_entry.Children() + (child - children_begin); update->in_ = child->in_; update->index_ = child->index_ + 1; std::copy(child + 1, children_end, update + 1); queue_.push(new_entry); } // Gesmundo, A. and Henderson, J. Faster Cube Pruning, IWSLT 2010. if (child->index_) break; } // Convert QueueEntry to Applied. This leaves some unused memory. void *overwrite = entry.Children(); for (unsigned int i = 0; i < entry.GetArity(); ++i) { RevealedRef from(*(static_cast<const RevealedRef*>(overwrite) + i)); *(static_cast<Applied*>(overwrite) + i) = from.in_->Get(pool, from.index_); } revealed_.push_back(Applied(entry.Base())); } NBestComplete NBest::Complete(std::vector<PartialEdge> &partials) { assert(!partials.empty()); NBestList *list = list_pool_.construct(partials, entry_pool_, config_.keep); return NBestComplete( list, partials.front().CompletedState(), // All partials have the same state list->TopAfterConstructor()); } const std::vector<Applied> &NBest::Extract(History history) { return static_cast<NBestList*>(history)->Extract(entry_pool_, config_.size); } } // namespace search