summaryrefslogtreecommitdiff
path: root/klm/search
diff options
context:
space:
mode:
Diffstat (limited to 'klm/search')
-rw-r--r--klm/search/Jamfile5
-rw-r--r--klm/search/arity.hh8
-rw-r--r--klm/search/config.hh25
-rw-r--r--klm/search/context.hh66
-rw-r--r--klm/search/edge.hh54
-rw-r--r--klm/search/edge_generator.cc129
-rw-r--r--klm/search/edge_generator.hh54
-rw-r--r--klm/search/final.hh40
-rw-r--r--klm/search/rule.cc55
-rw-r--r--klm/search/rule.hh60
-rw-r--r--klm/search/source.hh48
-rw-r--r--klm/search/types.hh18
-rw-r--r--klm/search/vertex.cc48
-rw-r--r--klm/search/vertex.hh165
-rw-r--r--klm/search/vertex_generator.cc99
-rw-r--r--klm/search/vertex_generator.hh70
-rw-r--r--klm/search/weights.cc69
-rw-r--r--klm/search/weights.hh49
-rw-r--r--klm/search/weights_test.cc38
-rw-r--r--klm/search/word.hh47
20 files changed, 1147 insertions, 0 deletions
diff --git a/klm/search/Jamfile b/klm/search/Jamfile
new file mode 100644
index 00000000..ac47c249
--- /dev/null
+++ b/klm/search/Jamfile
@@ -0,0 +1,5 @@
+lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil : : : <include>.. ;
+
+import testing ;
+
+unit-test weights_test : weights_test.cc search /top//boost_unit_test_framework ;
diff --git a/klm/search/arity.hh b/klm/search/arity.hh
new file mode 100644
index 00000000..09c2c671
--- /dev/null
+++ b/klm/search/arity.hh
@@ -0,0 +1,8 @@
+#ifndef SEARCH_ARITY__
+#define SEARCH_ARITY__
+namespace search {
+
+const unsigned int kMaxArity = 2;
+
+} // namespace search
+#endif // SEARCH_ARITY__
diff --git a/klm/search/config.hh b/klm/search/config.hh
new file mode 100644
index 00000000..e21e4b7c
--- /dev/null
+++ b/klm/search/config.hh
@@ -0,0 +1,25 @@
+#ifndef SEARCH_CONFIG__
+#define SEARCH_CONFIG__
+
+#include "search/weights.hh"
+#include "util/string_piece.hh"
+
+namespace search {
+
+class Config {
+ public:
+ Config(StringPiece weight_str, unsigned int pop_limit) :
+ weights_(weight_str), pop_limit_(pop_limit) {}
+
+ const Weights &GetWeights() const { return weights_; }
+
+ unsigned int PopLimit() const { return pop_limit_; }
+
+ private:
+ search::Weights weights_;
+ unsigned int pop_limit_;
+};
+
+} // namespace search
+
+#endif // SEARCH_CONFIG__
diff --git a/klm/search/context.hh b/klm/search/context.hh
new file mode 100644
index 00000000..ae248549
--- /dev/null
+++ b/klm/search/context.hh
@@ -0,0 +1,66 @@
+#ifndef SEARCH_CONTEXT__
+#define SEARCH_CONTEXT__
+
+#include "lm/model.hh"
+#include "search/config.hh"
+#include "search/final.hh"
+#include "search/types.hh"
+#include "search/vertex.hh"
+#include "search/word.hh"
+#include "util/exception.hh"
+
+#include <boost/pool/object_pool.hpp>
+#include <boost/ptr_container/ptr_vector.hpp>
+
+#include <vector>
+
+namespace search {
+
+class Weights;
+
+class ContextBase {
+ public:
+ explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {}
+
+ Final *NewFinal() {
+ Final *ret = final_pool_.construct();
+ assert(ret);
+ return ret;
+ }
+
+ VertexNode *NewVertexNode() {
+ VertexNode *ret = vertex_node_pool_.construct();
+ assert(ret);
+ return ret;
+ }
+
+ void DeleteVertexNode(VertexNode *node) {
+ vertex_node_pool_.destroy(node);
+ }
+
+ unsigned int PopLimit() const { return pop_limit_; }
+
+ const Weights &GetWeights() const { return weights_; }
+
+ private:
+ boost::object_pool<Final> final_pool_;
+ boost::object_pool<VertexNode> vertex_node_pool_;
+
+ unsigned int pop_limit_;
+
+ const Weights &weights_;
+};
+
+template <class Model> class Context : public ContextBase {
+ public:
+ Context(const Config &config, const Model &model) : ContextBase(config), model_(model) {}
+
+ const Model &LanguageModel() const { return model_; }
+
+ private:
+ const Model &model_;
+};
+
+} // namespace search
+
+#endif // SEARCH_CONTEXT__
diff --git a/klm/search/edge.hh b/klm/search/edge.hh
new file mode 100644
index 00000000..4d2a5cbf
--- /dev/null
+++ b/klm/search/edge.hh
@@ -0,0 +1,54 @@
+#ifndef SEARCH_EDGE__
+#define SEARCH_EDGE__
+
+#include "lm/state.hh"
+#include "search/arity.hh"
+#include "search/rule.hh"
+#include "search/types.hh"
+#include "search/vertex.hh"
+
+#include <queue>
+
+namespace search {
+
+class Edge {
+ public:
+ Edge() {
+ end_to_ = to_;
+ }
+
+ Rule &InitRule() { return rule_; }
+
+ void Add(Vertex &vertex) {
+ assert(end_to_ - to_ < kMaxArity);
+ *(end_to_++) = &vertex;
+ }
+
+ const Vertex &GetVertex(std::size_t index) const {
+ return *to_[index];
+ }
+
+ const Rule &GetRule() const { return rule_; }
+
+ private:
+ // Rule and pointers to rule arguments.
+ Rule rule_;
+
+ Vertex *to_[kMaxArity];
+ Vertex **end_to_;
+};
+
+struct PartialEdge {
+ Score score;
+ // Terminals
+ lm::ngram::ChartState between[kMaxArity + 1];
+ // Non-terminals
+ PartialVertex nt[kMaxArity];
+
+ bool operator<(const PartialEdge &other) const {
+ return score < other.score;
+ }
+};
+
+} // namespace search
+#endif // SEARCH_EDGE__
diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc
new file mode 100644
index 00000000..d135899a
--- /dev/null
+++ b/klm/search/edge_generator.cc
@@ -0,0 +1,129 @@
+#include "search/edge_generator.hh"
+
+#include "lm/left.hh"
+#include "lm/partial.hh"
+#include "search/context.hh"
+#include "search/vertex.hh"
+#include "search/vertex_generator.hh"
+
+#include <numeric>
+
+namespace search {
+
+bool EdgeGenerator::Init(Edge &edge, VertexGenerator &parent) {
+ from_ = &edge;
+ for (unsigned int i = 0; i < GetRule().Arity(); ++i) {
+ if (edge.GetVertex(i).RootPartial().Empty()) return false;
+ }
+ PartialEdge &root = *parent.MallocPartialEdge();
+ root.score = GetRule().Bound();
+ for (unsigned int i = 0; i < GetRule().Arity(); ++i) {
+ root.nt[i] = edge.GetVertex(i).RootPartial();
+ root.score += root.nt[i].Bound();
+ }
+ for (unsigned int i = GetRule().Arity(); i < 2; ++i) {
+ root.nt[i] = kBlankPartialVertex;
+ }
+ for (unsigned int i = 0; i < GetRule().Arity() + 1; ++i) {
+ root.between[i] = GetRule().Lexical(i);
+ }
+ // wtf no clear method?
+ generate_ = Generate();
+ generate_.push(&root);
+ top_ = root.score;
+ return true;
+}
+
+namespace {
+
+template <class Model> float FastScore(const Context<Model> &context, unsigned char victim, unsigned char arity, const PartialEdge &previous, PartialEdge &update) {
+ memcpy(update.between, previous.between, sizeof(lm::ngram::ChartState) * (arity + 1));
+
+ float ret = 0.0;
+ lm::ngram::ChartState *before, *after;
+ if (victim == 0) {
+ before = &update.between[0];
+ after = &update.between[(arity == 2 && previous.nt[1].Complete()) ? 2 : 1];
+ } else {
+ assert(victim == 1);
+ assert(arity == 2);
+ before = &update.between[previous.nt[0].Complete() ? 0 : 1];
+ after = &update.between[2];
+ }
+ const lm::ngram::ChartState &previous_reveal = previous.nt[victim].State();
+ const PartialVertex &update_nt = update.nt[victim];
+ const lm::ngram::ChartState &update_reveal = update_nt.State();
+ float just_after = 0.0;
+ if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) {
+ just_after += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length);
+ }
+ if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous.nt[victim].RightFull())) {
+ ret += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right);
+ }
+ if (update_nt.Complete()) {
+ if (update_reveal.left.full) {
+ before->left.full = true;
+ } else {
+ assert(update_reveal.left.length == update_reveal.right.length);
+ ret += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length);
+ }
+ if (victim == 0) {
+ update.between[0].right = after->right;
+ } else {
+ update.between[2].left = before->left;
+ }
+ }
+ return previous.score + (ret + just_after) * context.GetWeights().LM();
+}
+
+} // namespace
+
+template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGenerator &parent) {
+ assert(!generate_.empty());
+ PartialEdge &top = *generate_.top();
+ generate_.pop();
+ unsigned int victim = 0;
+ unsigned char lowest_length = 255;
+ for (unsigned int i = 0; i != GetRule().Arity(); ++i) {
+ if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) {
+ lowest_length = top.nt[i].Length();
+ victim = i;
+ }
+ }
+ if (lowest_length == 255) {
+ // All states report complete.
+ top.between[0].right = top.between[GetRule().Arity()].right;
+ parent.NewHypothesis(top.between[0], *from_, top);
+ top_ = generate_.empty() ? -kScoreInf : generate_.top()->score;
+ return !generate_.empty();
+ }
+
+ unsigned int stay = !victim;
+ PartialEdge &continuation = *parent.MallocPartialEdge();
+ float old_bound = top.nt[victim].Bound();
+ // The alternate's score will change because alternate.nt[victim] changes.
+ bool split = top.nt[victim].Split(continuation.nt[victim]);
+ // top is now the alternate.
+
+ continuation.nt[stay] = top.nt[stay];
+ continuation.score = FastScore(context, victim, GetRule().Arity(), top, continuation);
+ // TODO: dedupe?
+ generate_.push(&continuation);
+
+ if (split) {
+ // We have an alternate.
+ top.score += top.nt[victim].Bound() - old_bound;
+ // TODO: dedupe?
+ generate_.push(&top);
+ } else {
+ parent.FreePartialEdge(&top);
+ }
+
+ top_ = generate_.top()->score;
+ return true;
+}
+
+template bool EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, VertexGenerator &parent);
+template bool EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, VertexGenerator &parent);
+
+} // namespace search
diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh
new file mode 100644
index 00000000..e306dc61
--- /dev/null
+++ b/klm/search/edge_generator.hh
@@ -0,0 +1,54 @@
+#ifndef SEARCH_EDGE_GENERATOR__
+#define SEARCH_EDGE_GENERATOR__
+
+#include "search/edge.hh"
+
+#include <boost/unordered_map.hpp>
+
+#include <functional>
+#include <queue>
+
+namespace lm {
+namespace ngram {
+class ChartState;
+} // namespace ngram
+} // namespace lm
+
+namespace search {
+
+template <class Model> class Context;
+
+class VertexGenerator;
+
+struct PartialEdgePointerLess : std::binary_function<const PartialEdge *, const PartialEdge *, bool> {
+ bool operator()(const PartialEdge *first, const PartialEdge *second) const {
+ return *first < *second;
+ }
+};
+
+class EdgeGenerator {
+ public:
+ // True if it has a hypothesis.
+ bool Init(Edge &edge, VertexGenerator &parent);
+
+ Score Top() const {
+ return top_;
+ }
+
+ template <class Model> bool Pop(Context<Model> &context, VertexGenerator &parent);
+
+ private:
+ const Rule &GetRule() const {
+ return from_->GetRule();
+ }
+
+ Score top_;
+
+ typedef std::priority_queue<PartialEdge*, std::vector<PartialEdge*>, PartialEdgePointerLess> Generate;
+ Generate generate_;
+
+ Edge *from_;
+};
+
+} // namespace search
+#endif // SEARCH_EDGE_GENERATOR__
diff --git a/klm/search/final.hh b/klm/search/final.hh
new file mode 100644
index 00000000..24e6f0a5
--- /dev/null
+++ b/klm/search/final.hh
@@ -0,0 +1,40 @@
+#ifndef SEARCH_FINAL__
+#define SEARCH_FINAL__
+
+#include "search/rule.hh"
+#include "search/types.hh"
+
+#include <boost/array.hpp>
+
+namespace search {
+
+class Final {
+ public:
+ typedef boost::array<const Final*, search::kMaxArity> ChildArray;
+
+ void Reset(Score bound, const Rule &from, const Final &left, const Final &right) {
+ bound_ = bound;
+ from_ = &from;
+ children_[0] = &left;
+ children_[1] = &right;
+ }
+
+ const ChildArray &Children() const { return children_; }
+
+ unsigned int ChildCount() const { return from_->Arity(); }
+
+ const Rule &From() const { return *from_; }
+
+ Score Bound() const { return bound_; }
+
+ private:
+ Score bound_;
+
+ const Rule *from_;
+
+ ChildArray children_;
+};
+
+} // namespace search
+
+#endif // SEARCH_FINAL__
diff --git a/klm/search/rule.cc b/klm/search/rule.cc
new file mode 100644
index 00000000..a8b993eb
--- /dev/null
+++ b/klm/search/rule.cc
@@ -0,0 +1,55 @@
+#include "search/rule.hh"
+
+#include "search/context.hh"
+#include "search/final.hh"
+
+#include <ostream>
+
+#include <math.h>
+
+namespace search {
+
+template <class Model> void Rule::FinishedAdding(const Context<Model> &context, Score additive, bool prepend_bos) {
+ additive_ = additive;
+ Score lm_score = 0.0;
+ lexical_.clear();
+ const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound();
+
+ for (std::vector<Word>::const_iterator word = items_.begin(); ; ++word) {
+ lexical_.resize(lexical_.size() + 1);
+ lm::ngram::RuleScore<Model> scorer(context.LanguageModel(), lexical_.back());
+ // TODO: optimize
+ if (prepend_bos && (word == items_.begin())) {
+ scorer.BeginSentence();
+ }
+ for (; ; ++word) {
+ if (word == items_.end()) {
+ lm_score += scorer.Finish();
+ bound_ = additive_ + context.GetWeights().LM() * lm_score;
+ assert(lexical_.size() == arity_ + 1);
+ return;
+ }
+ if (!word->Terminal()) break;
+ if (word->Index() == oov) additive_ += context.GetWeights().OOV();
+ scorer.Terminal(word->Index());
+ }
+ lm_score += scorer.Finish();
+ }
+}
+
+template void Rule::FinishedAdding(const Context<lm::ngram::RestProbingModel> &context, Score additive, bool prepend_bos);
+template void Rule::FinishedAdding(const Context<lm::ngram::ProbingModel> &context, Score additive, bool prepend_bos);
+
+std::ostream &operator<<(std::ostream &o, const Rule &rule) {
+ const Rule::ItemsRet &items = rule.Items();
+ for (Rule::ItemsRet::const_iterator i = items.begin(); i != items.end(); ++i) {
+ if (i->Terminal()) {
+ o << i->String() << ' ';
+ } else {
+ o << "[] ";
+ }
+ }
+ return o;
+}
+
+} // namespace search
diff --git a/klm/search/rule.hh b/klm/search/rule.hh
new file mode 100644
index 00000000..79192d40
--- /dev/null
+++ b/klm/search/rule.hh
@@ -0,0 +1,60 @@
+#ifndef SEARCH_RULE__
+#define SEARCH_RULE__
+
+#include "lm/left.hh"
+#include "search/arity.hh"
+#include "search/types.hh"
+#include "search/word.hh"
+
+#include <boost/array.hpp>
+
+#include <iosfwd>
+#include <vector>
+
+namespace search {
+
+template <class Model> class Context;
+
+class Rule {
+ public:
+ Rule() : arity_(0) {}
+
+ void AppendTerminal(Word w) { items_.push_back(w); }
+
+ void AppendNonTerminal() {
+ items_.resize(items_.size() + 1);
+ ++arity_;
+ }
+
+ template <class Model> void FinishedAdding(const Context<Model> &context, Score additive, bool prepend_bos);
+
+ Score Bound() const { return bound_; }
+
+ Score Additive() const { return additive_; }
+
+ unsigned int Arity() const { return arity_; }
+
+ const lm::ngram::ChartState &Lexical(unsigned int index) const {
+ return lexical_[index];
+ }
+
+ // For printing.
+ typedef const std::vector<Word> ItemsRet;
+ ItemsRet &Items() const { return items_; }
+
+ private:
+ Score bound_, additive_;
+
+ unsigned int arity_;
+
+ // TODO: pool?
+ std::vector<Word> items_;
+
+ std::vector<lm::ngram::ChartState> lexical_;
+};
+
+std::ostream &operator<<(std::ostream &o, const Rule &rule);
+
+} // namespace search
+
+#endif // SEARCH_RULE__
diff --git a/klm/search/source.hh b/klm/search/source.hh
new file mode 100644
index 00000000..11839f7b
--- /dev/null
+++ b/klm/search/source.hh
@@ -0,0 +1,48 @@
+#ifndef SEARCH_SOURCE__
+#define SEARCH_SOURCE__
+
+#include "search/types.hh"
+
+#include <assert.h>
+#include <vector>
+
+namespace search {
+
+template <class Final> class Source {
+ public:
+ Source() : bound_(kScoreInf) {}
+
+ Index Size() const {
+ return final_.size();
+ }
+
+ Score Bound() const {
+ return bound_;
+ }
+
+ const Final &operator[](Index index) const {
+ return *final_[index];
+ }
+
+ Score ScoreOrBound(Index index) const {
+ return Size() > index ? final_[index]->Total() : Bound();
+ }
+
+ protected:
+ void AddFinal(const Final &store) {
+ final_.push_back(&store);
+ }
+
+ void SetBound(Score to) {
+ assert(to <= bound_ + 0.001);
+ bound_ = to;
+ }
+
+ private:
+ std::vector<const Final *> final_;
+
+ Score bound_;
+};
+
+} // namespace search
+#endif // SEARCH_SOURCE__
diff --git a/klm/search/types.hh b/klm/search/types.hh
new file mode 100644
index 00000000..9726379f
--- /dev/null
+++ b/klm/search/types.hh
@@ -0,0 +1,18 @@
+#ifndef SEARCH_TYPES__
+#define SEARCH_TYPES__
+
+#include <cmath>
+
+namespace search {
+
+typedef float Score;
+const Score kScoreInf = INFINITY;
+
+// This could have been an enum but gcc wants 4 bytes.
+typedef bool ExtendDirection;
+const ExtendDirection kExtendLeft = 0;
+const ExtendDirection kExtendRight = 1;
+
+} // namespace search
+
+#endif // SEARCH_TYPES__
diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc
new file mode 100644
index 00000000..cc53c0dd
--- /dev/null
+++ b/klm/search/vertex.cc
@@ -0,0 +1,48 @@
+#include "search/vertex.hh"
+
+#include "search/context.hh"
+
+#include <algorithm>
+#include <functional>
+
+#include <assert.h>
+
+namespace search {
+
+namespace {
+
+struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> {
+ bool operator()(const VertexNode *first, const VertexNode *second) const {
+ return first->Bound() > second->Bound();
+ }
+};
+
+} // namespace
+
+void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) {
+ if (Complete()) {
+ assert(end_);
+ assert(extend_.empty());
+ bound_ = end_->Bound();
+ return;
+ }
+ if (extend_.size() == 1 && parent_ptr) {
+ *parent_ptr = extend_[0];
+ extend_[0]->SortAndSet(context, parent_ptr);
+ context.DeleteVertexNode(this);
+ return;
+ }
+ for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
+ (*i)->SortAndSet(context, &*i);
+ }
+ std::sort(extend_.begin(), extend_.end(), GreaterByBound());
+ bound_ = extend_.front()->Bound();
+}
+
+namespace {
+VertexNode kBlankVertexNode;
+} // namespace
+
+PartialVertex kBlankPartialVertex(kBlankVertexNode);
+
+} // namespace search
diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh
new file mode 100644
index 00000000..7ef29efc
--- /dev/null
+++ b/klm/search/vertex.hh
@@ -0,0 +1,165 @@
+#ifndef SEARCH_VERTEX__
+#define SEARCH_VERTEX__
+
+#include "lm/left.hh"
+#include "search/final.hh"
+#include "search/types.hh"
+
+#include <boost/unordered_set.hpp>
+
+#include <queue>
+#include <vector>
+
+#include <stdint.h>
+
+namespace search {
+
+class ContextBase;
+
+class Edge;
+
+class VertexNode {
+ public:
+ VertexNode() : end_(NULL) {}
+
+ void InitRoot() {
+ extend_.clear();
+ state_.left.full = false;
+ state_.left.length = 0;
+ state_.right.length = 0;
+ right_full_ = false;
+ bound_ = -kScoreInf;
+ end_ = NULL;
+ }
+
+ lm::ngram::ChartState &MutableState() { return state_; }
+ bool &MutableRightFull() { return right_full_; }
+
+ void AddExtend(VertexNode *next) {
+ extend_.push_back(next);
+ }
+
+ void SetEnd(Final *end) { end_ = end; }
+
+ Final &MutableEnd() { return *end_; }
+
+ void SortAndSet(ContextBase &context, VertexNode **parent_pointer);
+
+ // Should only happen to a root node when the entire vertex is empty.
+ bool Empty() const {
+ return !end_ && extend_.empty();
+ }
+
+ bool Complete() const {
+ return end_;
+ }
+
+ const lm::ngram::ChartState &State() const { return state_; }
+ bool RightFull() const { return right_full_; }
+
+ Score Bound() const {
+ return bound_;
+ }
+
+ unsigned char Length() const {
+ return state_.left.length + state_.right.length;
+ }
+
+ // May be NULL.
+ const Final *End() const { return end_; }
+
+ const VertexNode &operator[](size_t index) const {
+ return *extend_[index];
+ }
+
+ size_t Size() const {
+ return extend_.size();
+ }
+
+ private:
+ std::vector<VertexNode*> extend_;
+
+ lm::ngram::ChartState state_;
+ bool right_full_;
+
+ Score bound_;
+ Final *end_;
+};
+
+class PartialVertex {
+ public:
+ PartialVertex() {}
+
+ explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {}
+
+ bool Empty() const { return back_->Empty(); }
+
+ bool Complete() const { return back_->Complete(); }
+
+ const lm::ngram::ChartState &State() const { return back_->State(); }
+ bool RightFull() const { return back_->RightFull(); }
+
+ Score Bound() const { return Complete() ? back_->End()->Bound() : (*back_)[index_].Bound(); }
+
+ unsigned char Length() const { return back_->Length(); }
+
+ // Split into continuation and alternative, rendering this the alternative.
+ bool Split(PartialVertex &continuation) {
+ assert(!Complete());
+ continuation.back_ = &((*back_)[index_]);
+ continuation.index_ = 0;
+ if (index_ + 1 < back_->Size()) {
+ ++index_;
+ return true;
+ }
+ return false;
+ }
+
+ const Final &End() const {
+ return *back_->End();
+ }
+
+ private:
+ const VertexNode *back_;
+ unsigned int index_;
+};
+
+extern PartialVertex kBlankPartialVertex;
+
+class Vertex {
+ public:
+ Vertex()
+#ifdef DEBUG
+ : finished_adding_(false)
+#endif
+ {}
+
+ void Add(Edge &edge) {
+#ifdef DEBUG
+ assert(!finished_adding_);
+#endif
+ edges_.push_back(&edge);
+ }
+
+ void FinishedAdding() {
+#ifdef DEBUG
+ assert(!finished_adding_);
+ finished_adding_ = true;
+#endif
+ }
+
+ PartialVertex RootPartial() const { return PartialVertex(root_); }
+
+ private:
+ friend class VertexGenerator;
+ std::vector<Edge*> edges_;
+
+#ifdef DEBUG
+ bool finished_adding_;
+#endif
+
+ VertexNode root_;
+};
+
+} // namespace search
+#endif // SEARCH_VERTEX__
diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc
new file mode 100644
index 00000000..0281fc37
--- /dev/null
+++ b/klm/search/vertex_generator.cc
@@ -0,0 +1,99 @@
+#include "search/vertex_generator.hh"
+
+#include "lm/left.hh"
+#include "search/context.hh"
+
+#include <stdint.h>
+
+namespace search {
+
+template <class Model> VertexGenerator::VertexGenerator(Context<Model> &context, Vertex &gen) : context_(context), edges_(gen.edges_.size()), partial_edge_pool_(sizeof(PartialEdge), context.PopLimit() * 2) {
+ for (std::size_t i = 0; i < gen.edges_.size(); ++i) {
+ if (edges_[i].Init(*gen.edges_[i], *this))
+ generate_.push(&edges_[i]);
+ }
+ gen.root_.InitRoot();
+ root_.under = &gen.root_;
+ to_pop_ = context.PopLimit();
+ while (to_pop_ > 0 && !generate_.empty()) {
+ EdgeGenerator *top = generate_.top();
+ generate_.pop();
+ if (top->Pop(context, *this)) {
+ generate_.push(top);
+ }
+ }
+ gen.root_.SortAndSet(context, NULL);
+}
+
+template VertexGenerator::VertexGenerator(Context<lm::ngram::ProbingModel> &context, Vertex &gen);
+template VertexGenerator::VertexGenerator(Context<lm::ngram::RestProbingModel> &context, Vertex &gen);
+
+namespace {
+const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
+} // namespace
+
+void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) {
+ std::pair<Existing::iterator, bool> got(existing_.insert(std::pair<uint64_t, Final*>(hash_value(state), NULL)));
+ if (!got.second) {
+ // Found it already.
+ Final &exists = *got.first->second;
+ if (exists.Bound() < partial.score) {
+ exists.Reset(partial.score, from.GetRule(), partial.nt[0].End(), partial.nt[1].End());
+ }
+ --to_pop_;
+ return;
+ }
+ unsigned char left = 0, right = 0;
+ Trie *node = &root_;
+ while (true) {
+ if (left == state.left.length) {
+ node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, true, right, false);
+ for (; right < state.right.length; ++right) {
+ node = &FindOrInsert(*node, state.right.words[right], state, left, true, right + 1, false);
+ }
+ break;
+ }
+ node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, false);
+ left++;
+ if (right == state.right.length) {
+ node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, false, right, true);
+ for (; left < state.left.length; ++left) {
+ node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, true);
+ }
+ break;
+ }
+ node = &FindOrInsert(*node, state.right.words[right], state, left, false, right + 1, false);
+ right++;
+ }
+
+ node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true);
+ got.first->second = CompleteTransition(*node, state, from, partial);
+ --to_pop_;
+}
+
+VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {
+ VertexGenerator::Trie &next = node.extend[added];
+ if (!next.under) {
+ next.under = context_.NewVertexNode();
+ lm::ngram::ChartState &writing = next.under->MutableState();
+ writing = state;
+ writing.left.full &= left_full && state.left.full;
+ next.under->MutableRightFull() = right_full && state.left.full;
+ writing.left.length = left;
+ writing.right.length = right;
+ node.under->AddExtend(next.under);
+ }
+ return next;
+}
+
+Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) {
+ VertexNode &node = *starter.under;
+ assert(node.State().left.full == state.left.full);
+ assert(!node.End());
+ Final *final = context_.NewFinal();
+ final->Reset(partial.score, from.GetRule(), partial.nt[0].End(), partial.nt[1].End());
+ node.SetEnd(final);
+ return final;
+}
+
+} // namespace search
diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh
new file mode 100644
index 00000000..8cdf1420
--- /dev/null
+++ b/klm/search/vertex_generator.hh
@@ -0,0 +1,70 @@
+#ifndef SEARCH_VERTEX_GENERATOR__
+#define SEARCH_VERTEX_GENERATOR__
+
+#include "search/edge.hh"
+#include "search/edge_generator.hh"
+
+#include <boost/pool/pool.hpp>
+#include <boost/unordered_map.hpp>
+
+#include <queue>
+
+namespace lm {
+namespace ngram {
+class ChartState;
+} // namespace ngram
+} // namespace lm
+
+namespace search {
+
+template <class Model> class Context;
+class ContextBase;
+class Final;
+
+class VertexGenerator {
+ public:
+ template <class Model> VertexGenerator(Context<Model> &context, Vertex &gen);
+
+ PartialEdge *MallocPartialEdge() { return static_cast<PartialEdge*>(partial_edge_pool_.malloc()); }
+ void FreePartialEdge(PartialEdge *value) { partial_edge_pool_.free(value); }
+
+ void NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial);
+
+ private:
+ // Parallel structure to VertexNode.
+ struct Trie {
+ Trie() : under(NULL) {}
+
+ VertexNode *under;
+ boost::unordered_map<uint64_t, Trie> extend;
+ };
+
+ Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full);
+
+ Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial);
+
+ ContextBase &context_;
+
+ std::vector<EdgeGenerator> edges_;
+
+ struct LessByTop : public std::binary_function<const EdgeGenerator *, const EdgeGenerator *, bool> {
+ bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const {
+ return first->Top() < second->Top();
+ }
+ };
+
+ typedef std::priority_queue<EdgeGenerator*, std::vector<EdgeGenerator*>, LessByTop> Generate;
+ Generate generate_;
+
+ Trie root_;
+
+ typedef boost::unordered_map<uint64_t, Final*> Existing;
+ Existing existing_;
+
+ int to_pop_;
+
+ boost::pool<> partial_edge_pool_;
+};
+
+} // namespace search
+#endif // SEARCH_VERTEX_GENERATOR__
diff --git a/klm/search/weights.cc b/klm/search/weights.cc
new file mode 100644
index 00000000..82ff3f12
--- /dev/null
+++ b/klm/search/weights.cc
@@ -0,0 +1,69 @@
+#include "search/weights.hh"
+#include "util/tokenize_piece.hh"
+
+#include <cstdlib>
+
+namespace search {
+
+namespace {
+struct Insert {
+ void operator()(boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) const {
+ std::string copy(name.data(), name.size());
+ map[copy] = score;
+ }
+};
+
+struct DotProduct {
+ search::Score total;
+ DotProduct() : total(0.0) {}
+
+ void operator()(const boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) {
+ boost::unordered_map<std::string, search::Score>::const_iterator i(FindStringPiece(map, name));
+ if (i != map.end())
+ total += score * i->second;
+ }
+};
+
+template <class Map, class Op> void Parse(StringPiece text, Map &map, Op &op) {
+ for (util::TokenIter<util::SingleCharacter, true> spaces(text, ' '); spaces; ++spaces) {
+ util::TokenIter<util::SingleCharacter> equals(*spaces, '=');
+ UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces);
+ StringPiece name(*equals);
+ UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces);
+ char *end;
+ // Assumes proper termination.
+ double value = std::strtod(equals->data(), &end);
+ UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals);
+ UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces);
+ op(map, name, value);
+ }
+}
+
+} // namespace
+
+Weights::Weights(StringPiece text) {
+ Insert op;
+ Parse<Map, Insert>(text, map_, op);
+ lm_ = Steal("LanguageModel");
+ oov_ = Steal("OOV");
+ word_penalty_ = Steal("WordPenalty");
+}
+
+search::Score Weights::DotNoLM(StringPiece text) const {
+ DotProduct dot;
+ Parse<const Map, DotProduct>(text, map_, dot);
+ return dot.total;
+}
+
+float Weights::Steal(const std::string &str) {
+ Map::iterator i(map_.find(str));
+ if (i == map_.end()) {
+ return 0.0;
+ } else {
+ float ret = i->second;
+ map_.erase(i);
+ return ret;
+ }
+}
+
+} // namespace search
diff --git a/klm/search/weights.hh b/klm/search/weights.hh
new file mode 100644
index 00000000..4a4388c7
--- /dev/null
+++ b/klm/search/weights.hh
@@ -0,0 +1,49 @@
+// For now, the individual features are not kept.
+#ifndef SEARCH_WEIGHTS__
+#define SEARCH_WEIGHTS__
+
+#include "search/types.hh"
+#include "util/exception.hh"
+#include "util/string_piece.hh"
+
+#include <boost/unordered_map.hpp>
+
+#include <string>
+
+namespace search {
+
+class WeightParseException : public util::Exception {
+ public:
+ WeightParseException() {}
+ ~WeightParseException() throw() {}
+};
+
+class Weights {
+ public:
+ // Parses weights, sets lm_weight_, removes it from map_.
+ explicit Weights(StringPiece text);
+
+ search::Score DotNoLM(StringPiece text) const;
+
+ search::Score LM() const { return lm_; }
+
+ search::Score OOV() const { return oov_; }
+
+ search::Score WordPenalty() const { return word_penalty_; }
+
+ // Mostly for testing.
+ const boost::unordered_map<std::string, search::Score> &GetMap() const { return map_; }
+
+ private:
+ float Steal(const std::string &str);
+
+ typedef boost::unordered_map<std::string, search::Score> Map;
+
+ Map map_;
+
+ search::Score lm_, oov_, word_penalty_;
+};
+
+} // namespace search
+
+#endif // SEARCH_WEIGHTS__
diff --git a/klm/search/weights_test.cc b/klm/search/weights_test.cc
new file mode 100644
index 00000000..4811ff06
--- /dev/null
+++ b/klm/search/weights_test.cc
@@ -0,0 +1,38 @@
+#include "search/weights.hh"
+
+#define BOOST_TEST_MODULE WeightTest
+#include <boost/test/unit_test.hpp>
+#include <boost/test/floating_point_comparison.hpp>
+
+namespace search {
+namespace {
+
+#define CHECK_WEIGHT(value, string) \
+ i = parsed.find(string); \
+ BOOST_REQUIRE(i != parsed.end()); \
+ BOOST_CHECK_CLOSE((value), i->second, 0.001);
+
+BOOST_AUTO_TEST_CASE(parse) {
+ // These are not real feature weights.
+ Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5");
+ const boost::unordered_map<std::string, search::Score> &parsed = w.GetMap();
+ boost::unordered_map<std::string, search::Score>::const_iterator i;
+ CHECK_WEIGHT(0.0, "rarity");
+ CHECK_WEIGHT(0.0, "phrase-SGT");
+ CHECK_WEIGHT(9.45117, "phrase-TGS");
+ CHECK_WEIGHT(2.33833, "lexical-SGT");
+ BOOST_CHECK(parsed.end() == parsed.find("lm"));
+ BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001);
+ CHECK_WEIGHT(-28.3317, "lexical-TGS");
+ CHECK_WEIGHT(5.0, "glue?");
+}
+
+BOOST_AUTO_TEST_CASE(dot) {
+ Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5");
+ BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001);
+ BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001);
+ BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001);
+}
+
+} // namespace
+} // namespace search
diff --git a/klm/search/word.hh b/klm/search/word.hh
new file mode 100644
index 00000000..e7a15be9
--- /dev/null
+++ b/klm/search/word.hh
@@ -0,0 +1,47 @@
+#ifndef SEARCH_WORD__
+#define SEARCH_WORD__
+
+#include "lm/word_index.hh"
+
+#include <boost/functional/hash.hpp>
+
+#include <string>
+#include <utility>
+
+namespace search {
+
+class Word {
+ public:
+ // Construct a non-terminal.
+ Word() : entry_(NULL) {}
+
+ explicit Word(const std::pair<const std::string, lm::WordIndex> &entry) {
+ entry_ = &entry;
+ }
+
+ // Returns true for two non-terminals even if their labels are different (since we don't care about labels).
+ bool operator==(const Word &other) const {
+ return entry_ == other.entry_;
+ }
+
+ bool Terminal() const { return entry_ != NULL; }
+
+ const std::string &String() const { return entry_->first; }
+
+ lm::WordIndex Index() const { return entry_->second; }
+
+ protected:
+ friend size_t hash_value(const Word &word);
+
+ const std::pair<const std::string, lm::WordIndex> *Entry() const { return entry_; }
+
+ private:
+ const std::pair<const std::string, lm::WordIndex> *entry_;
+};
+
+inline size_t hash_value(const Word &word) {
+ return boost::hash_value(word.Entry());
+}
+
+} // namespace search
+#endif // SEARCH_WORD__