summaryrefslogtreecommitdiff
path: root/klm/search
diff options
context:
space:
mode:
authorAvneesh Saluja <asaluja@gmail.com>2013-03-28 18:28:16 -0700
committerAvneesh Saluja <asaluja@gmail.com>2013-03-28 18:28:16 -0700
commit5b8253e0e1f1393a509fb9975ba8c1347af758ed (patch)
tree1790470b1d07a0b4973ebce19192e896566ea60b /klm/search
parent2389a5a8a43dda87c355579838559515b0428421 (diff)
parentb203f8c5dc8cff1b9c9c2073832b248fcad0765a (diff)
fixed conflicts
Diffstat (limited to 'klm/search')
-rw-r--r--klm/search/Makefile.am23
-rw-r--r--klm/search/applied.hh86
-rw-r--r--klm/search/config.hh38
-rw-r--r--klm/search/context.hh49
-rw-r--r--klm/search/dedupe.hh131
-rw-r--r--klm/search/edge.hh54
-rw-r--r--klm/search/edge_generator.cc111
-rw-r--r--klm/search/edge_generator.hh56
-rw-r--r--klm/search/header.hh64
-rw-r--r--klm/search/nbest.cc106
-rw-r--r--klm/search/nbest.hh81
-rw-r--r--klm/search/rule.cc43
-rw-r--r--klm/search/rule.hh25
-rw-r--r--klm/search/types.hh31
-rw-r--r--klm/search/vertex.cc55
-rw-r--r--klm/search/vertex.hh164
-rw-r--r--klm/search/vertex_generator.cc68
-rw-r--r--klm/search/vertex_generator.hh99
18 files changed, 1284 insertions, 0 deletions
diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am
new file mode 100644
index 00000000..03554276
--- /dev/null
+++ b/klm/search/Makefile.am
@@ -0,0 +1,23 @@
+noinst_LIBRARIES = libksearch.a
+
+libksearch_a_SOURCES = \
+ applied.hh \
+ config.hh \
+ context.hh \
+ dedupe.hh \
+ edge.hh \
+ edge_generator.hh \
+ header.hh \
+ nbest.hh \
+ rule.hh \
+ types.hh \
+ vertex.hh \
+ vertex_generator.hh \
+ edge_generator.cc \
+ nbest.cc \
+ rule.cc \
+ vertex.cc \
+ vertex_generator.cc
+
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/klm
+
diff --git a/klm/search/applied.hh b/klm/search/applied.hh
new file mode 100644
index 00000000..bd659e5c
--- /dev/null
+++ b/klm/search/applied.hh
@@ -0,0 +1,86 @@
+#ifndef SEARCH_APPLIED__
+#define SEARCH_APPLIED__
+
+#include "search/edge.hh"
+#include "search/header.hh"
+#include "util/pool.hh"
+
+#include <math.h>
+
+namespace search {
+
+// A full hypothesis: a score, arity of the rule, a pointer to the decoder's rule (Note), and pointers to non-terminals that were substituted.
+template <class Below> class GenericApplied : public Header {
+ public:
+ GenericApplied() {}
+
+ GenericApplied(void *location, PartialEdge partial)
+ : Header(location) {
+ memcpy(Base(), partial.Base(), kHeaderSize);
+ Below *child_out = Children();
+ const PartialVertex *part = partial.NT();
+ const PartialVertex *const part_end_loop = part + partial.GetArity();
+ for (; part != part_end_loop; ++part, ++child_out)
+ *child_out = Below(part->End());
+ }
+
+ GenericApplied(void *location, Score score, Arity arity, Note note) : Header(location, arity) {
+ SetScore(score);
+ SetNote(note);
+ }
+
+ explicit GenericApplied(History from) : Header(from) {}
+
+
+ // These are arrays of length GetArity().
+ Below *Children() {
+ return reinterpret_cast<Below*>(After());
+ }
+ const Below *Children() const {
+ return reinterpret_cast<const Below*>(After());
+ }
+
+ static std::size_t Size(Arity arity) {
+ return kHeaderSize + arity * sizeof(const Below);
+ }
+};
+
+// Applied rule that references itself.
+class Applied : public GenericApplied<Applied> {
+ private:
+ typedef GenericApplied<Applied> P;
+
+ public:
+ Applied() {}
+ Applied(void *location, PartialEdge partial) : P(location, partial) {}
+ Applied(History from) : P(from) {}
+};
+
+// How to build single-best hypotheses.
+class SingleBest {
+ public:
+ typedef PartialEdge Combine;
+
+ void Add(PartialEdge &existing, PartialEdge add) const {
+ if (!existing.Valid() || existing.GetScore() < add.GetScore())
+ existing = add;
+ }
+
+ NBestComplete Complete(PartialEdge partial) {
+ if (!partial.Valid())
+ return NBestComplete(NULL, lm::ngram::ChartState(), -INFINITY);
+ void *place_final = pool_.Allocate(Applied::Size(partial.GetArity()));
+ Applied(place_final, partial);
+ return NBestComplete(
+ place_final,
+ partial.CompletedState(),
+ partial.GetScore());
+ }
+
+ private:
+ util::Pool pool_;
+};
+
+} // namespace search
+
+#endif // SEARCH_APPLIED__
diff --git a/klm/search/config.hh b/klm/search/config.hh
new file mode 100644
index 00000000..ba18c09e
--- /dev/null
+++ b/klm/search/config.hh
@@ -0,0 +1,38 @@
+#ifndef SEARCH_CONFIG__
+#define SEARCH_CONFIG__
+
+#include "search/types.hh"
+
+namespace search {
+
+struct NBestConfig {
+ explicit NBestConfig(unsigned int in_size) {
+ keep = in_size;
+ size = in_size;
+ }
+
+ unsigned int keep, size;
+};
+
+class Config {
+ public:
+ Config(Score lm_weight, unsigned int pop_limit, const NBestConfig &nbest) :
+ lm_weight_(lm_weight), pop_limit_(pop_limit), nbest_(nbest) {}
+
+ Score LMWeight() const { return lm_weight_; }
+
+ unsigned int PopLimit() const { return pop_limit_; }
+
+ const NBestConfig &GetNBest() const { return nbest_; }
+
+ private:
+ Score lm_weight_;
+
+ unsigned int pop_limit_;
+
+ NBestConfig nbest_;
+};
+
+} // namespace search
+
+#endif // SEARCH_CONFIG__
diff --git a/klm/search/context.hh b/klm/search/context.hh
new file mode 100644
index 00000000..08f21bbf
--- /dev/null
+++ b/klm/search/context.hh
@@ -0,0 +1,49 @@
+#ifndef SEARCH_CONTEXT__
+#define SEARCH_CONTEXT__
+
+#include "search/config.hh"
+#include "search/vertex.hh"
+
+#include <boost/pool/object_pool.hpp>
+
+namespace search {
+
+class ContextBase {
+ public:
+ explicit ContextBase(const Config &config) : config_(config) {}
+
+ VertexNode *NewVertexNode() {
+ VertexNode *ret = vertex_node_pool_.construct();
+ assert(ret);
+ return ret;
+ }
+
+ void DeleteVertexNode(VertexNode *node) {
+ vertex_node_pool_.destroy(node);
+ }
+
+ unsigned int PopLimit() const { return config_.PopLimit(); }
+
+ Score LMWeight() const { return config_.LMWeight(); }
+
+ const Config &GetConfig() const { return config_; }
+
+ private:
+ boost::object_pool<VertexNode> vertex_node_pool_;
+
+ Config config_;
+};
+
+template <class Model> class Context : public ContextBase {
+ public:
+ Context(const Config &config, const Model &model) : ContextBase(config), model_(model) {}
+
+ const Model &LanguageModel() const { return model_; }
+
+ private:
+ const Model &model_;
+};
+
+} // namespace search
+
+#endif // SEARCH_CONTEXT__
diff --git a/klm/search/dedupe.hh b/klm/search/dedupe.hh
new file mode 100644
index 00000000..7eaa3b95
--- /dev/null
+++ b/klm/search/dedupe.hh
@@ -0,0 +1,131 @@
+#ifndef SEARCH_DEDUPE__
+#define SEARCH_DEDUPE__
+
+#include "lm/state.hh"
+#include "search/edge_generator.hh"
+
+#include <boost/pool/object_pool.hpp>
+#include <boost/unordered_map.hpp>
+
+namespace search {
+
+class Dedupe {
+ public:
+ Dedupe() {}
+
+ PartialEdge AllocateEdge(Arity arity) {
+ return behind_.AllocateEdge(arity);
+ }
+
+ void AddEdge(PartialEdge edge) {
+ edge.MutableFlags() = 0;
+
+ uint64_t hash = 0;
+ const PartialVertex *v = edge.NT();
+ const PartialVertex *v_end = v + edge.GetArity();
+ for (; v != v_end; ++v) {
+ const void *ptr = v->Identify();
+ hash = util::MurmurHashNative(&ptr, sizeof(const void*), hash);
+ }
+
+ const lm::ngram::ChartState *c = edge.Between();
+ const lm::ngram::ChartState *const c_end = c + edge.GetArity() + 1;
+ for (; c != c_end; ++c) hash = hash_value(*c, hash);
+
+ std::pair<Table::iterator, bool> ret(table_.insert(std::make_pair(hash, edge)));
+ if (!ret.second) FoundDupe(ret.first->second, edge);
+ }
+
+ bool Empty() const { return behind_.Empty(); }
+
+ template <class Model, class Output> void Search(Context<Model> &context, Output &output) {
+ for (Table::const_iterator i(table_.begin()); i != table_.end(); ++i) {
+ behind_.AddEdge(i->second);
+ }
+ Unpack<Output> unpack(output, *this);
+ behind_.Search(context, unpack);
+ }
+
+ private:
+ void FoundDupe(PartialEdge &table, PartialEdge adding) {
+ if (table.GetFlags() & kPackedFlag) {
+ Packed &packed = *static_cast<Packed*>(table.GetNote().mut);
+ if (table.GetScore() >= adding.GetScore()) {
+ packed.others.push_back(adding);
+ return;
+ }
+ Note original(packed.original);
+ packed.original = adding.GetNote();
+ adding.SetNote(table.GetNote());
+ table.SetNote(original);
+ packed.others.push_back(table);
+ packed.starting = adding.GetScore();
+ table = adding;
+ table.MutableFlags() |= kPackedFlag;
+ return;
+ }
+ PartialEdge loser;
+ if (adding.GetScore() > table.GetScore()) {
+ loser = table;
+ table = adding;
+ } else {
+ loser = adding;
+ }
+ // table is winner, loser is loser...
+ packed_.construct(table, loser);
+ }
+
+ struct Packed {
+ Packed(PartialEdge winner, PartialEdge loser)
+ : original(winner.GetNote()), starting(winner.GetScore()), others(1, loser) {
+ winner.MutableNote().vp = this;
+ winner.MutableFlags() |= kPackedFlag;
+ loser.MutableFlags() &= ~kPackedFlag;
+ }
+ Note original;
+ Score starting;
+ std::vector<PartialEdge> others;
+ };
+
+ template <class Output> class Unpack {
+ public:
+ explicit Unpack(Output &output, Dedupe &owner) : output_(output), owner_(owner) {}
+
+ void NewHypothesis(PartialEdge edge) {
+ if (edge.GetFlags() & kPackedFlag) {
+ Packed &packed = *reinterpret_cast<Packed*>(edge.GetNote().mut);
+ edge.SetNote(packed.original);
+ edge.MutableFlags() = 0;
+ std::size_t copy_size = sizeof(PartialVertex) * edge.GetArity() + sizeof(lm::ngram::ChartState);
+ for (std::vector<PartialEdge>::iterator i = packed.others.begin(); i != packed.others.end(); ++i) {
+ PartialEdge copy(owner_.AllocateEdge(edge.GetArity()));
+ copy.SetScore(edge.GetScore() - packed.starting + i->GetScore());
+ copy.MutableFlags() = 0;
+ copy.SetNote(i->GetNote());
+ memcpy(copy.NT(), edge.NT(), copy_size);
+ output_.NewHypothesis(copy);
+ }
+ }
+ output_.NewHypothesis(edge);
+ }
+
+ void FinishedSearch() {
+ output_.FinishedSearch();
+ }
+
+ private:
+ Output &output_;
+ Dedupe &owner_;
+ };
+
+ EdgeGenerator behind_;
+
+ typedef boost::unordered_map<uint64_t, PartialEdge> Table;
+ Table table_;
+
+ boost::object_pool<Packed> packed_;
+
+ static const uint16_t kPackedFlag = 1;
+};
+} // namespace search
+#endif // SEARCH_DEDUPE__
diff --git a/klm/search/edge.hh b/klm/search/edge.hh
new file mode 100644
index 00000000..187904bf
--- /dev/null
+++ b/klm/search/edge.hh
@@ -0,0 +1,54 @@
+#ifndef SEARCH_EDGE__
+#define SEARCH_EDGE__
+
+#include "lm/state.hh"
+#include "search/header.hh"
+#include "search/types.hh"
+#include "search/vertex.hh"
+#include "util/pool.hh"
+
+#include <functional>
+
+#include <stdint.h>
+
+namespace search {
+
+// Copyable, but the copy will be shallow.
+class PartialEdge : public Header {
+ public:
+ // Allow default construction for STL.
+ PartialEdge() {}
+
+ PartialEdge(util::Pool &pool, Arity arity)
+ : Header(pool.Allocate(Size(arity, arity + 1)), arity) {}
+
+ PartialEdge(util::Pool &pool, Arity arity, Arity chart_states)
+ : Header(pool.Allocate(Size(arity, chart_states)), arity) {}
+
+ // Non-terminals
+ const PartialVertex *NT() const {
+ return reinterpret_cast<const PartialVertex*>(After());
+ }
+ PartialVertex *NT() {
+ return reinterpret_cast<PartialVertex*>(After());
+ }
+
+ const lm::ngram::ChartState &CompletedState() const {
+ return *Between();
+ }
+ const lm::ngram::ChartState *Between() const {
+ return reinterpret_cast<const lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex));
+ }
+ lm::ngram::ChartState *Between() {
+ return reinterpret_cast<lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex));
+ }
+
+ private:
+ static std::size_t Size(Arity arity, Arity chart_states) {
+ return kHeaderSize + arity * sizeof(PartialVertex) + chart_states * sizeof(lm::ngram::ChartState);
+ }
+};
+
+
+} // namespace search
+#endif // SEARCH_EDGE__
diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc
new file mode 100644
index 00000000..eacf5de5
--- /dev/null
+++ b/klm/search/edge_generator.cc
@@ -0,0 +1,111 @@
+#include "search/edge_generator.hh"
+
+#include "lm/left.hh"
+#include "lm/model.hh"
+#include "lm/partial.hh"
+#include "search/context.hh"
+#include "search/vertex.hh"
+
+#include <numeric>
+
+namespace search {
+
+namespace {
+
+template <class Model> void FastScore(const Context<Model> &context, Arity victim, Arity before_idx, Arity incomplete, const PartialVertex &previous_vertex, PartialEdge update) {
+ lm::ngram::ChartState *between = update.Between();
+ lm::ngram::ChartState *before = &between[before_idx], *after = &between[before_idx + 1];
+
+ float adjustment = 0.0;
+ const lm::ngram::ChartState &previous_reveal = previous_vertex.State();
+ const PartialVertex &update_nt = update.NT()[victim];
+ const lm::ngram::ChartState &update_reveal = update_nt.State();
+ if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) {
+ adjustment += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length);
+ }
+ if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous_vertex.RightFull())) {
+ adjustment += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right);
+ }
+ if (update_nt.Complete()) {
+ if (update_reveal.left.full) {
+ before->left.full = true;
+ } else {
+ assert(update_reveal.left.length == update_reveal.right.length);
+ adjustment += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length);
+ }
+ before->right = after->right;
+ // Shift the others shifted one down, covering after.
+ for (lm::ngram::ChartState *cover = after; cover < between + incomplete; ++cover) {
+ *cover = *(cover + 1);
+ }
+ }
+ update.SetScore(update.GetScore() + adjustment * context.LMWeight());
+}
+
+} // namespace
+
+template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) {
+ assert(!generate_.empty());
+ PartialEdge top = generate_.top();
+ generate_.pop();
+ PartialVertex *const top_nt = top.NT();
+ const Arity arity = top.GetArity();
+
+ Arity victim = 0;
+ Arity victim_completed;
+ Arity incomplete;
+ // Select victim or return if complete.
+ {
+ Arity completed = 0;
+ unsigned char lowest_length = 255;
+ for (Arity i = 0; i != arity; ++i) {
+ if (top_nt[i].Complete()) {
+ ++completed;
+ } else if (top_nt[i].Length() < lowest_length) {
+ lowest_length = top_nt[i].Length();
+ victim = i;
+ victim_completed = completed;
+ }
+ }
+ if (lowest_length == 255) {
+ return top;
+ }
+ incomplete = arity - completed;
+ }
+
+ PartialVertex old_value(top_nt[victim]);
+ PartialVertex alternate_changed;
+ if (top_nt[victim].Split(alternate_changed)) {
+ PartialEdge alternate(partial_edge_pool_, arity, incomplete + 1);
+ alternate.SetScore(top.GetScore() + alternate_changed.Bound() - old_value.Bound());
+
+ alternate.SetNote(top.GetNote());
+
+ PartialVertex *alternate_nt = alternate.NT();
+ for (Arity i = 0; i < victim; ++i) alternate_nt[i] = top_nt[i];
+ alternate_nt[victim] = alternate_changed;
+ for (Arity i = victim + 1; i < arity; ++i) alternate_nt[i] = top_nt[i];
+
+ memcpy(alternate.Between(), top.Between(), sizeof(lm::ngram::ChartState) * (incomplete + 1));
+
+ // TODO: dedupe?
+ generate_.push(alternate);
+ }
+
+ // top is now the continuation.
+ FastScore(context, victim, victim - victim_completed, incomplete, old_value, top);
+ // TODO: dedupe?
+ generate_.push(top);
+
+ // Invalid indicates no new hypothesis generated.
+ return PartialEdge();
+}
+
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context);
+
+} // namespace search
diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh
new file mode 100644
index 00000000..203942c6
--- /dev/null
+++ b/klm/search/edge_generator.hh
@@ -0,0 +1,56 @@
+#ifndef SEARCH_EDGE_GENERATOR__
+#define SEARCH_EDGE_GENERATOR__
+
+#include "search/edge.hh"
+#include "search/types.hh"
+
+#include <queue>
+
+namespace lm {
+namespace ngram {
+class ChartState;
+} // namespace ngram
+} // namespace lm
+
+namespace search {
+
+template <class Model> class Context;
+
+class EdgeGenerator {
+ public:
+ EdgeGenerator() {}
+
+ PartialEdge AllocateEdge(Arity arity) {
+ return PartialEdge(partial_edge_pool_, arity);
+ }
+
+ void AddEdge(PartialEdge edge) {
+ generate_.push(edge);
+ }
+
+ bool Empty() const { return generate_.empty(); }
+
+ // Pop. If there's a complete hypothesis, return it. Otherwise return an invalid PartialEdge.
+ template <class Model> PartialEdge Pop(Context<Model> &context);
+
+ template <class Model, class Output> void Search(Context<Model> &context, Output &output) {
+ unsigned to_pop = context.PopLimit();
+ while (to_pop > 0 && !generate_.empty()) {
+ PartialEdge got(Pop(context));
+ if (got.Valid()) {
+ output.NewHypothesis(got);
+ --to_pop;
+ }
+ }
+ output.FinishedSearch();
+ }
+
+ private:
+ util::Pool partial_edge_pool_;
+
+ typedef std::priority_queue<PartialEdge> Generate;
+ Generate generate_;
+};
+
+} // namespace search
+#endif // SEARCH_EDGE_GENERATOR__
diff --git a/klm/search/header.hh b/klm/search/header.hh
new file mode 100644
index 00000000..69f0eed0
--- /dev/null
+++ b/klm/search/header.hh
@@ -0,0 +1,64 @@
+#ifndef SEARCH_HEADER__
+#define SEARCH_HEADER__
+
+// Header consisting of Score, Arity, and Note
+
+#include "search/types.hh"
+
+#include <stdint.h>
+
+namespace search {
+
+// Copying is shallow.
+class Header {
+ public:
+ bool Valid() const { return base_; }
+
+ Score GetScore() const {
+ return *reinterpret_cast<const float*>(base_);
+ }
+ void SetScore(Score to) {
+ *reinterpret_cast<float*>(base_) = to;
+ }
+ bool operator<(const Header &other) const {
+ return GetScore() < other.GetScore();
+ }
+ bool operator>(const Header &other) const {
+ return GetScore() > other.GetScore();
+ }
+
+ Arity GetArity() const {
+ return *reinterpret_cast<const Arity*>(base_ + sizeof(Score));
+ }
+
+ Note GetNote() const {
+ return *reinterpret_cast<const Note*>(base_ + sizeof(Score) + sizeof(Arity));
+ }
+ void SetNote(Note to) {
+ *reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to;
+ }
+
+ uint8_t *Base() { return base_; }
+ const uint8_t *Base() const { return base_; }
+
+ protected:
+ Header() : base_(NULL) {}
+
+ explicit Header(void *base) : base_(static_cast<uint8_t*>(base)) {}
+
+ Header(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) {
+ *reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity;
+ }
+
+ static const std::size_t kHeaderSize = sizeof(Score) + sizeof(Arity) + sizeof(Note);
+
+ uint8_t *After() { return base_ + kHeaderSize; }
+ const uint8_t *After() const { return base_ + kHeaderSize; }
+
+ private:
+ uint8_t *base_;
+};
+
+} // namespace search
+
+#endif // SEARCH_HEADER__
diff --git a/klm/search/nbest.cc b/klm/search/nbest.cc
new file mode 100644
index 00000000..ec3322c9
--- /dev/null
+++ b/klm/search/nbest.cc
@@ -0,0 +1,106 @@
+#include "search/nbest.hh"
+
+#include "util/pool.hh"
+
+#include <algorithm>
+#include <functional>
+#include <queue>
+
+#include <assert.h>
+#include <math.h>
+
+namespace search {
+
+NBestList::NBestList(std::vector<PartialEdge> &partials, util::Pool &entry_pool, std::size_t keep) {
+ assert(!partials.empty());
+ std::vector<PartialEdge>::iterator end;
+ if (partials.size() > keep) {
+ end = partials.begin() + keep;
+ std::nth_element(partials.begin(), end, partials.end(), std::greater<PartialEdge>());
+ } else {
+ end = partials.end();
+ }
+ for (std::vector<PartialEdge>::const_iterator i(partials.begin()); i != end; ++i) {
+ queue_.push(QueueEntry(entry_pool.Allocate(QueueEntry::Size(i->GetArity())), *i));
+ }
+}
+
+Score NBestList::TopAfterConstructor() const {
+ assert(revealed_.empty());
+ return queue_.top().GetScore();
+}
+
+const std::vector<Applied> &NBestList::Extract(util::Pool &pool, std::size_t n) {
+ while (revealed_.size() < n && !queue_.empty()) {
+ MoveTop(pool);
+ }
+ return revealed_;
+}
+
+Score NBestList::Visit(util::Pool &pool, std::size_t index) {
+ if (index + 1 < revealed_.size())
+ return revealed_[index + 1].GetScore() - revealed_[index].GetScore();
+ if (queue_.empty())
+ return -INFINITY;
+ if (index + 1 == revealed_.size())
+ return queue_.top().GetScore() - revealed_[index].GetScore();
+ assert(index == revealed_.size());
+
+ MoveTop(pool);
+
+ if (queue_.empty()) return -INFINITY;
+ return queue_.top().GetScore() - revealed_[index].GetScore();
+}
+
+Applied NBestList::Get(util::Pool &pool, std::size_t index) {
+ assert(index <= revealed_.size());
+ if (index == revealed_.size()) MoveTop(pool);
+ return revealed_[index];
+}
+
+void NBestList::MoveTop(util::Pool &pool) {
+ assert(!queue_.empty());
+ QueueEntry entry(queue_.top());
+ queue_.pop();
+ RevealedRef *const children_begin = entry.Children();
+ RevealedRef *const children_end = children_begin + entry.GetArity();
+ Score basis = entry.GetScore();
+ for (RevealedRef *child = children_begin; child != children_end; ++child) {
+ Score change = child->in_->Visit(pool, child->index_);
+ if (change != -INFINITY) {
+ assert(change < 0.001);
+ QueueEntry new_entry(pool.Allocate(QueueEntry::Size(entry.GetArity())), basis + change, entry.GetArity(), entry.GetNote());
+ std::copy(children_begin, child, new_entry.Children());
+ RevealedRef *update = new_entry.Children() + (child - children_begin);
+ update->in_ = child->in_;
+ update->index_ = child->index_ + 1;
+ std::copy(child + 1, children_end, update + 1);
+ queue_.push(new_entry);
+ }
+ // Gesmundo, A. and Henderson, J. Faster Cube Pruning, IWSLT 2010.
+ if (child->index_) break;
+ }
+
+ // Convert QueueEntry to Applied. This leaves some unused memory.
+ void *overwrite = entry.Children();
+ for (unsigned int i = 0; i < entry.GetArity(); ++i) {
+ RevealedRef from(*(static_cast<const RevealedRef*>(overwrite) + i));
+ *(static_cast<Applied*>(overwrite) + i) = from.in_->Get(pool, from.index_);
+ }
+ revealed_.push_back(Applied(entry.Base()));
+}
+
+NBestComplete NBest::Complete(std::vector<PartialEdge> &partials) {
+ assert(!partials.empty());
+ NBestList *list = list_pool_.construct(partials, entry_pool_, config_.keep);
+ return NBestComplete(
+ list,
+ partials.front().CompletedState(), // All partials have the same state
+ list->TopAfterConstructor());
+}
+
+const std::vector<Applied> &NBest::Extract(History history) {
+ return static_cast<NBestList*>(history)->Extract(entry_pool_, config_.size);
+}
+
+} // namespace search
diff --git a/klm/search/nbest.hh b/klm/search/nbest.hh
new file mode 100644
index 00000000..cb7651bc
--- /dev/null
+++ b/klm/search/nbest.hh
@@ -0,0 +1,81 @@
+#ifndef SEARCH_NBEST__
+#define SEARCH_NBEST__
+
+#include "search/applied.hh"
+#include "search/config.hh"
+#include "search/edge.hh"
+
+#include <boost/pool/object_pool.hpp>
+
+#include <cstddef>
+#include <queue>
+#include <vector>
+
+#include <assert.h>
+
+namespace search {
+
+class NBestList;
+
+class NBestList {
+ private:
+ class RevealedRef {
+ public:
+ explicit RevealedRef(History history)
+ : in_(static_cast<NBestList*>(history)), index_(0) {}
+
+ private:
+ friend class NBestList;
+
+ NBestList *in_;
+ std::size_t index_;
+ };
+
+ typedef GenericApplied<RevealedRef> QueueEntry;
+
+ public:
+ NBestList(std::vector<PartialEdge> &existing, util::Pool &entry_pool, std::size_t keep);
+
+ Score TopAfterConstructor() const;
+
+ const std::vector<Applied> &Extract(util::Pool &pool, std::size_t n);
+
+ private:
+ Score Visit(util::Pool &pool, std::size_t index);
+
+ Applied Get(util::Pool &pool, std::size_t index);
+
+ void MoveTop(util::Pool &pool);
+
+ typedef std::vector<Applied> Revealed;
+ Revealed revealed_;
+
+ typedef std::priority_queue<QueueEntry> Queue;
+ Queue queue_;
+};
+
+class NBest {
+ public:
+ typedef std::vector<PartialEdge> Combine;
+
+ explicit NBest(const NBestConfig &config) : config_(config) {}
+
+ void Add(std::vector<PartialEdge> &existing, PartialEdge addition) const {
+ existing.push_back(addition);
+ }
+
+ NBestComplete Complete(std::vector<PartialEdge> &partials);
+
+ const std::vector<Applied> &Extract(History root);
+
+ private:
+ const NBestConfig config_;
+
+ boost::object_pool<NBestList> list_pool_;
+
+ util::Pool entry_pool_;
+};
+
+} // namespace search
+
+#endif // SEARCH_NBEST__
diff --git a/klm/search/rule.cc b/klm/search/rule.cc
new file mode 100644
index 00000000..0244a09f
--- /dev/null
+++ b/klm/search/rule.cc
@@ -0,0 +1,43 @@
+#include "search/rule.hh"
+
+#include "lm/model.hh"
+#include "search/context.hh"
+
+#include <ostream>
+
+#include <math.h>
+
+namespace search {
+
+template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing) {
+ ScoreRuleRet ret;
+ ret.prob = 0.0;
+ ret.oov = 0;
+ const lm::WordIndex oov = model.GetVocabulary().NotFound(), bos = model.GetVocabulary().BeginSentence();
+ lm::ngram::RuleScore<Model> scorer(model, *(writing++));
+ std::vector<lm::WordIndex>::const_iterator word = words.begin();
+ if (word != words.end() && *word == bos) {
+ scorer.BeginSentence();
+ ++word;
+ }
+ for (; word != words.end(); ++word) {
+ if (*word == kNonTerminal) {
+ ret.prob += scorer.Finish();
+ scorer.Reset(*(writing++));
+ } else {
+ if (*word == oov) ++ret.oov;
+ scorer.Terminal(*word);
+ }
+ }
+ ret.prob += scorer.Finish();
+ return ret;
+}
+
+template ScoreRuleRet ScoreRule(const lm::ngram::RestProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
+template ScoreRuleRet ScoreRule(const lm::ngram::ProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
+template ScoreRuleRet ScoreRule(const lm::ngram::TrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
+template ScoreRuleRet ScoreRule(const lm::ngram::QuantTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
+template ScoreRuleRet ScoreRule(const lm::ngram::ArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
+template ScoreRuleRet ScoreRule(const lm::ngram::QuantArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing);
+
+} // namespace search
diff --git a/klm/search/rule.hh b/klm/search/rule.hh
new file mode 100644
index 00000000..43ca6162
--- /dev/null
+++ b/klm/search/rule.hh
@@ -0,0 +1,25 @@
+#ifndef SEARCH_RULE__
+#define SEARCH_RULE__
+
+#include "lm/left.hh"
+#include "lm/word_index.hh"
+#include "search/types.hh"
+
+#include <vector>
+
+namespace search {
+
+const lm::WordIndex kNonTerminal = lm::kMaxWordIndex;
+
+struct ScoreRuleRet {
+ Score prob;
+ unsigned int oov;
+};
+
+// Pass <s> and </s> normally.
+// Indicate non-terminals with kNonTerminal.
+template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *state_out);
+
+} // namespace search
+
+#endif // SEARCH_RULE__
diff --git a/klm/search/types.hh b/klm/search/types.hh
new file mode 100644
index 00000000..f9c849b3
--- /dev/null
+++ b/klm/search/types.hh
@@ -0,0 +1,31 @@
+#ifndef SEARCH_TYPES__
+#define SEARCH_TYPES__
+
+#include <stdint.h>
+
+namespace lm { namespace ngram { class ChartState; } }
+
+namespace search {
+
+typedef float Score;
+
+typedef uint32_t Arity;
+
+union Note {
+ const void *vp;
+};
+
+typedef void *History;
+
+struct NBestComplete {
+ NBestComplete(History in_history, const lm::ngram::ChartState &in_state, Score in_score)
+ : history(in_history), state(&in_state), score(in_score) {}
+
+ History history;
+ const lm::ngram::ChartState *state;
+ Score score;
+};
+
+} // namespace search
+
+#endif // SEARCH_TYPES__
diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc
new file mode 100644
index 00000000..45842982
--- /dev/null
+++ b/klm/search/vertex.cc
@@ -0,0 +1,55 @@
+#include "search/vertex.hh"
+
+#include "search/context.hh"
+
+#include <algorithm>
+#include <functional>
+
+#include <assert.h>
+
+namespace search {
+
+namespace {
+
+struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> {
+ bool operator()(const VertexNode *first, const VertexNode *second) const {
+ return first->Bound() > second->Bound();
+ }
+};
+
+} // namespace
+
+void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) {
+ if (Complete()) {
+ assert(end_);
+ assert(extend_.empty());
+ return;
+ }
+ if (extend_.size() == 1) {
+ parent_ptr = extend_[0];
+ extend_[0]->RecursiveSortAndSet(context, parent_ptr);
+ context.DeleteVertexNode(this);
+ return;
+ }
+ for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
+ (*i)->RecursiveSortAndSet(context, *i);
+ }
+ std::sort(extend_.begin(), extend_.end(), GreaterByBound());
+ bound_ = extend_.front()->Bound();
+}
+
+void VertexNode::SortAndSet(ContextBase &context) {
+ // This is the root. The root might be empty.
+ if (extend_.empty()) {
+ bound_ = -INFINITY;
+ return;
+ }
+ // The root cannot be replaced. There's always one transition.
+ for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) {
+ (*i)->RecursiveSortAndSet(context, *i);
+ }
+ std::sort(extend_.begin(), extend_.end(), GreaterByBound());
+ bound_ = extend_.front()->Bound();
+}
+
+} // namespace search
diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh
new file mode 100644
index 00000000..ca9a4fcd
--- /dev/null
+++ b/klm/search/vertex.hh
@@ -0,0 +1,164 @@
+#ifndef SEARCH_VERTEX__
+#define SEARCH_VERTEX__
+
+#include "lm/left.hh"
+#include "search/types.hh"
+
+#include <boost/unordered_set.hpp>
+
+#include <queue>
+#include <vector>
+
+#include <math.h>
+#include <stdint.h>
+
+namespace search {
+
+class ContextBase;
+
+class VertexNode {
+ public:
+ VertexNode() : end_() {}
+
+ void InitRoot() {
+ extend_.clear();
+ state_.left.full = false;
+ state_.left.length = 0;
+ state_.right.length = 0;
+ right_full_ = false;
+ end_ = History();
+ }
+
+ lm::ngram::ChartState &MutableState() { return state_; }
+ bool &MutableRightFull() { return right_full_; }
+
+ void AddExtend(VertexNode *next) {
+ extend_.push_back(next);
+ }
+
+ void SetEnd(History end, Score score) {
+ assert(!end_);
+ end_ = end;
+ bound_ = score;
+ }
+
+ void SortAndSet(ContextBase &context);
+
+ // Should only happen to a root node when the entire vertex is empty.
+ bool Empty() const {
+ return !end_ && extend_.empty();
+ }
+
+ bool Complete() const {
+ return end_;
+ }
+
+ const lm::ngram::ChartState &State() const { return state_; }
+ bool RightFull() const { return right_full_; }
+
+ Score Bound() const {
+ return bound_;
+ }
+
+ unsigned char Length() const {
+ return state_.left.length + state_.right.length;
+ }
+
+ // Will be invalid unless this is a leaf.
+ History End() const { return end_; }
+
+ const VertexNode &operator[](size_t index) const {
+ return *extend_[index];
+ }
+
+ size_t Size() const {
+ return extend_.size();
+ }
+
+ private:
+ void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent);
+
+ std::vector<VertexNode*> extend_;
+
+ lm::ngram::ChartState state_;
+ bool right_full_;
+
+ Score bound_;
+ History end_;
+};
+
+class PartialVertex {
+ public:
+ PartialVertex() {}
+
+ explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {}
+
+ bool Empty() const { return back_->Empty(); }
+
+ bool Complete() const { return back_->Complete(); }
+
+ const lm::ngram::ChartState &State() const { return back_->State(); }
+ bool RightFull() const { return back_->RightFull(); }
+
+ Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); }
+
+ unsigned char Length() const { return back_->Length(); }
+
+ bool HasAlternative() const {
+ return index_ + 1 < back_->Size();
+ }
+
+ // Split into continuation and alternative, rendering this the continuation.
+ bool Split(PartialVertex &alternative) {
+ assert(!Complete());
+ bool ret;
+ if (index_ + 1 < back_->Size()) {
+ alternative.index_ = index_ + 1;
+ alternative.back_ = back_;
+ ret = true;
+ } else {
+ ret = false;
+ }
+ back_ = &((*back_)[index_]);
+ index_ = 0;
+ return ret;
+ }
+
+ History End() const {
+ return back_->End();
+ }
+
+ private:
+ const VertexNode *back_;
+ unsigned int index_;
+};
+
+template <class Output> class VertexGenerator;
+
+class Vertex {
+ public:
+ Vertex() {}
+
+ PartialVertex RootPartial() const { return PartialVertex(root_); }
+
+ History BestChild() const {
+ PartialVertex top(RootPartial());
+ if (top.Empty()) {
+ return History();
+ } else {
+ PartialVertex continuation;
+ while (!top.Complete()) {
+ top.Split(continuation);
+ }
+ return top.End();
+ }
+ }
+
+ private:
+ template <class Output> friend class VertexGenerator;
+ template <class Output> friend class RootVertexGenerator;
+ VertexNode root_;
+};
+
+} // namespace search
+#endif // SEARCH_VERTEX__
diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc
new file mode 100644
index 00000000..73139ffc
--- /dev/null
+++ b/klm/search/vertex_generator.cc
@@ -0,0 +1,68 @@
+#include "search/vertex_generator.hh"
+
+#include "lm/left.hh"
+#include "search/context.hh"
+#include "search/edge.hh"
+
+#include <boost/unordered_map.hpp>
+#include <boost/version.hpp>
+
+#include <stdint.h>
+
+namespace search {
+
+#if BOOST_VERSION > 104200
+namespace {
+
+const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
+
+Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {
+ Trie &next = node.extend[added];
+ if (!next.under) {
+ next.under = context.NewVertexNode();
+ lm::ngram::ChartState &writing = next.under->MutableState();
+ writing = state;
+ writing.left.full &= left_full && state.left.full;
+ next.under->MutableRightFull() = right_full && state.left.full;
+ writing.left.length = left;
+ writing.right.length = right;
+ node.under->AddExtend(next.under);
+ }
+ return next;
+}
+
+} // namespace
+
+void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end) {
+ const lm::ngram::ChartState &state = *end.state;
+
+ unsigned char left = 0, right = 0;
+ Trie *node = &root;
+ while (true) {
+ if (left == state.left.length) {
+ node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false);
+ for (; right < state.right.length; ++right) {
+ node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false);
+ }
+ break;
+ }
+ node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false);
+ left++;
+ if (right == state.right.length) {
+ node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true);
+ for (; left < state.left.length; ++left) {
+ node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true);
+ }
+ break;
+ }
+ node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false);
+ right++;
+ }
+
+ node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true);
+ node->under->SetEnd(end.history, end.score);
+}
+
+#endif // BOOST_VERSION
+
+} // namespace search
diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh
new file mode 100644
index 00000000..646b8189
--- /dev/null
+++ b/klm/search/vertex_generator.hh
@@ -0,0 +1,99 @@
+#ifndef SEARCH_VERTEX_GENERATOR__
+#define SEARCH_VERTEX_GENERATOR__
+
+#include "search/edge.hh"
+#include "search/types.hh"
+#include "search/vertex.hh"
+#include "util/exception.hh"
+
+#include <boost/unordered_map.hpp>
+#include <boost/version.hpp>
+
+namespace lm {
+namespace ngram {
+class ChartState;
+} // namespace ngram
+} // namespace lm
+
+namespace search {
+
+class ContextBase;
+
+#if BOOST_VERSION > 104200
+// Parallel structure to VertexNode.
+struct Trie {
+ Trie() : under(NULL) {}
+
+ VertexNode *under;
+ boost::unordered_map<uint64_t, Trie> extend;
+};
+
+void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end);
+
+#endif // BOOST_VERSION
+
+// Output makes the single-best or n-best list.
+template <class Output> class VertexGenerator {
+ public:
+ VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) {
+ gen.root_.InitRoot();
+ }
+
+ void NewHypothesis(PartialEdge partial) {
+ nbest_.Add(existing_[hash_value(partial.CompletedState())], partial);
+ }
+
+ void FinishedSearch() {
+#if BOOST_VERSION > 104200
+ Trie root;
+ root.under = &gen_.root_;
+ for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) {
+ AddHypothesis(context_, root, nbest_.Complete(i->second));
+ }
+ existing_.clear();
+ root.under->SortAndSet(context_);
+#else
+ UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search.");
+#endif
+ }
+
+ const Vertex &Generating() const { return gen_; }
+
+ private:
+ ContextBase &context_;
+
+ Vertex &gen_;
+
+ typedef boost::unordered_map<uint64_t, typename Output::Combine> Existing;
+ Existing existing_;
+
+ Output &nbest_;
+};
+
+// Special case for root vertex: everything should come together into the root
+// node. In theory, this should happen naturally due to state collapsing with
+// <s> and </s>. If that's the case, VertexGenerator is fine, though it will
+// make one connection.
+template <class Output> class RootVertexGenerator {
+ public:
+ RootVertexGenerator(Vertex &gen, Output &out) : gen_(gen), out_(out) {}
+
+ void NewHypothesis(PartialEdge partial) {
+ out_.Add(combine_, partial);
+ }
+
+ void FinishedSearch() {
+ gen_.root_.InitRoot();
+ NBestComplete completed(out_.Complete(combine_));
+ gen_.root_.SetEnd(completed.history, completed.score);
+ }
+
+ private:
+ Vertex &gen_;
+
+ typename Output::Combine combine_;
+ Output &out_;
+};
+
+} // namespace search
+#endif // SEARCH_VERTEX_GENERATOR__