diff options
Diffstat (limited to 'klm/search/vertex.cc')
-rw-r--r-- | klm/search/vertex.cc | 204 |
1 files changed, 177 insertions, 27 deletions
diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc index 45842982..bf40810e 100644 --- a/klm/search/vertex.cc +++ b/klm/search/vertex.cc @@ -2,6 +2,8 @@ #include "search/context.hh" +#include <boost/unordered_map.hpp> + #include <algorithm> #include <functional> @@ -11,45 +13,193 @@ namespace search { namespace { -struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> { - bool operator()(const VertexNode *first, const VertexNode *second) const { - return first->Bound() > second->Bound(); +const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); + +class DivideLeft { + public: + explicit DivideLeft(unsigned char index) + : index_(index) {} + + uint64_t operator()(const lm::ngram::ChartState &state) const { + return (index_ < state.left.length) ? + state.left.pointers[index_] : + (kCompleteAdd - state.left.full); + } + + private: + unsigned char index_; +}; + +class DivideRight { + public: + explicit DivideRight(unsigned char index) + : index_(index) {} + + uint64_t operator()(const lm::ngram::ChartState &state) const { + return (index_ < state.right.length) ? + static_cast<uint64_t>(state.right.words[index_]) : + (kCompleteAdd - state.left.full); + } + + private: + unsigned char index_; +}; + +template <class Divider> void Split(const Divider ÷r, const std::vector<HypoState> &hypos, std::vector<VertexNode> &extend) { + // Map from divider to index in extend. + typedef boost::unordered_map<uint64_t, std::size_t> Lookup; + Lookup lookup; + for (std::vector<HypoState>::const_iterator i = hypos.begin(); i != hypos.end(); ++i) { + uint64_t key = divider(i->state); + std::pair<Lookup::iterator, bool> res(lookup.insert(std::make_pair(key, extend.size()))); + if (res.second) { + extend.resize(extend.size() + 1); + extend.back().AppendHypothesis(*i); + } else { + extend[res.first->second].AppendHypothesis(*i); + } } + //assert((extend.size() != 1) || (hypos.size() == 1)); +} + +lm::WordIndex Identify(const lm::ngram::Right &right, unsigned char index) { + return right.words[index]; +} + +uint64_t Identify(const lm::ngram::Left &left, unsigned char index) { + return left.pointers[index]; +} + +template <class Side> class DetermineSame { + public: + DetermineSame(const Side &side, unsigned char guaranteed) + : side_(side), guaranteed_(guaranteed), shared_(side.length), complete_(true) {} + + void Consider(const Side &other) { + if (shared_ != other.length) { + complete_ = false; + if (shared_ > other.length) + shared_ = other.length; + } + for (unsigned char i = guaranteed_; i < shared_; ++i) { + if (Identify(side_, i) != Identify(other, i)) { + shared_ = i; + complete_ = false; + return; + } + } + } + + unsigned char Shared() const { return shared_; } + + bool Complete() const { return complete_; } + + private: + const Side &side_; + unsigned char guaranteed_, shared_; + bool complete_; }; +// Custom enum to save memory: valid values of policy_. +// Alternate and there is still alternation to do. +const unsigned char kPolicyAlternate = 0; +// Branch based on left state only, because right ran out or this is a left tree. +const unsigned char kPolicyOneLeft = 1; +// Branch based on right state only. +const unsigned char kPolicyOneRight = 2; +// Reveal everything in the next branch. Used to terminate the left/right policies. +// static const unsigned char kPolicyEverything = 3; + +} // namespace + +namespace { +struct GreaterByScore : public std::binary_function<const HypoState &, const HypoState &, bool> { + bool operator()(const HypoState &first, const HypoState &second) const { + return first.score > second.score; + } +}; } // namespace -void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) { - if (Complete()) { - assert(end_); - assert(extend_.empty()); - return; +void VertexNode::FinishRoot() { + std::sort(hypos_.begin(), hypos_.end(), GreaterByScore()); + extend_.clear(); + // HACK: extend to one hypo so that root can be blank. + state_.left.full = false; + state_.left.length = 0; + state_.right.length = 0; + right_full_ = false; + niceness_ = 0; + policy_ = kPolicyAlternate; + if (hypos_.size() == 1) { + extend_.resize(1); + extend_.front().AppendHypothesis(hypos_.front()); + extend_.front().FinishedAppending(0, 0); + } + if (hypos_.empty()) { + bound_ = -INFINITY; + } else { + bound_ = hypos_.front().score; } - if (extend_.size() == 1) { - parent_ptr = extend_[0]; - extend_[0]->RecursiveSortAndSet(context, parent_ptr); - context.DeleteVertexNode(this); - return; +} + +void VertexNode::FinishedAppending(const unsigned char common_left, const unsigned char common_right) { + assert(!hypos_.empty()); + assert(extend_.empty()); + bound_ = hypos_.front().score; + state_ = hypos_.front().state; + bool all_full = state_.left.full; + bool all_non_full = !state_.left.full; + DetermineSame<lm::ngram::Left> left(state_.left, common_left); + DetermineSame<lm::ngram::Right> right(state_.right, common_right); + for (std::vector<HypoState>::const_iterator i = hypos_.begin() + 1; i != hypos_.end(); ++i) { + all_full &= i->state.left.full; + all_non_full &= !i->state.left.full; + left.Consider(i->state.left); + right.Consider(i->state.right); } - for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { - (*i)->RecursiveSortAndSet(context, *i); + state_.left.full = all_full && left.Complete(); + right_full_ = all_full && right.Complete(); + state_.left.length = left.Shared(); + state_.right.length = right.Shared(); + + if (!all_full && !all_non_full) { + policy_ = kPolicyAlternate; + } else if (left.Complete()) { + policy_ = kPolicyOneRight; + } else if (right.Complete()) { + policy_ = kPolicyOneLeft; + } else { + policy_ = kPolicyAlternate; } - std::sort(extend_.begin(), extend_.end(), GreaterByBound()); - bound_ = extend_.front()->Bound(); + niceness_ = state_.left.length + state_.right.length; } -void VertexNode::SortAndSet(ContextBase &context) { - // This is the root. The root might be empty. - if (extend_.empty()) { - bound_ = -INFINITY; - return; +void VertexNode::BuildExtend() { + // Already built. + if (!extend_.empty()) return; + // Nothing to build since this is a leaf. + if (hypos_.size() <= 1) return; + bool left_branch = true; + switch (policy_) { + case kPolicyAlternate: + left_branch = (state_.left.length <= state_.right.length); + break; + case kPolicyOneLeft: + left_branch = true; + break; + case kPolicyOneRight: + left_branch = false; + break; + } + if (left_branch) { + Split(DivideLeft(state_.left.length), hypos_, extend_); + } else { + Split(DivideRight(state_.right.length), hypos_, extend_); } - // The root cannot be replaced. There's always one transition. - for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { - (*i)->RecursiveSortAndSet(context, *i); + for (std::vector<VertexNode>::iterator i = extend_.begin(); i != extend_.end(); ++i) { + // TODO: provide more here for branching? + i->FinishedAppending(state_.left.length, state_.right.length); } - std::sort(extend_.begin(), extend_.end(), GreaterByBound()); - bound_ = extend_.front()->Bound(); } } // namespace search |