summaryrefslogtreecommitdiff
path: root/klm/search
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2013-06-18 11:34:20 -0700
committerKenneth Heafield <github@kheafield.com>2013-06-18 11:34:20 -0700
commit8dc383a74c12d44ab3f51947575ed5828653f4f1 (patch)
tree30371d17e6c2daf3837068b2617e357ad6b37d89 /klm/search
parent354787aa16539702802b9ea075c4bd8a72071035 (diff)
lazy dd880b4 including kenlm 6eef0f1
Diffstat (limited to 'klm/search')
-rw-r--r--klm/search/Makefile.am17
-rw-r--r--klm/search/context.hh12
-rw-r--r--klm/search/edge_generator.cc12
-rw-r--r--klm/search/vertex.cc204
-rw-r--r--klm/search/vertex.hh121
-rw-r--r--klm/search/vertex_generator.hh36
6 files changed, 271 insertions, 131 deletions
diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am
index 03554276..b8c8a050 100644
--- a/klm/search/Makefile.am
+++ b/klm/search/Makefile.am
@@ -1,23 +1,10 @@
noinst_LIBRARIES = libksearch.a
libksearch_a_SOURCES = \
- applied.hh \
- config.hh \
- context.hh \
- dedupe.hh \
- edge.hh \
- edge_generator.hh \
- header.hh \
- nbest.hh \
- rule.hh \
- types.hh \
- vertex.hh \
- vertex_generator.hh \
edge_generator.cc \
nbest.cc \
rule.cc \
- vertex.cc \
- vertex_generator.cc
+ vertex.cc
-AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/klm
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I..
diff --git a/klm/search/context.hh b/klm/search/context.hh
index 08f21bbf..c3c8e53b 100644
--- a/klm/search/context.hh
+++ b/klm/search/context.hh
@@ -12,16 +12,6 @@ class ContextBase {
public:
explicit ContextBase(const Config &config) : config_(config) {}
- VertexNode *NewVertexNode() {
- VertexNode *ret = vertex_node_pool_.construct();
- assert(ret);
- return ret;
- }
-
- void DeleteVertexNode(VertexNode *node) {
- vertex_node_pool_.destroy(node);
- }
-
unsigned int PopLimit() const { return config_.PopLimit(); }
Score LMWeight() const { return config_.LMWeight(); }
@@ -29,8 +19,6 @@ class ContextBase {
const Config &GetConfig() const { return config_; }
private:
- boost::object_pool<VertexNode> vertex_node_pool_;
-
Config config_;
};
diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc
index eacf5de5..dd9d61e4 100644
--- a/klm/search/edge_generator.cc
+++ b/klm/search/edge_generator.cc
@@ -54,20 +54,20 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) {
Arity victim = 0;
Arity victim_completed;
Arity incomplete;
+ unsigned char lowest_niceness = 255;
// Select victim or return if complete.
{
Arity completed = 0;
- unsigned char lowest_length = 255;
for (Arity i = 0; i != arity; ++i) {
if (top_nt[i].Complete()) {
++completed;
- } else if (top_nt[i].Length() < lowest_length) {
- lowest_length = top_nt[i].Length();
+ } else if (top_nt[i].Niceness() < lowest_niceness) {
+ lowest_niceness = top_nt[i].Niceness();
victim = i;
victim_completed = completed;
}
}
- if (lowest_length == 255) {
+ if (lowest_niceness == 255) {
return top;
}
incomplete = arity - completed;
@@ -92,10 +92,14 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) {
generate_.push(alternate);
}
+#ifndef NDEBUG
+ Score before = top.GetScore();
+#endif
// top is now the continuation.
FastScore(context, victim, victim - victim_completed, incomplete, old_value, top);
// TODO: dedupe?
generate_.push(top);
+ assert(lowest_niceness != 254 || top.GetScore() == before);
// Invalid indicates no new hypothesis generated.
return PartialEdge();
diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc
index 45842982..bf40810e 100644
--- a/klm/search/vertex.cc
+++ b/klm/search/vertex.cc
@@ -2,6 +2,8 @@
#include "search/context.hh"
+#include <boost/unordered_map.hpp>
+
#include <algorithm>
#include <functional>
@@ -11,45 +13,193 @@ namespace search {
namespace {
-struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> {
- bool operator()(const VertexNode *first, const VertexNode *second) const {
- return first->Bound() > second->Bound();
+const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
+
+class DivideLeft {
+ public:
+ explicit DivideLeft(unsigned char index)
+ : index_(index) {}
+
+ uint64_t operator()(const lm::ngram::ChartState &state) const {
+ return (index_ < state.left.length) ?
+ state.left.pointers[index_] :
+ (kCompleteAdd - state.left.full);
+ }
+
+ private:
+ unsigned char index_;
+};
+
+class DivideRight {
+ public:
+ explicit DivideRight(unsigned char index)
+ : index_(index) {}
+
+ uint64_t operator()(const lm::ngram::ChartState &state) const {
+ return (index_ < state.right.length) ?
+ static_cast<uint64_t>(state.right.words[index_]) :
+ (kCompleteAdd - state.left.full);
+ }
+
+ private:
+ unsigned char index_;
+};
+
+template <class Divider> void Split(const Divider &divider, const std::vector<HypoState> &hypos, std::vector<VertexNode> &extend) {
+ // Map from divider to index in extend.
+ typedef boost::unordered_map<uint64_t, std::size_t> Lookup;
+ Lookup lookup;
+ for (std::vector<HypoState>::const_iterator i = hypos.begin(); i != hypos.end(); ++i) {
+ uint64_t key = divider(i->state);
+ std::pair<Lookup::iterator, bool> res(lookup.insert(std::make_pair(key, extend.size())));
+ if (res.second) {
+ extend.resize(extend.size() + 1);
+ extend.back().AppendHypothesis(*i);
+ } else {
+ extend[res.first->second].AppendHypothesis(*i);
+ }
}
+ //assert((extend.size() != 1) || (hypos.size() == 1));
+}
+
+lm::WordIndex Identify(const lm::ngram::Right &right, unsigned char index) {
+ return right.words[index];
+}
+
+uint64_t Identify(const lm::ngram::Left &left, unsigned char index) {
+ return left.pointers[index];
+}
+
+template <class Side> class DetermineSame {
+ public:
+ DetermineSame(const Side &side, unsigned char guaranteed)
+ : side_(side), guaranteed_(guaranteed), shared_(side.length), complete_(true) {}
+
+ void Consider(const Side &other) {
+ if (shared_ != other.length) {
+ complete_ = false;
+ if (shared_ > other.length)
+ shared_ = other.length;
+ }
+ for (unsigned char i = guaranteed_; i < shared_; ++i) {
+ if (Identify(side_, i) != Identify(other, i)) {
+ shared_ = i;
+ complete_ = false;
+ return;
+ }
+ }
+ }
+
+ unsigned char Shared() const { return shared_; }
+
+ bool Complete() const { return complete_; }
+
+ private:
+ const Side &side_;
+ unsigned char guaranteed_, shared_;
+ bool complete_;
};
+// Custom enum to save memory: valid values of policy_.
+// Alternate and there is still alternation to do.
+const unsigned char kPolicyAlternate = 0;
+// Branch based on left state only, because right ran out or this is a left tree.
+const unsigned char kPolicyOneLeft = 1;
+// Branch based on right state only.
+const unsigned char kPolicyOneRight = 2;
+// Reveal everything in the next branch. Used to terminate the left/right policies.
+// static const unsigned char kPolicyEverything = 3;
+
+} // namespace
+
+namespace {
+struct GreaterByScore : public std::binary_function<const HypoState &, const HypoState &, bool> {
+ bool operator()(const HypoState &first, const HypoState &second) const {
+ return first.score > second.score;
+ }
+};
} // namespace
-void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) {
- if (Complete()) {
- assert(end_);
- assert(extend_.empty());
- return;
+void VertexNode::FinishRoot() {
+ std::sort(hypos_.begin(), hypos_.end(), GreaterByScore());
+ extend_.clear();
+ // HACK: extend to one hypo so that root can be blank.
+ state_.left.full = false;
+ state_.left.length = 0;
+ state_.right.length = 0;
+ right_full_ = false;
+ niceness_ = 0;
+ policy_ = kPolicyAlternate;
+ if (hypos_.size() == 1) {
+ extend_.resize(1);
+ extend_.front().AppendHypothesis(hypos_.front());
+ extend_.front().FinishedAppending(0, 0);
+ }
+ if (hypos_.empty()) {
+ bound_ = -INFINITY;
+ } else {
+ bound_ = hypos_.front().score;
}
- if (extend_.size() == 1) {
- parent_ptr = extend_[0];
- extend_[0]->RecursiveSortAndSet(context, parent_ptr);
- context.DeleteVertexNode(this);
- return;
+}
+
+void VertexNode::FinishedAppending(const unsigned char common_left, const unsigned char common_right) {
+ assert(!hypos_.empty());
+ assert(extend_.empty());
+ bound_ = hypos_.front().score;
+ state_ = hypos_.front().state;
+ bool all_full = state_.left.full;
+ bool all_non_full = !state_.left.full;
+ DetermineSame<lm::ngram::Left> left(state_.left, common_left);
+ DetermineSame<lm::ngram::Right> right(state_.right, common_right);
+ for (std::vector<HypoState>::const_iterator i = hypos_.begin() + 1; i != hypos_.end(); ++i) {
+ all_full &= i->state.left.full;
+ all_non_full &= !i->state.left.full;
+ left.Consider(i->state.left);
+ right.Consider(i->state.right);
}
- for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
- (*i)->RecursiveSortAndSet(context, *i);
+ state_.left.full = all_full && left.Complete();
+ right_full_ = all_full && right.Complete();
+ state_.left.length = left.Shared();
+ state_.right.length = right.Shared();
+
+ if (!all_full && !all_non_full) {
+ policy_ = kPolicyAlternate;
+ } else if (left.Complete()) {
+ policy_ = kPolicyOneRight;
+ } else if (right.Complete()) {
+ policy_ = kPolicyOneLeft;
+ } else {
+ policy_ = kPolicyAlternate;
}
- std::sort(extend_.begin(), extend_.end(), GreaterByBound());
- bound_ = extend_.front()->Bound();
+ niceness_ = state_.left.length + state_.right.length;
}
-void VertexNode::SortAndSet(ContextBase &context) {
- // This is the root. The root might be empty.
- if (extend_.empty()) {
- bound_ = -INFINITY;
- return;
+void VertexNode::BuildExtend() {
+ // Already built.
+ if (!extend_.empty()) return;
+ // Nothing to build since this is a leaf.
+ if (hypos_.size() <= 1) return;
+ bool left_branch = true;
+ switch (policy_) {
+ case kPolicyAlternate:
+ left_branch = (state_.left.length <= state_.right.length);
+ break;
+ case kPolicyOneLeft:
+ left_branch = true;
+ break;
+ case kPolicyOneRight:
+ left_branch = false;
+ break;
+ }
+ if (left_branch) {
+ Split(DivideLeft(state_.left.length), hypos_, extend_);
+ } else {
+ Split(DivideRight(state_.right.length), hypos_, extend_);
}
- // The root cannot be replaced. There's always one transition.
- for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
- (*i)->RecursiveSortAndSet(context, *i);
+ for (std::vector<VertexNode>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
+ // TODO: provide more here for branching?
+ i->FinishedAppending(state_.left.length, state_.right.length);
}
- std::sort(extend_.begin(), extend_.end(), GreaterByBound());
- bound_ = extend_.front()->Bound();
}
} // namespace search
diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh
index ca9a4fcd..81c3cfed 100644
--- a/klm/search/vertex.hh
+++ b/klm/search/vertex.hh
@@ -16,59 +16,74 @@ namespace search {
class ContextBase;
+struct HypoState {
+ History history;
+ lm::ngram::ChartState state;
+ Score score;
+};
+
class VertexNode {
public:
- VertexNode() : end_() {}
-
- void InitRoot() {
- extend_.clear();
- state_.left.full = false;
- state_.left.length = 0;
- state_.right.length = 0;
- right_full_ = false;
- end_ = History();
+ VertexNode() {}
+
+ void InitRoot() { hypos_.clear(); }
+
+ /* The steps of building a VertexNode:
+ * 1. Default construct.
+ * 2. AppendHypothesis at least once, possibly multiple times.
+ * 3. FinishAppending with the number of words on left and right guaranteed
+ * to be common.
+ * 4. If !Complete(), call BuildExtend to construct the extensions
+ */
+ // Must default construct, call AppendHypothesis 1 or more times then do FinishedAppending.
+ void AppendHypothesis(const NBestComplete &best) {
+ assert(hypos_.empty() || !(hypos_.front().state == *best.state));
+ HypoState hypo;
+ hypo.history = best.history;
+ hypo.state = *best.state;
+ hypo.score = best.score;
+ hypos_.push_back(hypo);
+ }
+ void AppendHypothesis(const HypoState &hypo) {
+ hypos_.push_back(hypo);
}
- lm::ngram::ChartState &MutableState() { return state_; }
- bool &MutableRightFull() { return right_full_; }
+ // Sort hypotheses for the root.
+ void FinishRoot();
- void AddExtend(VertexNode *next) {
- extend_.push_back(next);
- }
+ void FinishedAppending(const unsigned char common_left, const unsigned char common_right);
- void SetEnd(History end, Score score) {
- assert(!end_);
- end_ = end;
- bound_ = score;
- }
-
- void SortAndSet(ContextBase &context);
+ void BuildExtend();
// Should only happen to a root node when the entire vertex is empty.
bool Empty() const {
- return !end_ && extend_.empty();
+ return hypos_.empty() && extend_.empty();
}
bool Complete() const {
- return end_;
+ // HACK: prevent root from being complete. TODO: allow root to be complete.
+ return hypos_.size() == 1 && extend_.empty();
}
const lm::ngram::ChartState &State() const { return state_; }
bool RightFull() const { return right_full_; }
+ // Priority relative to other non-terminals. 0 is highest.
+ unsigned char Niceness() const { return niceness_; }
+
Score Bound() const {
return bound_;
}
- unsigned char Length() const {
- return state_.left.length + state_.right.length;
- }
-
// Will be invalid unless this is a leaf.
- History End() const { return end_; }
+ History End() const {
+ assert(hypos_.size() == 1);
+ return hypos_.front().history;
+ }
- const VertexNode &operator[](size_t index) const {
- return *extend_[index];
+ VertexNode &operator[](size_t index) {
+ assert(!extend_.empty());
+ return extend_[index];
}
size_t Size() const {
@@ -76,22 +91,26 @@ class VertexNode {
}
private:
- void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent);
+ // Hypotheses to be split.
+ std::vector<HypoState> hypos_;
- std::vector<VertexNode*> extend_;
+ std::vector<VertexNode> extend_;
lm::ngram::ChartState state_;
bool right_full_;
+ unsigned char niceness_;
+
+ unsigned char policy_;
+
Score bound_;
- History end_;
};
class PartialVertex {
public:
PartialVertex() {}
- explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {}
+ explicit PartialVertex(VertexNode &back) : back_(&back), index_(0) {}
bool Empty() const { return back_->Empty(); }
@@ -100,17 +119,14 @@ class PartialVertex {
const lm::ngram::ChartState &State() const { return back_->State(); }
bool RightFull() const { return back_->RightFull(); }
- Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); }
-
- unsigned char Length() const { return back_->Length(); }
+ Score Bound() const { return index_ ? (*back_)[index_].Bound() : back_->Bound(); }
- bool HasAlternative() const {
- return index_ + 1 < back_->Size();
- }
+ unsigned char Niceness() const { return back_->Niceness(); }
// Split into continuation and alternative, rendering this the continuation.
bool Split(PartialVertex &alternative) {
assert(!Complete());
+ back_->BuildExtend();
bool ret;
if (index_ + 1 < back_->Size()) {
alternative.index_ = index_ + 1;
@@ -129,7 +145,7 @@ class PartialVertex {
}
private:
- const VertexNode *back_;
+ VertexNode *back_;
unsigned int index_;
};
@@ -139,10 +155,21 @@ class Vertex {
public:
Vertex() {}
- PartialVertex RootPartial() const { return PartialVertex(root_); }
+ //PartialVertex RootFirst() const { return PartialVertex(right_); }
+ PartialVertex RootAlternate() { return PartialVertex(root_); }
+ //PartialVertex RootLast() const { return PartialVertex(left_); }
+
+ bool Empty() const {
+ return root_.Empty();
+ }
+
+ Score Bound() const {
+ return root_.Bound();
+ }
- History BestChild() const {
- PartialVertex top(RootPartial());
+ History BestChild() {
+ // left_ and right_ are not set at the root.
+ PartialVertex top(RootAlternate());
if (top.Empty()) {
return History();
} else {
@@ -158,6 +185,12 @@ class Vertex {
template <class Output> friend class VertexGenerator;
template <class Output> friend class RootVertexGenerator;
VertexNode root_;
+
+ // These will not be set for the root vertex.
+ // Branches only on left state.
+ //VertexNode left_;
+ // Branches only on right state.
+ //VertexNode right_;
};
} // namespace search
diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh
index 646b8189..91000012 100644
--- a/klm/search/vertex_generator.hh
+++ b/klm/search/vertex_generator.hh
@@ -4,10 +4,8 @@
#include "search/edge.hh"
#include "search/types.hh"
#include "search/vertex.hh"
-#include "util/exception.hh"
#include <boost/unordered_map.hpp>
-#include <boost/version.hpp>
namespace lm {
namespace ngram {
@@ -19,45 +17,25 @@ namespace search {
class ContextBase;
-#if BOOST_VERSION > 104200
-// Parallel structure to VertexNode.
-struct Trie {
- Trie() : under(NULL) {}
-
- VertexNode *under;
- boost::unordered_map<uint64_t, Trie> extend;
-};
-
-void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end);
-
-#endif // BOOST_VERSION
-
// Output makes the single-best or n-best list.
template <class Output> class VertexGenerator {
public:
- VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) {
- gen.root_.InitRoot();
- }
+ VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) {}
void NewHypothesis(PartialEdge partial) {
nbest_.Add(existing_[hash_value(partial.CompletedState())], partial);
}
void FinishedSearch() {
-#if BOOST_VERSION > 104200
- Trie root;
- root.under = &gen_.root_;
+ gen_.root_.InitRoot();
for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) {
- AddHypothesis(context_, root, nbest_.Complete(i->second));
+ gen_.root_.AppendHypothesis(nbest_.Complete(i->second));
}
existing_.clear();
- root.under->SortAndSet(context_);
-#else
- UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search.");
-#endif
+ gen_.root_.FinishRoot();
}
- const Vertex &Generating() const { return gen_; }
+ Vertex &Generating() { return gen_; }
private:
ContextBase &context_;
@@ -84,8 +62,8 @@ template <class Output> class RootVertexGenerator {
void FinishedSearch() {
gen_.root_.InitRoot();
- NBestComplete completed(out_.Complete(combine_));
- gen_.root_.SetEnd(completed.history, completed.score);
+ gen_.root_.AppendHypothesis(out_.Complete(combine_));
+ gen_.root_.FinishRoot();
}
private: