summaryrefslogtreecommitdiff
path: root/klm/search
diff options
context:
space:
mode:
Diffstat (limited to 'klm/search')
-rw-r--r--klm/search/Jamfile2
-rw-r--r--klm/search/Makefile.am11
-rw-r--r--klm/search/arity.hh8
-rw-r--r--klm/search/context.hh10
-rw-r--r--klm/search/edge.hh53
-rw-r--r--klm/search/edge_generator.cc144
-rw-r--r--klm/search/edge_generator.hh49
-rw-r--r--klm/search/edge_queue.cc25
-rw-r--r--klm/search/edge_queue.hh73
-rw-r--r--klm/search/final.hh41
-rw-r--r--klm/search/header.hh57
-rw-r--r--klm/search/source.hh48
-rw-r--r--klm/search/types.hh8
-rw-r--r--klm/search/vertex.cc10
-rw-r--r--klm/search/vertex.hh55
-rw-r--r--klm/search/vertex_generator.cc97
-rw-r--r--klm/search/vertex_generator.hh33
17 files changed, 318 insertions, 406 deletions
diff --git a/klm/search/Jamfile b/klm/search/Jamfile
index e8b14363..bc95c53a 100644
--- a/klm/search/Jamfile
+++ b/klm/search/Jamfile
@@ -1,4 +1,4 @@
-lib search : weights.cc vertex.cc vertex_generator.cc edge_queue.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ;
+lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ;
import testing ;
diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am
new file mode 100644
index 00000000..ccc5b7f6
--- /dev/null
+++ b/klm/search/Makefile.am
@@ -0,0 +1,11 @@
+noinst_LIBRARIES = libksearch.a
+
+libksearch_a_SOURCES = \
+ edge_generator.cc \
+ rule.cc \
+ vertex.cc \
+ vertex_generator.cc \
+ weights.cc
+
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I..
+
diff --git a/klm/search/arity.hh b/klm/search/arity.hh
deleted file mode 100644
index 09c2c671..00000000
--- a/klm/search/arity.hh
+++ /dev/null
@@ -1,8 +0,0 @@
-#ifndef SEARCH_ARITY__
-#define SEARCH_ARITY__
-namespace search {
-
-const unsigned int kMaxArity = 2;
-
-} // namespace search
-#endif // SEARCH_ARITY__
diff --git a/klm/search/context.hh b/klm/search/context.hh
index 27940053..62163144 100644
--- a/klm/search/context.hh
+++ b/klm/search/context.hh
@@ -7,6 +7,7 @@
#include "search/types.hh"
#include "search/vertex.hh"
#include "util/exception.hh"
+#include "util/pool.hh"
#include <boost/pool/object_pool.hpp>
#include <boost/ptr_container/ptr_vector.hpp>
@@ -21,10 +22,8 @@ class ContextBase {
public:
explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {}
- Final *NewFinal() {
- Final *ret = final_pool_.construct();
- assert(ret);
- return ret;
+ util::Pool &FinalPool() {
+ return final_pool_;
}
VertexNode *NewVertexNode() {
@@ -42,7 +41,8 @@ class ContextBase {
const Weights &GetWeights() const { return weights_; }
private:
- boost::object_pool<Final> final_pool_;
+ util::Pool final_pool_;
+
boost::object_pool<VertexNode> vertex_node_pool_;
unsigned int pop_limit_;
diff --git a/klm/search/edge.hh b/klm/search/edge.hh
index 77ab0ade..187904bf 100644
--- a/klm/search/edge.hh
+++ b/klm/search/edge.hh
@@ -2,30 +2,53 @@
#define SEARCH_EDGE__
#include "lm/state.hh"
-#include "search/arity.hh"
-#include "search/rule.hh"
+#include "search/header.hh"
#include "search/types.hh"
#include "search/vertex.hh"
+#include "util/pool.hh"
-#include <queue>
+#include <functional>
+
+#include <stdint.h>
namespace search {
-struct PartialEdge {
- Score score;
- // Terminals
- lm::ngram::ChartState between[kMaxArity + 1];
- // Non-terminals
- PartialVertex nt[kMaxArity];
+// Copyable, but the copy will be shallow.
+class PartialEdge : public Header {
+ public:
+ // Allow default construction for STL.
+ PartialEdge() {}
+
+ PartialEdge(util::Pool &pool, Arity arity)
+ : Header(pool.Allocate(Size(arity, arity + 1)), arity) {}
+
+ PartialEdge(util::Pool &pool, Arity arity, Arity chart_states)
+ : Header(pool.Allocate(Size(arity, chart_states)), arity) {}
- const lm::ngram::ChartState &CompletedState() const {
- return between[0];
- }
+ // Non-terminals
+ const PartialVertex *NT() const {
+ return reinterpret_cast<const PartialVertex*>(After());
+ }
+ PartialVertex *NT() {
+ return reinterpret_cast<PartialVertex*>(After());
+ }
- bool operator<(const PartialEdge &other) const {
- return score < other.score;
- }
+ const lm::ngram::ChartState &CompletedState() const {
+ return *Between();
+ }
+ const lm::ngram::ChartState *Between() const {
+ return reinterpret_cast<const lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex));
+ }
+ lm::ngram::ChartState *Between() {
+ return reinterpret_cast<lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex));
+ }
+
+ private:
+ static std::size_t Size(Arity arity, Arity chart_states) {
+ return kHeaderSize + arity * sizeof(PartialVertex) + chart_states * sizeof(lm::ngram::ChartState);
+ }
};
+
} // namespace search
#endif // SEARCH_EDGE__
diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc
index 56239dfb..260159b1 100644
--- a/klm/search/edge_generator.cc
+++ b/klm/search/edge_generator.cc
@@ -4,117 +4,107 @@
#include "lm/partial.hh"
#include "search/context.hh"
#include "search/vertex.hh"
-#include "search/vertex_generator.hh"
#include <numeric>
namespace search {
-EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) {
-/* for (unsigned char i = 0; i < edge.Arity(); ++i) {
- root.nt[i] = edge.GetVertex(i).RootPartial();
- }
- for (unsigned char i = edge.Arity(); i < 2; ++i) {
- root.nt[i] = kBlankPartialVertex;
- }*/
- generate_.push(&root);
- top_score_ = root.score;
-}
-
namespace {
-template <class Model> float FastScore(const Context<Model> &context, unsigned char victim, unsigned char arity, const PartialEdge &previous, PartialEdge &update) {
- memcpy(update.between, previous.between, sizeof(lm::ngram::ChartState) * (arity + 1));
-
- float ret = 0.0;
- lm::ngram::ChartState *before, *after;
- if (victim == 0) {
- before = &update.between[0];
- after = &update.between[(arity == 2 && previous.nt[1].Complete()) ? 2 : 1];
- } else {
- assert(victim == 1);
- assert(arity == 2);
- before = &update.between[previous.nt[0].Complete() ? 0 : 1];
- after = &update.between[2];
- }
- const lm::ngram::ChartState &previous_reveal = previous.nt[victim].State();
- const PartialVertex &update_nt = update.nt[victim];
+template <class Model> void FastScore(const Context<Model> &context, Arity victim, Arity before_idx, Arity incomplete, const PartialVertex &previous_vertex, PartialEdge update) {
+ lm::ngram::ChartState *between = update.Between();
+ lm::ngram::ChartState *before = &between[before_idx], *after = &between[before_idx + 1];
+
+ float adjustment = 0.0;
+ const lm::ngram::ChartState &previous_reveal = previous_vertex.State();
+ const PartialVertex &update_nt = update.NT()[victim];
const lm::ngram::ChartState &update_reveal = update_nt.State();
- float just_after = 0.0;
if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) {
- just_after += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length);
+ adjustment += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length);
}
- if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous.nt[victim].RightFull())) {
- ret += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right);
+ if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous_vertex.RightFull())) {
+ adjustment += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right);
}
if (update_nt.Complete()) {
if (update_reveal.left.full) {
before->left.full = true;
} else {
assert(update_reveal.left.length == update_reveal.right.length);
- ret += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length);
+ adjustment += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length);
}
- if (victim == 0) {
- update.between[0].right = after->right;
- } else {
- update.between[2].left = before->left;
+ before->right = after->right;
+ // Shift the others shifted one down, covering after.
+ for (lm::ngram::ChartState *cover = after; cover < between + incomplete; ++cover) {
+ *cover = *(cover + 1);
}
}
- return previous.score + (ret + just_after) * context.GetWeights().LM();
+ update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM());
}
} // namespace
-template <class Model> PartialEdge *EdgeGenerator::Pop(Context<Model> &context, boost::pool<> &partial_edge_pool) {
+template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) {
assert(!generate_.empty());
- PartialEdge &top = *generate_.top();
+ PartialEdge top = generate_.top();
generate_.pop();
- unsigned int victim = 0;
- unsigned char lowest_length = 255;
- for (unsigned char i = 0; i != arity_; ++i) {
- if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) {
- lowest_length = top.nt[i].Length();
- victim = i;
+ PartialVertex *const top_nt = top.NT();
+ const Arity arity = top.GetArity();
+
+ Arity victim = 0;
+ Arity victim_completed;
+ Arity incomplete;
+ // Select victim or return if complete.
+ {
+ Arity completed = 0;
+ unsigned char lowest_length = 255;
+ for (Arity i = 0; i != arity; ++i) {
+ if (top_nt[i].Complete()) {
+ ++completed;
+ } else if (top_nt[i].Length() < lowest_length) {
+ lowest_length = top_nt[i].Length();
+ victim = i;
+ victim_completed = completed;
+ }
}
- }
- if (lowest_length == 255) {
- // All states report complete.
- top.between[0].right = top.between[arity_].right;
- // Now top.between[0] is the full edge state.
- top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score;
- return &top;
+ if (lowest_length == 255) {
+ return top;
+ }
+ incomplete = arity - completed;
}
- unsigned int stay = !victim;
- PartialEdge &continuation = *static_cast<PartialEdge*>(partial_edge_pool.malloc());
- float old_bound = top.nt[victim].Bound();
- // The alternate's score will change because alternate.nt[victim] changes.
- bool split = top.nt[victim].Split(continuation.nt[victim]);
- // top is now the alternate.
+ PartialVertex old_value(top_nt[victim]);
+ PartialVertex alternate_changed;
+ if (top_nt[victim].Split(alternate_changed)) {
+ PartialEdge alternate(partial_edge_pool_, arity, incomplete + 1);
+ alternate.SetScore(top.GetScore() + alternate_changed.Bound() - old_value.Bound());
- continuation.nt[stay] = top.nt[stay];
- continuation.score = FastScore(context, victim, arity_, top, continuation);
- // TODO: dedupe?
- generate_.push(&continuation);
+ alternate.SetNote(top.GetNote());
+
+ PartialVertex *alternate_nt = alternate.NT();
+ for (Arity i = 0; i < victim; ++i) alternate_nt[i] = top_nt[i];
+ alternate_nt[victim] = alternate_changed;
+ for (Arity i = victim + 1; i < arity; ++i) alternate_nt[i] = top_nt[i];
+
+ memcpy(alternate.Between(), top.Between(), sizeof(lm::ngram::ChartState) * (incomplete + 1));
- if (split) {
- // We have an alternate.
- top.score += top.nt[victim].Bound() - old_bound;
// TODO: dedupe?
- generate_.push(&top);
- } else {
- partial_edge_pool.free(&top);
+ generate_.push(alternate);
}
- top_score_ = generate_.top()->score;
- return NULL;
+ // top is now the continuation.
+ FastScore(context, victim, victim - victim_completed, incomplete, old_value, top);
+ // TODO: dedupe?
+ generate_.push(top);
+
+ // Invalid indicates no new hypothesis generated.
+ return PartialEdge();
}
-template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, boost::pool<> &partial_edge_pool);
-template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, boost::pool<> &partial_edge_pool);
-template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context, boost::pool<> &partial_edge_pool);
-template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context, boost::pool<> &partial_edge_pool);
-template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context, boost::pool<> &partial_edge_pool);
-template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context, boost::pool<> &partial_edge_pool);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context);
+template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context);
} // namespace search
diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh
index 875ccc5e..582c78b7 100644
--- a/klm/search/edge_generator.hh
+++ b/klm/search/edge_generator.hh
@@ -3,11 +3,8 @@
#include "search/edge.hh"
#include "search/note.hh"
+#include "search/types.hh"
-#include <boost/pool/pool.hpp>
-#include <boost/unordered_map.hpp>
-
-#include <functional>
#include <queue>
namespace lm {
@@ -20,38 +17,40 @@ namespace search {
template <class Model> class Context;
-class VertexGenerator;
-
-struct PartialEdgePointerLess : std::binary_function<const PartialEdge *, const PartialEdge *, bool> {
- bool operator()(const PartialEdge *first, const PartialEdge *second) const {
- return *first < *second;
- }
-};
-
class EdgeGenerator {
public:
- EdgeGenerator(PartialEdge &root, unsigned char arity, Note note);
+ EdgeGenerator() {}
- Score TopScore() const {
- return top_score_;
+ PartialEdge AllocateEdge(Arity arity) {
+ return PartialEdge(partial_edge_pool_, arity);
}
- Note GetNote() const {
- return note_;
+ void AddEdge(PartialEdge edge) {
+ generate_.push(edge);
}
- // Pop. If there's a complete hypothesis, return it. Otherwise return NULL.
- template <class Model> PartialEdge *Pop(Context<Model> &context, boost::pool<> &partial_edge_pool);
+ bool Empty() const { return generate_.empty(); }
+
+ // Pop. If there's a complete hypothesis, return it. Otherwise return an invalid PartialEdge.
+ template <class Model> PartialEdge Pop(Context<Model> &context);
+
+ template <class Model, class Output> void Search(Context<Model> &context, Output &output) {
+ unsigned to_pop = context.PopLimit();
+ while (to_pop > 0 && !generate_.empty()) {
+ PartialEdge got(Pop(context));
+ if (got.Valid()) {
+ output.NewHypothesis(got);
+ --to_pop;
+ }
+ }
+ output.FinishedSearch();
+ }
private:
- Score top_score_;
-
- unsigned char arity_;
+ util::Pool partial_edge_pool_;
- typedef std::priority_queue<PartialEdge*, std::vector<PartialEdge*>, PartialEdgePointerLess> Generate;
+ typedef std::priority_queue<PartialEdge> Generate;
Generate generate_;
-
- Note note_;
};
} // namespace search
diff --git a/klm/search/edge_queue.cc b/klm/search/edge_queue.cc
deleted file mode 100644
index e3ae6ebf..00000000
--- a/klm/search/edge_queue.cc
+++ /dev/null
@@ -1,25 +0,0 @@
-#include "search/edge_queue.hh"
-
-#include "lm/left.hh"
-#include "search/context.hh"
-
-#include <stdint.h>
-
-namespace search {
-
-EdgeQueue::EdgeQueue(unsigned int pop_limit_hint) : partial_edge_pool_(sizeof(PartialEdge), pop_limit_hint * 2) {
- take_ = static_cast<PartialEdge*>(partial_edge_pool_.malloc());
-}
-
-/*void EdgeQueue::AddEdge(PartialEdge &root, unsigned char arity, Note note) {
- // Ignore empty edges.
- for (unsigned char i = 0; i < edge.Arity(); ++i) {
- PartialVertex root(edge.GetVertex(i).RootPartial());
- if (root.Empty()) return;
- total_score += root.Bound();
- }
- PartialEdge &allocated = *static_cast<PartialEdge*>(partial_edge_pool_.malloc());
- allocated.score = total_score;
-}*/
-
-} // namespace search
diff --git a/klm/search/edge_queue.hh b/klm/search/edge_queue.hh
deleted file mode 100644
index 187eaed7..00000000
--- a/klm/search/edge_queue.hh
+++ /dev/null
@@ -1,73 +0,0 @@
-#ifndef SEARCH_EDGE_QUEUE__
-#define SEARCH_EDGE_QUEUE__
-
-#include "search/edge.hh"
-#include "search/edge_generator.hh"
-#include "search/note.hh"
-
-#include <boost/pool/pool.hpp>
-#include <boost/pool/object_pool.hpp>
-
-#include <queue>
-
-namespace search {
-
-template <class Model> class Context;
-
-class EdgeQueue {
- public:
- explicit EdgeQueue(unsigned int pop_limit_hint);
-
- PartialEdge &InitializeEdge() {
- return *take_;
- }
-
- void AddEdge(unsigned char arity, Note note) {
- generate_.push(edge_pool_.construct(*take_, arity, note));
- take_ = static_cast<PartialEdge*>(partial_edge_pool_.malloc());
- }
-
- bool Empty() const { return generate_.empty(); }
-
- /* Generate hypotheses and send them to output. Normally, output is a
- * VertexGenerator, but the decoder may want to route edges to different
- * vertices i.e. if they have different LHS non-terminal labels.
- */
- template <class Model, class Output> void Search(Context<Model> &context, Output &output) {
- int to_pop = context.PopLimit();
- while (to_pop > 0 && !generate_.empty()) {
- EdgeGenerator *top = generate_.top();
- generate_.pop();
- PartialEdge *ret = top->Pop(context, partial_edge_pool_);
- if (ret) {
- output.NewHypothesis(*ret, top->GetNote());
- --to_pop;
- if (top->TopScore() != -kScoreInf) {
- generate_.push(top);
- }
- } else {
- generate_.push(top);
- }
- }
- output.FinishedSearch();
- }
-
- private:
- boost::object_pool<EdgeGenerator> edge_pool_;
-
- struct LessByTopScore : public std::binary_function<const EdgeGenerator *, const EdgeGenerator *, bool> {
- bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const {
- return first->TopScore() < second->TopScore();
- }
- };
-
- typedef std::priority_queue<EdgeGenerator*, std::vector<EdgeGenerator*>, LessByTopScore> Generate;
- Generate generate_;
-
- boost::pool<> partial_edge_pool_;
-
- PartialEdge *take_;
-};
-
-} // namespace search
-#endif // SEARCH_EDGE_QUEUE__
diff --git a/klm/search/final.hh b/klm/search/final.hh
index 1b3092ac..50e62cf2 100644
--- a/klm/search/final.hh
+++ b/klm/search/final.hh
@@ -1,37 +1,34 @@
#ifndef SEARCH_FINAL__
#define SEARCH_FINAL__
-#include "search/arity.hh"
-#include "search/note.hh"
-#include "search/types.hh"
-
-#include <boost/array.hpp>
+#include "search/header.hh"
+#include "util/pool.hh"
namespace search {
-class Final {
+// A full hypothesis with pointers to children.
+class Final : public Header {
public:
- typedef boost::array<const Final*, search::kMaxArity> ChildArray;
+ Final() {}
- void Reset(Score bound, Note note, const Final &left, const Final &right) {
- bound_ = bound;
- note_ = note;
- children_[0] = &left;
- children_[1] = &right;
+ Final(util::Pool &pool, Score score, Arity arity, Note note)
+ : Header(pool.Allocate(Size(arity)), arity) {
+ SetScore(score);
+ SetNote(note);
}
- const ChildArray &Children() const { return children_; }
-
- Note GetNote() const { return note_; }
-
- Score Bound() const { return bound_; }
+ // These are arrays of length GetArity().
+ Final *Children() {
+ return reinterpret_cast<Final*>(After());
+ }
+ const Final *Children() const {
+ return reinterpret_cast<const Final*>(After());
+ }
private:
- Score bound_;
-
- Note note_;
-
- ChildArray children_;
+ static std::size_t Size(Arity arity) {
+ return kHeaderSize + arity * sizeof(const Final);
+ }
};
} // namespace search
diff --git a/klm/search/header.hh b/klm/search/header.hh
new file mode 100644
index 00000000..25550dbe
--- /dev/null
+++ b/klm/search/header.hh
@@ -0,0 +1,57 @@
+#ifndef SEARCH_HEADER__
+#define SEARCH_HEADER__
+
+// Header consisting of Score, Arity, and Note
+
+#include "search/note.hh"
+#include "search/types.hh"
+
+#include <stdint.h>
+
+namespace search {
+
+// Copying is shallow.
+class Header {
+ public:
+ bool Valid() const { return base_; }
+
+ Score GetScore() const {
+ return *reinterpret_cast<const float*>(base_);
+ }
+ void SetScore(Score to) {
+ *reinterpret_cast<float*>(base_) = to;
+ }
+ bool operator<(const Header &other) const {
+ return GetScore() < other.GetScore();
+ }
+
+ Arity GetArity() const {
+ return *reinterpret_cast<const Arity*>(base_ + sizeof(Score));
+ }
+
+ Note GetNote() const {
+ return *reinterpret_cast<const Note*>(base_ + sizeof(Score) + sizeof(Arity));
+ }
+ void SetNote(Note to) {
+ *reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to;
+ }
+
+ protected:
+ Header() : base_(NULL) {}
+
+ Header(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) {
+ *reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity;
+ }
+
+ static const std::size_t kHeaderSize = sizeof(Score) + sizeof(Arity) + sizeof(Note);
+
+ uint8_t *After() { return base_ + kHeaderSize; }
+ const uint8_t *After() const { return base_ + kHeaderSize; }
+
+ private:
+ uint8_t *base_;
+};
+
+} // namespace search
+
+#endif // SEARCH_HEADER__
diff --git a/klm/search/source.hh b/klm/search/source.hh
deleted file mode 100644
index 11839f7b..00000000
--- a/klm/search/source.hh
+++ /dev/null
@@ -1,48 +0,0 @@
-#ifndef SEARCH_SOURCE__
-#define SEARCH_SOURCE__
-
-#include "search/types.hh"
-
-#include <assert.h>
-#include <vector>
-
-namespace search {
-
-template <class Final> class Source {
- public:
- Source() : bound_(kScoreInf) {}
-
- Index Size() const {
- return final_.size();
- }
-
- Score Bound() const {
- return bound_;
- }
-
- const Final &operator[](Index index) const {
- return *final_[index];
- }
-
- Score ScoreOrBound(Index index) const {
- return Size() > index ? final_[index]->Total() : Bound();
- }
-
- protected:
- void AddFinal(const Final &store) {
- final_.push_back(&store);
- }
-
- void SetBound(Score to) {
- assert(to <= bound_ + 0.001);
- bound_ = to;
- }
-
- private:
- std::vector<const Final *> final_;
-
- Score bound_;
-};
-
-} // namespace search
-#endif // SEARCH_SOURCE__
diff --git a/klm/search/types.hh b/klm/search/types.hh
index 9726379f..06eb5bfa 100644
--- a/klm/search/types.hh
+++ b/klm/search/types.hh
@@ -1,17 +1,13 @@
#ifndef SEARCH_TYPES__
#define SEARCH_TYPES__
-#include <cmath>
+#include <stdint.h>
namespace search {
typedef float Score;
-const Score kScoreInf = INFINITY;
-// This could have been an enum but gcc wants 4 bytes.
-typedef bool ExtendDirection;
-const ExtendDirection kExtendLeft = 0;
-const ExtendDirection kExtendRight = 1;
+typedef uint32_t Arity;
} // namespace search
diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc
index cc53c0dd..11f4631f 100644
--- a/klm/search/vertex.cc
+++ b/klm/search/vertex.cc
@@ -21,9 +21,9 @@ struct GreaterByBound : public std::binary_function<const VertexNode *, const Ve
void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) {
if (Complete()) {
- assert(end_);
+ assert(end_.Valid());
assert(extend_.empty());
- bound_ = end_->Bound();
+ bound_ = end_.GetScore();
return;
}
if (extend_.size() == 1 && parent_ptr) {
@@ -39,10 +39,4 @@ void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) {
bound_ = extend_.front()->Bound();
}
-namespace {
-VertexNode kBlankVertexNode;
-} // namespace
-
-PartialVertex kBlankPartialVertex(kBlankVertexNode);
-
} // namespace search
diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh
index e1a9ad11..52bc1dfe 100644
--- a/klm/search/vertex.hh
+++ b/klm/search/vertex.hh
@@ -18,7 +18,7 @@ class ContextBase;
class VertexNode {
public:
- VertexNode() : end_(NULL) {}
+ VertexNode() {}
void InitRoot() {
extend_.clear();
@@ -26,8 +26,7 @@ class VertexNode {
state_.left.length = 0;
state_.right.length = 0;
right_full_ = false;
- bound_ = -kScoreInf;
- end_ = NULL;
+ end_ = Final();
}
lm::ngram::ChartState &MutableState() { return state_; }
@@ -37,19 +36,20 @@ class VertexNode {
extend_.push_back(next);
}
- void SetEnd(Final *end) { end_ = end; }
+ void SetEnd(Final end) {
+ assert(!end_.Valid());
+ end_ = end;
+ }
- Final &MutableEnd() { return *end_; }
-
void SortAndSet(ContextBase &context, VertexNode **parent_pointer);
// Should only happen to a root node when the entire vertex is empty.
bool Empty() const {
- return !end_ && extend_.empty();
+ return !end_.Valid() && extend_.empty();
}
bool Complete() const {
- return end_;
+ return end_.Valid();
}
const lm::ngram::ChartState &State() const { return state_; }
@@ -63,8 +63,8 @@ class VertexNode {
return state_.left.length + state_.right.length;
}
- // May be NULL.
- const Final *End() const { return end_; }
+ // Will be invalid unless this is a leaf.
+ const Final End() const { return end_; }
const VertexNode &operator[](size_t index) const {
return *extend_[index];
@@ -81,7 +81,7 @@ class VertexNode {
bool right_full_;
Score bound_;
- Final *end_;
+ Final end_;
};
class PartialVertex {
@@ -97,7 +97,7 @@ class PartialVertex {
const lm::ngram::ChartState &State() const { return back_->State(); }
bool RightFull() const { return back_->RightFull(); }
- Score Bound() const { return Complete() ? back_->End()->Bound() : (*back_)[index_].Bound(); }
+ Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); }
unsigned char Length() const { return back_->Length(); }
@@ -105,20 +105,24 @@ class PartialVertex {
return index_ + 1 < back_->Size();
}
- // Split into continuation and alternative, rendering this the alternative.
- bool Split(PartialVertex &continuation) {
+ // Split into continuation and alternative, rendering this the continuation.
+ bool Split(PartialVertex &alternative) {
assert(!Complete());
- continuation.back_ = &((*back_)[index_]);
- continuation.index_ = 0;
+ bool ret;
if (index_ + 1 < back_->Size()) {
- ++index_;
- return true;
+ alternative.index_ = index_ + 1;
+ alternative.back_ = back_;
+ ret = true;
+ } else {
+ ret = false;
}
- return false;
+ back_ = &((*back_)[index_]);
+ index_ = 0;
+ return ret;
}
- const Final &End() const {
- return *back_->End();
+ const Final End() const {
+ return back_->End();
}
private:
@@ -126,25 +130,22 @@ class PartialVertex {
unsigned int index_;
};
-extern PartialVertex kBlankPartialVertex;
-
class Vertex {
public:
Vertex() {}
PartialVertex RootPartial() const { return PartialVertex(root_); }
- const Final *BestChild() const {
+ const Final BestChild() const {
PartialVertex top(RootPartial());
if (top.Empty()) {
- return NULL;
+ return Final();
} else {
PartialVertex continuation;
while (!top.Complete()) {
top.Split(continuation);
- top = continuation;
}
- return &top.End();
+ return top.End();
}
}
diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc
index d94e6e06..0945fe55 100644
--- a/klm/search/vertex_generator.cc
+++ b/klm/search/vertex_generator.cc
@@ -10,74 +10,85 @@ namespace search {
VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) {
gen.root_.InitRoot();
- root_.under = &gen.root_;
}
namespace {
+
const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
-} // namespace
-void VertexGenerator::NewHypothesis(const PartialEdge &partial, Note note) {
- const lm::ngram::ChartState &state = partial.CompletedState();
- std::pair<Existing::iterator, bool> got(existing_.insert(std::pair<uint64_t, Final*>(hash_value(state), NULL)));
- if (!got.second) {
- // Found it already.
- Final &exists = *got.first->second;
- if (exists.Bound() < partial.score) {
- exists.Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End());
- }
- return;
+// Parallel structure to VertexNode.
+struct Trie {
+ Trie() : under(NULL) {}
+
+ VertexNode *under;
+ boost::unordered_map<uint64_t, Trie> extend;
+};
+
+Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {
+ Trie &next = node.extend[added];
+ if (!next.under) {
+ next.under = context.NewVertexNode();
+ lm::ngram::ChartState &writing = next.under->MutableState();
+ writing = state;
+ writing.left.full &= left_full && state.left.full;
+ next.under->MutableRightFull() = right_full && state.left.full;
+ writing.left.length = left;
+ writing.right.length = right;
+ node.under->AddExtend(next.under);
}
+ return next;
+}
+
+void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) {
+ Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote());
+ Final *child_out = final.Children();
+ const PartialVertex *part = partial.NT();
+ const PartialVertex *const part_end_loop = part + partial.GetArity();
+ for (; part != part_end_loop; ++part, ++child_out)
+ *child_out = part->End();
+
+ starter.under->SetEnd(final);
+}
+
+void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) {
+ const lm::ngram::ChartState &state = partial.CompletedState();
+
unsigned char left = 0, right = 0;
- Trie *node = &root_;
+ Trie *node = &root;
while (true) {
if (left == state.left.length) {
- node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, true, right, false);
+ node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false);
for (; right < state.right.length; ++right) {
- node = &FindOrInsert(*node, state.right.words[right], state, left, true, right + 1, false);
+ node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false);
}
break;
}
- node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, false);
+ node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false);
left++;
if (right == state.right.length) {
- node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, false, right, true);
+ node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true);
for (; left < state.left.length; ++left) {
- node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, true);
+ node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true);
}
break;
}
- node = &FindOrInsert(*node, state.right.words[right], state, left, false, right + 1, false);
+ node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false);
right++;
}
- node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true);
- got.first->second = CompleteTransition(*node, state, note, partial);
+ node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true);
+ CompleteTransition(context, *node, partial);
}
-VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {
- VertexGenerator::Trie &next = node.extend[added];
- if (!next.under) {
- next.under = context_.NewVertexNode();
- lm::ngram::ChartState &writing = next.under->MutableState();
- writing = state;
- writing.left.full &= left_full && state.left.full;
- next.under->MutableRightFull() = right_full && state.left.full;
- writing.left.length = left;
- writing.right.length = right;
- node.under->AddExtend(next.under);
- }
- return next;
-}
+} // namespace
-Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial) {
- VertexNode &node = *starter.under;
- assert(node.State().left.full == state.left.full);
- assert(!node.End());
- Final *final = context_.NewFinal();
- final->Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End());
- node.SetEnd(final);
- return final;
+void VertexGenerator::FinishedSearch() {
+ Trie root;
+ root.under = &gen_.root_;
+ for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) {
+ AddHypothesis(context_, root, i->second);
+ }
+ root.under->SortAndSet(context_, NULL);
}
} // namespace search
diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh
index 6b98da3e..60e86112 100644
--- a/klm/search/vertex_generator.hh
+++ b/klm/search/vertex_generator.hh
@@ -1,13 +1,11 @@
#ifndef SEARCH_VERTEX_GENERATOR__
#define SEARCH_VERTEX_GENERATOR__
-#include "search/note.hh"
+#include "search/edge.hh"
#include "search/vertex.hh"
#include <boost/unordered_map.hpp>
-#include <queue>
-
namespace lm {
namespace ngram {
class ChartState;
@@ -18,40 +16,29 @@ namespace search {
class ContextBase;
class Final;
-struct PartialEdge;
class VertexGenerator {
public:
VertexGenerator(ContextBase &context, Vertex &gen);
- void NewHypothesis(const PartialEdge &partial, Note note);
-
- void FinishedSearch() {
- root_.under->SortAndSet(context_, NULL);
+ void NewHypothesis(PartialEdge partial) {
+ const lm::ngram::ChartState &state = partial.CompletedState();
+ std::pair<Existing::iterator, bool> ret(existing_.insert(std::make_pair(hash_value(state), partial)));
+ if (!ret.second && ret.first->second < partial) {
+ ret.first->second = partial;
+ }
}
+ void FinishedSearch();
+
const Vertex &Generating() const { return gen_; }
private:
- // Parallel structure to VertexNode.
- struct Trie {
- Trie() : under(NULL) {}
-
- VertexNode *under;
- boost::unordered_map<uint64_t, Trie> extend;
- };
-
- Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full);
-
- Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial);
-
ContextBase &context_;
Vertex &gen_;
- Trie root_;
-
- typedef boost::unordered_map<uint64_t, Final*> Existing;
+ typedef boost::unordered_map<uint64_t, PartialEdge> Existing;
Existing existing_;
};