diff options
Diffstat (limited to 'klm/search')
-rw-r--r-- | klm/search/Makefile.am | 17 | ||||
-rw-r--r-- | klm/search/context.hh | 12 | ||||
-rw-r--r-- | klm/search/edge_generator.cc | 12 | ||||
-rw-r--r-- | klm/search/vertex.cc | 204 | ||||
-rw-r--r-- | klm/search/vertex.hh | 121 | ||||
-rw-r--r-- | klm/search/vertex_generator.hh | 36 |
6 files changed, 271 insertions, 131 deletions
diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am index 03554276..b8c8a050 100644 --- a/klm/search/Makefile.am +++ b/klm/search/Makefile.am @@ -1,23 +1,10 @@ noinst_LIBRARIES = libksearch.a libksearch_a_SOURCES = \ - applied.hh \ - config.hh \ - context.hh \ - dedupe.hh \ - edge.hh \ - edge_generator.hh \ - header.hh \ - nbest.hh \ - rule.hh \ - types.hh \ - vertex.hh \ - vertex_generator.hh \ edge_generator.cc \ nbest.cc \ rule.cc \ - vertex.cc \ - vertex_generator.cc + vertex.cc -AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/klm +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/search/context.hh b/klm/search/context.hh index 08f21bbf..c3c8e53b 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -12,16 +12,6 @@ class ContextBase { public: explicit ContextBase(const Config &config) : config_(config) {} - VertexNode *NewVertexNode() { - VertexNode *ret = vertex_node_pool_.construct(); - assert(ret); - return ret; - } - - void DeleteVertexNode(VertexNode *node) { - vertex_node_pool_.destroy(node); - } - unsigned int PopLimit() const { return config_.PopLimit(); } Score LMWeight() const { return config_.LMWeight(); } @@ -29,8 +19,6 @@ class ContextBase { const Config &GetConfig() const { return config_; } private: - boost::object_pool<VertexNode> vertex_node_pool_; - Config config_; }; diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index eacf5de5..dd9d61e4 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -54,20 +54,20 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) { Arity victim = 0; Arity victim_completed; Arity incomplete; + unsigned char lowest_niceness = 255; // Select victim or return if complete. { Arity completed = 0; - unsigned char lowest_length = 255; for (Arity i = 0; i != arity; ++i) { if (top_nt[i].Complete()) { ++completed; - } else if (top_nt[i].Length() < lowest_length) { - lowest_length = top_nt[i].Length(); + } else if (top_nt[i].Niceness() < lowest_niceness) { + lowest_niceness = top_nt[i].Niceness(); victim = i; victim_completed = completed; } } - if (lowest_length == 255) { + if (lowest_niceness == 255) { return top; } incomplete = arity - completed; @@ -92,10 +92,14 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) { generate_.push(alternate); } +#ifndef NDEBUG + Score before = top.GetScore(); +#endif // top is now the continuation. FastScore(context, victim, victim - victim_completed, incomplete, old_value, top); // TODO: dedupe? generate_.push(top); + assert(lowest_niceness != 254 || top.GetScore() == before); // Invalid indicates no new hypothesis generated. return PartialEdge(); diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc index 45842982..bf40810e 100644 --- a/klm/search/vertex.cc +++ b/klm/search/vertex.cc @@ -2,6 +2,8 @@ #include "search/context.hh" +#include <boost/unordered_map.hpp> + #include <algorithm> #include <functional> @@ -11,45 +13,193 @@ namespace search { namespace { -struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> { - bool operator()(const VertexNode *first, const VertexNode *second) const { - return first->Bound() > second->Bound(); +const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); + +class DivideLeft { + public: + explicit DivideLeft(unsigned char index) + : index_(index) {} + + uint64_t operator()(const lm::ngram::ChartState &state) const { + return (index_ < state.left.length) ? + state.left.pointers[index_] : + (kCompleteAdd - state.left.full); + } + + private: + unsigned char index_; +}; + +class DivideRight { + public: + explicit DivideRight(unsigned char index) + : index_(index) {} + + uint64_t operator()(const lm::ngram::ChartState &state) const { + return (index_ < state.right.length) ? + static_cast<uint64_t>(state.right.words[index_]) : + (kCompleteAdd - state.left.full); + } + + private: + unsigned char index_; +}; + +template <class Divider> void Split(const Divider ÷r, const std::vector<HypoState> &hypos, std::vector<VertexNode> &extend) { + // Map from divider to index in extend. + typedef boost::unordered_map<uint64_t, std::size_t> Lookup; + Lookup lookup; + for (std::vector<HypoState>::const_iterator i = hypos.begin(); i != hypos.end(); ++i) { + uint64_t key = divider(i->state); + std::pair<Lookup::iterator, bool> res(lookup.insert(std::make_pair(key, extend.size()))); + if (res.second) { + extend.resize(extend.size() + 1); + extend.back().AppendHypothesis(*i); + } else { + extend[res.first->second].AppendHypothesis(*i); + } } + //assert((extend.size() != 1) || (hypos.size() == 1)); +} + +lm::WordIndex Identify(const lm::ngram::Right &right, unsigned char index) { + return right.words[index]; +} + +uint64_t Identify(const lm::ngram::Left &left, unsigned char index) { + return left.pointers[index]; +} + +template <class Side> class DetermineSame { + public: + DetermineSame(const Side &side, unsigned char guaranteed) + : side_(side), guaranteed_(guaranteed), shared_(side.length), complete_(true) {} + + void Consider(const Side &other) { + if (shared_ != other.length) { + complete_ = false; + if (shared_ > other.length) + shared_ = other.length; + } + for (unsigned char i = guaranteed_; i < shared_; ++i) { + if (Identify(side_, i) != Identify(other, i)) { + shared_ = i; + complete_ = false; + return; + } + } + } + + unsigned char Shared() const { return shared_; } + + bool Complete() const { return complete_; } + + private: + const Side &side_; + unsigned char guaranteed_, shared_; + bool complete_; }; +// Custom enum to save memory: valid values of policy_. +// Alternate and there is still alternation to do. +const unsigned char kPolicyAlternate = 0; +// Branch based on left state only, because right ran out or this is a left tree. +const unsigned char kPolicyOneLeft = 1; +// Branch based on right state only. +const unsigned char kPolicyOneRight = 2; +// Reveal everything in the next branch. Used to terminate the left/right policies. +// static const unsigned char kPolicyEverything = 3; + +} // namespace + +namespace { +struct GreaterByScore : public std::binary_function<const HypoState &, const HypoState &, bool> { + bool operator()(const HypoState &first, const HypoState &second) const { + return first.score > second.score; + } +}; } // namespace -void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) { - if (Complete()) { - assert(end_); - assert(extend_.empty()); - return; +void VertexNode::FinishRoot() { + std::sort(hypos_.begin(), hypos_.end(), GreaterByScore()); + extend_.clear(); + // HACK: extend to one hypo so that root can be blank. + state_.left.full = false; + state_.left.length = 0; + state_.right.length = 0; + right_full_ = false; + niceness_ = 0; + policy_ = kPolicyAlternate; + if (hypos_.size() == 1) { + extend_.resize(1); + extend_.front().AppendHypothesis(hypos_.front()); + extend_.front().FinishedAppending(0, 0); + } + if (hypos_.empty()) { + bound_ = -INFINITY; + } else { + bound_ = hypos_.front().score; } - if (extend_.size() == 1) { - parent_ptr = extend_[0]; - extend_[0]->RecursiveSortAndSet(context, parent_ptr); - context.DeleteVertexNode(this); - return; +} + +void VertexNode::FinishedAppending(const unsigned char common_left, const unsigned char common_right) { + assert(!hypos_.empty()); + assert(extend_.empty()); + bound_ = hypos_.front().score; + state_ = hypos_.front().state; + bool all_full = state_.left.full; + bool all_non_full = !state_.left.full; + DetermineSame<lm::ngram::Left> left(state_.left, common_left); + DetermineSame<lm::ngram::Right> right(state_.right, common_right); + for (std::vector<HypoState>::const_iterator i = hypos_.begin() + 1; i != hypos_.end(); ++i) { + all_full &= i->state.left.full; + all_non_full &= !i->state.left.full; + left.Consider(i->state.left); + right.Consider(i->state.right); } - for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { - (*i)->RecursiveSortAndSet(context, *i); + state_.left.full = all_full && left.Complete(); + right_full_ = all_full && right.Complete(); + state_.left.length = left.Shared(); + state_.right.length = right.Shared(); + + if (!all_full && !all_non_full) { + policy_ = kPolicyAlternate; + } else if (left.Complete()) { + policy_ = kPolicyOneRight; + } else if (right.Complete()) { + policy_ = kPolicyOneLeft; + } else { + policy_ = kPolicyAlternate; } - std::sort(extend_.begin(), extend_.end(), GreaterByBound()); - bound_ = extend_.front()->Bound(); + niceness_ = state_.left.length + state_.right.length; } -void VertexNode::SortAndSet(ContextBase &context) { - // This is the root. The root might be empty. - if (extend_.empty()) { - bound_ = -INFINITY; - return; +void VertexNode::BuildExtend() { + // Already built. + if (!extend_.empty()) return; + // Nothing to build since this is a leaf. + if (hypos_.size() <= 1) return; + bool left_branch = true; + switch (policy_) { + case kPolicyAlternate: + left_branch = (state_.left.length <= state_.right.length); + break; + case kPolicyOneLeft: + left_branch = true; + break; + case kPolicyOneRight: + left_branch = false; + break; + } + if (left_branch) { + Split(DivideLeft(state_.left.length), hypos_, extend_); + } else { + Split(DivideRight(state_.right.length), hypos_, extend_); } - // The root cannot be replaced. There's always one transition. - for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { - (*i)->RecursiveSortAndSet(context, *i); + for (std::vector<VertexNode>::iterator i = extend_.begin(); i != extend_.end(); ++i) { + // TODO: provide more here for branching? + i->FinishedAppending(state_.left.length, state_.right.length); } - std::sort(extend_.begin(), extend_.end(), GreaterByBound()); - bound_ = extend_.front()->Bound(); } } // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index ca9a4fcd..81c3cfed 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -16,59 +16,74 @@ namespace search { class ContextBase; +struct HypoState { + History history; + lm::ngram::ChartState state; + Score score; +}; + class VertexNode { public: - VertexNode() : end_() {} - - void InitRoot() { - extend_.clear(); - state_.left.full = false; - state_.left.length = 0; - state_.right.length = 0; - right_full_ = false; - end_ = History(); + VertexNode() {} + + void InitRoot() { hypos_.clear(); } + + /* The steps of building a VertexNode: + * 1. Default construct. + * 2. AppendHypothesis at least once, possibly multiple times. + * 3. FinishAppending with the number of words on left and right guaranteed + * to be common. + * 4. If !Complete(), call BuildExtend to construct the extensions + */ + // Must default construct, call AppendHypothesis 1 or more times then do FinishedAppending. + void AppendHypothesis(const NBestComplete &best) { + assert(hypos_.empty() || !(hypos_.front().state == *best.state)); + HypoState hypo; + hypo.history = best.history; + hypo.state = *best.state; + hypo.score = best.score; + hypos_.push_back(hypo); + } + void AppendHypothesis(const HypoState &hypo) { + hypos_.push_back(hypo); } - lm::ngram::ChartState &MutableState() { return state_; } - bool &MutableRightFull() { return right_full_; } + // Sort hypotheses for the root. + void FinishRoot(); - void AddExtend(VertexNode *next) { - extend_.push_back(next); - } + void FinishedAppending(const unsigned char common_left, const unsigned char common_right); - void SetEnd(History end, Score score) { - assert(!end_); - end_ = end; - bound_ = score; - } - - void SortAndSet(ContextBase &context); + void BuildExtend(); // Should only happen to a root node when the entire vertex is empty. bool Empty() const { - return !end_ && extend_.empty(); + return hypos_.empty() && extend_.empty(); } bool Complete() const { - return end_; + // HACK: prevent root from being complete. TODO: allow root to be complete. + return hypos_.size() == 1 && extend_.empty(); } const lm::ngram::ChartState &State() const { return state_; } bool RightFull() const { return right_full_; } + // Priority relative to other non-terminals. 0 is highest. + unsigned char Niceness() const { return niceness_; } + Score Bound() const { return bound_; } - unsigned char Length() const { - return state_.left.length + state_.right.length; - } - // Will be invalid unless this is a leaf. - History End() const { return end_; } + History End() const { + assert(hypos_.size() == 1); + return hypos_.front().history; + } - const VertexNode &operator[](size_t index) const { - return *extend_[index]; + VertexNode &operator[](size_t index) { + assert(!extend_.empty()); + return extend_[index]; } size_t Size() const { @@ -76,22 +91,26 @@ class VertexNode { } private: - void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent); + // Hypotheses to be split. + std::vector<HypoState> hypos_; - std::vector<VertexNode*> extend_; + std::vector<VertexNode> extend_; lm::ngram::ChartState state_; bool right_full_; + unsigned char niceness_; + + unsigned char policy_; + Score bound_; - History end_; }; class PartialVertex { public: PartialVertex() {} - explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {} + explicit PartialVertex(VertexNode &back) : back_(&back), index_(0) {} bool Empty() const { return back_->Empty(); } @@ -100,17 +119,14 @@ class PartialVertex { const lm::ngram::ChartState &State() const { return back_->State(); } bool RightFull() const { return back_->RightFull(); } - Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); } - - unsigned char Length() const { return back_->Length(); } + Score Bound() const { return index_ ? (*back_)[index_].Bound() : back_->Bound(); } - bool HasAlternative() const { - return index_ + 1 < back_->Size(); - } + unsigned char Niceness() const { return back_->Niceness(); } // Split into continuation and alternative, rendering this the continuation. bool Split(PartialVertex &alternative) { assert(!Complete()); + back_->BuildExtend(); bool ret; if (index_ + 1 < back_->Size()) { alternative.index_ = index_ + 1; @@ -129,7 +145,7 @@ class PartialVertex { } private: - const VertexNode *back_; + VertexNode *back_; unsigned int index_; }; @@ -139,10 +155,21 @@ class Vertex { public: Vertex() {} - PartialVertex RootPartial() const { return PartialVertex(root_); } + //PartialVertex RootFirst() const { return PartialVertex(right_); } + PartialVertex RootAlternate() { return PartialVertex(root_); } + //PartialVertex RootLast() const { return PartialVertex(left_); } + + bool Empty() const { + return root_.Empty(); + } + + Score Bound() const { + return root_.Bound(); + } - History BestChild() const { - PartialVertex top(RootPartial()); + History BestChild() { + // left_ and right_ are not set at the root. + PartialVertex top(RootAlternate()); if (top.Empty()) { return History(); } else { @@ -158,6 +185,12 @@ class Vertex { template <class Output> friend class VertexGenerator; template <class Output> friend class RootVertexGenerator; VertexNode root_; + + // These will not be set for the root vertex. + // Branches only on left state. + //VertexNode left_; + // Branches only on right state. + //VertexNode right_; }; } // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 646b8189..91000012 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -4,10 +4,8 @@ #include "search/edge.hh" #include "search/types.hh" #include "search/vertex.hh" -#include "util/exception.hh" #include <boost/unordered_map.hpp> -#include <boost/version.hpp> namespace lm { namespace ngram { @@ -19,45 +17,25 @@ namespace search { class ContextBase; -#if BOOST_VERSION > 104200 -// Parallel structure to VertexNode. -struct Trie { - Trie() : under(NULL) {} - - VertexNode *under; - boost::unordered_map<uint64_t, Trie> extend; -}; - -void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end); - -#endif // BOOST_VERSION - // Output makes the single-best or n-best list. template <class Output> class VertexGenerator { public: - VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) { - gen.root_.InitRoot(); - } + VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) {} void NewHypothesis(PartialEdge partial) { nbest_.Add(existing_[hash_value(partial.CompletedState())], partial); } void FinishedSearch() { -#if BOOST_VERSION > 104200 - Trie root; - root.under = &gen_.root_; + gen_.root_.InitRoot(); for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) { - AddHypothesis(context_, root, nbest_.Complete(i->second)); + gen_.root_.AppendHypothesis(nbest_.Complete(i->second)); } existing_.clear(); - root.under->SortAndSet(context_); -#else - UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search."); -#endif + gen_.root_.FinishRoot(); } - const Vertex &Generating() const { return gen_; } + Vertex &Generating() { return gen_; } private: ContextBase &context_; @@ -84,8 +62,8 @@ template <class Output> class RootVertexGenerator { void FinishedSearch() { gen_.root_.InitRoot(); - NBestComplete completed(out_.Complete(combine_)); - gen_.root_.SetEnd(completed.history, completed.score); + gen_.root_.AppendHypothesis(out_.Complete(combine_)); + gen_.root_.FinishRoot(); } private: |