summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-10-14 10:46:34 +0100
committerKenneth Heafield <github@kheafield.com>2012-10-14 10:46:34 +0100
commit9b99cb844e3e379b557ff8578df27893ce147f1a (patch)
treeca60a3b35312cf6d9741cd05ca958714d2c4e17c
parent28403a7d3cbca2de743a7d654ffb9e1600ce7c5c (diff)
Update to faster but less cute search
-rw-r--r--decoder/lazy.cc57
-rw-r--r--klm/lm/fragment.cc37
-rw-r--r--klm/search/Jamfile2
-rw-r--r--klm/search/edge.hh31
-rw-r--r--klm/search/edge_generator.cc53
-rw-r--r--klm/search/edge_generator.hh24
-rw-r--r--klm/search/edge_queue.cc25
-rw-r--r--klm/search/edge_queue.hh73
-rw-r--r--klm/search/final.hh11
-rw-r--r--klm/search/note.hh12
-rw-r--r--klm/search/rule.cc32
-rw-r--r--klm/search/rule.hh31
-rw-r--r--klm/search/vertex.hh45
-rw-r--r--klm/search/vertex_generator.cc32
-rw-r--r--klm/search/vertex_generator.hh35
15 files changed, 279 insertions, 221 deletions
diff --git a/decoder/lazy.cc b/decoder/lazy.cc
index c4138d7b..9dc657d6 100644
--- a/decoder/lazy.cc
+++ b/decoder/lazy.cc
@@ -8,6 +8,7 @@
#include "search/config.hh"
#include "search/context.hh"
#include "search/edge.hh"
+#include "search/edge_queue.hh"
#include "search/vertex.hh"
#include "search/vertex_generator.hh"
#include "util/exception.hh"
@@ -75,7 +76,7 @@ template <class Model> class Lazy : public LazyBase {
void Search(unsigned int pop_limit, const Hypergraph &hg) const;
private:
- void ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::Edge &out) const;
+ unsigned char ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const;
const Model m_;
};
@@ -93,73 +94,73 @@ LazyBase *LazyBase::Load(const char *model_file, const std::vector<weight_t> &we
}
}
-void PrintFinal(const Hypergraph &hg, const search::Edge *edge_base, const search::Final &final) {
- const std::vector<WordID> &words = hg.edges_[&final.From() - edge_base].rule_->e();
+void PrintFinal(const Hypergraph &hg, const search::Final &final) {
+ const std::vector<WordID> &words = static_cast<const Hypergraph::Edge*>(final.GetNote().vp)->rule_->e();
boost::array<const search::Final*, search::kMaxArity>::const_iterator child(final.Children().begin());
for (std::vector<WordID>::const_iterator i = words.begin(); i != words.end(); ++i) {
if (*i > 0) {
std::cout << TD::Convert(*i) << ' ';
} else {
- PrintFinal(hg, edge_base, **child++);
+ PrintFinal(hg, **child++);
}
}
}
template <class Model> void Lazy<Model>::Search(unsigned int pop_limit, const Hypergraph &hg) const {
boost::scoped_array<search::Vertex> out_vertices(new search::Vertex[hg.nodes_.size()]);
- boost::scoped_array<search::Edge> out_edges(new search::Edge[hg.edges_.size()]);
search::Config config(weights_, pop_limit);
search::Context<Model> context(config, m_);
for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) {
- search::Vertex &out_vertex = out_vertices[i];
+ search::EdgeQueue queue(context.PopLimit());
const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_;
for (unsigned int j = 0; j < down_edges.size(); ++j) {
unsigned int edge_index = down_edges[j];
- ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], out_edges[edge_index]);
- out_vertex.Add(out_edges[edge_index]);
+ unsigned char arity = ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], queue.InitializeEdge());
+ search::Note note;
+ note.vp = &hg.edges_[edge_index];
+ if (arity != 255) queue.AddEdge(arity, note);
}
- out_vertex.FinishedAdding();
- search::VertexGenerator(context, out_vertex);
+ search::VertexGenerator vertex_gen(context, out_vertices[i]);
+ queue.Search(context, vertex_gen);
}
- search::PartialVertex top = out_vertices[hg.nodes_.size() - 2].RootPartial();
- if (top.Empty()) {
- std::cout << "NO PATH FOUND";
+ const search::Final *top = out_vertices[hg.nodes_.size() - 2].BestChild();
+ if (!top) {
+ std::cout << "NO PATH FOUND" << std::endl;
} else {
- search::PartialVertex continuation;
- while (!top.Complete()) {
- top.Split(continuation);
- top = continuation;
- }
- PrintFinal(hg, out_edges.get(), top.End());
- std::cout << "||| " << top.End().Bound() << std::endl;
+ PrintFinal(hg, *top);
+ std::cout << "||| " << top->Bound() << std::endl;
}
}
-// TODO: get weights into here somehow.
-template <class Model> void Lazy<Model>::ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::Edge &out) const {
+template <class Model> unsigned char Lazy<Model>::ConvertEdge(const search::Context<Model> &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const {
const std::vector<WordID> &e = in.rule_->e();
std::vector<lm::WordIndex> words;
unsigned int terminals = 0;
+ unsigned char nt = 0;
for (std::vector<WordID>::const_iterator word = e.begin(); word != e.end(); ++word) {
if (*word <= 0) {
- out.Add(vertices[in.tail_nodes_[-*word]]);
+ out.nt[nt] = vertices[in.tail_nodes_[-*word]].RootPartial();
+ if (out.nt[nt].Empty()) return 255;
+ ++nt;
words.push_back(lm::kMaxWordIndex);
} else {
++terminals;
words.push_back(vocab_.FromCDec(*word));
}
}
+ for (unsigned char fill = nt; fill < search::kMaxArity; ++fill) {
+ out.nt[nt] = search::kBlankPartialVertex;
+ }
if (final) {
words.push_back(m_.GetVocabulary().EndSentence());
}
- float additive = in.rule_->GetFeatureValues().dot(cdec_weights_);
- UTIL_THROW_IF(isnan(additive), util::Exception, "Bad dot product");
- additive -= static_cast<float>(terminals) * context.GetWeights().WordPenalty() / M_LN10;
-
- out.InitRule().Init(context, additive, words, final);
+ out.score = in.rule_->GetFeatureValues().dot(cdec_weights_);
+ out.score -= static_cast<float>(terminals) * context.GetWeights().WordPenalty() / M_LN10;
+ out.score += search::ScoreRule(context, words, final, out.between);
+ return nt;
}
boost::scoped_ptr<LazyBase> AwfulGlobalLazy;
diff --git a/klm/lm/fragment.cc b/klm/lm/fragment.cc
new file mode 100644
index 00000000..0267cd4e
--- /dev/null
+++ b/klm/lm/fragment.cc
@@ -0,0 +1,37 @@
+#include "lm/binary_format.hh"
+#include "lm/model.hh"
+#include "lm/left.hh"
+#include "util/tokenize_piece.hh"
+
+template <class Model> void Query(const char *name) {
+ Model model(name);
+ std::string line;
+ lm::ngram::ChartState ignored;
+ while (getline(std::cin, line)) {
+ lm::ngram::RuleScore<Model> scorer(model, ignored);
+ for (util::TokenIter<util::SingleCharacter, true> i(line, ' '); i; ++i) {
+ scorer.Terminal(model.GetVocabulary().Index(*i));
+ }
+ std::cout << scorer.Finish() << '\n';
+ }
+}
+
+int main(int argc, char *argv[]) {
+ if (argc != 2) {
+ std::cerr << "Expected model file name." << std::endl;
+ return 1;
+ }
+ const char *name = argv[1];
+ lm::ngram::ModelType model_type = lm::ngram::PROBING;
+ lm::ngram::RecognizeBinary(name, model_type);
+ switch (model_type) {
+ case lm::ngram::PROBING:
+ Query<lm::ngram::ProbingModel>(name);
+ break;
+ case lm::ngram::REST_PROBING:
+ Query<lm::ngram::RestProbingModel>(name);
+ break;
+ default:
+ std::cerr << "Model type not supported yet." << std::endl;
+ }
+}
diff --git a/klm/search/Jamfile b/klm/search/Jamfile
index ac47c249..e8b14363 100644
--- a/klm/search/Jamfile
+++ b/klm/search/Jamfile
@@ -1,4 +1,4 @@
-lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil : : : <include>.. ;
+lib search : weights.cc vertex.cc vertex_generator.cc edge_queue.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ;
import testing ;
diff --git a/klm/search/edge.hh b/klm/search/edge.hh
index 4d2a5cbf..77ab0ade 100644
--- a/klm/search/edge.hh
+++ b/klm/search/edge.hh
@@ -11,33 +11,6 @@
namespace search {
-class Edge {
- public:
- Edge() {
- end_to_ = to_;
- }
-
- Rule &InitRule() { return rule_; }
-
- void Add(Vertex &vertex) {
- assert(end_to_ - to_ < kMaxArity);
- *(end_to_++) = &vertex;
- }
-
- const Vertex &GetVertex(std::size_t index) const {
- return *to_[index];
- }
-
- const Rule &GetRule() const { return rule_; }
-
- private:
- // Rule and pointers to rule arguments.
- Rule rule_;
-
- Vertex *to_[kMaxArity];
- Vertex **end_to_;
-};
-
struct PartialEdge {
Score score;
// Terminals
@@ -45,6 +18,10 @@ struct PartialEdge {
// Non-terminals
PartialVertex nt[kMaxArity];
+ const lm::ngram::ChartState &CompletedState() const {
+ return between[0];
+ }
+
bool operator<(const PartialEdge &other) const {
return score < other.score;
}
diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc
index d135899a..56239dfb 100644
--- a/klm/search/edge_generator.cc
+++ b/klm/search/edge_generator.cc
@@ -10,28 +10,15 @@
namespace search {
-bool EdgeGenerator::Init(Edge &edge, VertexGenerator &parent) {
- from_ = &edge;
- for (unsigned int i = 0; i < GetRule().Arity(); ++i) {
- if (edge.GetVertex(i).RootPartial().Empty()) return false;
- }
- PartialEdge &root = *parent.MallocPartialEdge();
- root.score = GetRule().Bound();
- for (unsigned int i = 0; i < GetRule().Arity(); ++i) {
+EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) {
+/* for (unsigned char i = 0; i < edge.Arity(); ++i) {
root.nt[i] = edge.GetVertex(i).RootPartial();
- root.score += root.nt[i].Bound();
}
- for (unsigned int i = GetRule().Arity(); i < 2; ++i) {
+ for (unsigned char i = edge.Arity(); i < 2; ++i) {
root.nt[i] = kBlankPartialVertex;
- }
- for (unsigned int i = 0; i < GetRule().Arity() + 1; ++i) {
- root.between[i] = GetRule().Lexical(i);
- }
- // wtf no clear method?
- generate_ = Generate();
+ }*/
generate_.push(&root);
- top_ = root.score;
- return true;
+ top_score_ = root.score;
}
namespace {
@@ -78,13 +65,13 @@ template <class Model> float FastScore(const Context<Model> &context, unsigned c
} // namespace
-template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGenerator &parent) {
+template <class Model> PartialEdge *EdgeGenerator::Pop(Context<Model> &context, boost::pool<> &partial_edge_pool) {
assert(!generate_.empty());
PartialEdge &top = *generate_.top();
generate_.pop();
unsigned int victim = 0;
unsigned char lowest_length = 255;
- for (unsigned int i = 0; i != GetRule().Arity(); ++i) {
+ for (unsigned char i = 0; i != arity_; ++i) {
if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) {
lowest_length = top.nt[i].Length();
victim = i;
@@ -92,21 +79,21 @@ template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGe
}
if (lowest_length == 255) {
// All states report complete.
- top.between[0].right = top.between[GetRule().Arity()].right;
- parent.NewHypothesis(top.between[0], *from_, top);
- top_ = generate_.empty() ? -kScoreInf : generate_.top()->score;
- return !generate_.empty();
+ top.between[0].right = top.between[arity_].right;
+ // Now top.between[0] is the full edge state.
+ top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score;
+ return &top;
}
unsigned int stay = !victim;
- PartialEdge &continuation = *parent.MallocPartialEdge();
+ PartialEdge &continuation = *static_cast<PartialEdge*>(partial_edge_pool.malloc());
float old_bound = top.nt[victim].Bound();
// The alternate's score will change because alternate.nt[victim] changes.
bool split = top.nt[victim].Split(continuation.nt[victim]);
// top is now the alternate.
continuation.nt[stay] = top.nt[stay];
- continuation.score = FastScore(context, victim, GetRule().Arity(), top, continuation);
+ continuation.score = FastScore(context, victim, arity_, top, continuation);
// TODO: dedupe?
generate_.push(&continuation);
@@ -116,14 +103,18 @@ template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGe
// TODO: dedupe?
generate_.push(&top);
} else {
- parent.FreePartialEdge(&top);
+ partial_edge_pool.free(&top);
}
- top_ = generate_.top()->score;
- return true;
+ top_score_ = generate_.top()->score;
+ return NULL;
}
-template bool EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, VertexGenerator &parent);
-template bool EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, VertexGenerator &parent);
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, boost::pool<> &partial_edge_pool);
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, boost::pool<> &partial_edge_pool);
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context, boost::pool<> &partial_edge_pool);
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context, boost::pool<> &partial_edge_pool);
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context, boost::pool<> &partial_edge_pool);
+template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context, boost::pool<> &partial_edge_pool);
} // namespace search
diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh
index e306dc61..875ccc5e 100644
--- a/klm/search/edge_generator.hh
+++ b/klm/search/edge_generator.hh
@@ -2,7 +2,9 @@
#define SEARCH_EDGE_GENERATOR__
#include "search/edge.hh"
+#include "search/note.hh"
+#include <boost/pool/pool.hpp>
#include <boost/unordered_map.hpp>
#include <functional>
@@ -28,26 +30,28 @@ struct PartialEdgePointerLess : std::binary_function<const PartialEdge *, const
class EdgeGenerator {
public:
- // True if it has a hypothesis.
- bool Init(Edge &edge, VertexGenerator &parent);
+ EdgeGenerator(PartialEdge &root, unsigned char arity, Note note);
- Score Top() const {
- return top_;
+ Score TopScore() const {
+ return top_score_;
}
- template <class Model> bool Pop(Context<Model> &context, VertexGenerator &parent);
+ Note GetNote() const {
+ return note_;
+ }
+
+ // Pop. If there's a complete hypothesis, return it. Otherwise return NULL.
+ template <class Model> PartialEdge *Pop(Context<Model> &context, boost::pool<> &partial_edge_pool);
private:
- const Rule &GetRule() const {
- return from_->GetRule();
- }
+ Score top_score_;
- Score top_;
+ unsigned char arity_;
typedef std::priority_queue<PartialEdge*, std::vector<PartialEdge*>, PartialEdgePointerLess> Generate;
Generate generate_;
- Edge *from_;
+ Note note_;
};
} // namespace search
diff --git a/klm/search/edge_queue.cc b/klm/search/edge_queue.cc
new file mode 100644
index 00000000..e3ae6ebf
--- /dev/null
+++ b/klm/search/edge_queue.cc
@@ -0,0 +1,25 @@
+#include "search/edge_queue.hh"
+
+#include "lm/left.hh"
+#include "search/context.hh"
+
+#include <stdint.h>
+
+namespace search {
+
+EdgeQueue::EdgeQueue(unsigned int pop_limit_hint) : partial_edge_pool_(sizeof(PartialEdge), pop_limit_hint * 2) {
+ take_ = static_cast<PartialEdge*>(partial_edge_pool_.malloc());
+}
+
+/*void EdgeQueue::AddEdge(PartialEdge &root, unsigned char arity, Note note) {
+ // Ignore empty edges.
+ for (unsigned char i = 0; i < edge.Arity(); ++i) {
+ PartialVertex root(edge.GetVertex(i).RootPartial());
+ if (root.Empty()) return;
+ total_score += root.Bound();
+ }
+ PartialEdge &allocated = *static_cast<PartialEdge*>(partial_edge_pool_.malloc());
+ allocated.score = total_score;
+}*/
+
+} // namespace search
diff --git a/klm/search/edge_queue.hh b/klm/search/edge_queue.hh
new file mode 100644
index 00000000..187eaed7
--- /dev/null
+++ b/klm/search/edge_queue.hh
@@ -0,0 +1,73 @@
+#ifndef SEARCH_EDGE_QUEUE__
+#define SEARCH_EDGE_QUEUE__
+
+#include "search/edge.hh"
+#include "search/edge_generator.hh"
+#include "search/note.hh"
+
+#include <boost/pool/pool.hpp>
+#include <boost/pool/object_pool.hpp>
+
+#include <queue>
+
+namespace search {
+
+template <class Model> class Context;
+
+class EdgeQueue {
+ public:
+ explicit EdgeQueue(unsigned int pop_limit_hint);
+
+ PartialEdge &InitializeEdge() {
+ return *take_;
+ }
+
+ void AddEdge(unsigned char arity, Note note) {
+ generate_.push(edge_pool_.construct(*take_, arity, note));
+ take_ = static_cast<PartialEdge*>(partial_edge_pool_.malloc());
+ }
+
+ bool Empty() const { return generate_.empty(); }
+
+ /* Generate hypotheses and send them to output. Normally, output is a
+ * VertexGenerator, but the decoder may want to route edges to different
+ * vertices i.e. if they have different LHS non-terminal labels.
+ */
+ template <class Model, class Output> void Search(Context<Model> &context, Output &output) {
+ int to_pop = context.PopLimit();
+ while (to_pop > 0 && !generate_.empty()) {
+ EdgeGenerator *top = generate_.top();
+ generate_.pop();
+ PartialEdge *ret = top->Pop(context, partial_edge_pool_);
+ if (ret) {
+ output.NewHypothesis(*ret, top->GetNote());
+ --to_pop;
+ if (top->TopScore() != -kScoreInf) {
+ generate_.push(top);
+ }
+ } else {
+ generate_.push(top);
+ }
+ }
+ output.FinishedSearch();
+ }
+
+ private:
+ boost::object_pool<EdgeGenerator> edge_pool_;
+
+ struct LessByTopScore : public std::binary_function<const EdgeGenerator *, const EdgeGenerator *, bool> {
+ bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const {
+ return first->TopScore() < second->TopScore();
+ }
+ };
+
+ typedef std::priority_queue<EdgeGenerator*, std::vector<EdgeGenerator*>, LessByTopScore> Generate;
+ Generate generate_;
+
+ boost::pool<> partial_edge_pool_;
+
+ PartialEdge *take_;
+};
+
+} // namespace search
+#endif // SEARCH_EDGE_QUEUE__
diff --git a/klm/search/final.hh b/klm/search/final.hh
index 823b8c1a..1b3092ac 100644
--- a/klm/search/final.hh
+++ b/klm/search/final.hh
@@ -2,35 +2,34 @@
#define SEARCH_FINAL__
#include "search/arity.hh"
+#include "search/note.hh"
#include "search/types.hh"
#include <boost/array.hpp>
namespace search {
-class Edge;
-
class Final {
public:
typedef boost::array<const Final*, search::kMaxArity> ChildArray;
- void Reset(Score bound, const Edge &from, const Final &left, const Final &right) {
+ void Reset(Score bound, Note note, const Final &left, const Final &right) {
bound_ = bound;
- from_ = &from;
+ note_ = note;
children_[0] = &left;
children_[1] = &right;
}
const ChildArray &Children() const { return children_; }
- const Edge &From() const { return *from_; }
+ Note GetNote() const { return note_; }
Score Bound() const { return bound_; }
private:
Score bound_;
- const Edge *from_;
+ Note note_;
ChildArray children_;
};
diff --git a/klm/search/note.hh b/klm/search/note.hh
new file mode 100644
index 00000000..50bed06e
--- /dev/null
+++ b/klm/search/note.hh
@@ -0,0 +1,12 @@
+#ifndef SEARCH_NOTE__
+#define SEARCH_NOTE__
+
+namespace search {
+
+union Note {
+ const void *vp;
+};
+
+} // namespace search
+
+#endif // SEARCH_NOTE__
diff --git a/klm/search/rule.cc b/klm/search/rule.cc
index 0a941527..5b00207e 100644
--- a/klm/search/rule.cc
+++ b/klm/search/rule.cc
@@ -9,35 +9,35 @@
namespace search {
-template <class Model> void Rule::Init(const Context<Model> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos) {
- additive_ = additive;
- Score lm_score = 0.0;
- lexical_.clear();
- const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound();
-
+template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing) {
+ unsigned int oov_count = 0;
+ float prob = 0.0;
+ const Model &model = context.LanguageModel();
+ const lm::WordIndex oov = model.GetVocabulary().NotFound();
for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) {
- lexical_.resize(lexical_.size() + 1);
- lm::ngram::RuleScore<Model> scorer(context.LanguageModel(), lexical_.back());
+ lm::ngram::RuleScore<Model> scorer(model, *(writing++));
// TODO: optimize
if (prepend_bos && (word == words.begin())) {
scorer.BeginSentence();
}
for (; ; ++word) {
if (word == words.end()) {
- lm_score += scorer.Finish();
- bound_ = additive_ + context.GetWeights().LM() * lm_score;
- arity_ = lexical_.size() - 1;
- return;
+ prob += scorer.Finish();
+ return static_cast<float>(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM();
}
if (*word == kNonTerminal) break;
- if (*word == oov) additive_ += context.GetWeights().OOV();
+ if (*word == oov) ++oov_count;
scorer.Terminal(*word);
}
- lm_score += scorer.Finish();
+ prob += scorer.Finish();
}
}
-template void Rule::Init(const Context<lm::ngram::RestProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos);
-template void Rule::Init(const Context<lm::ngram::ProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos);
+template float ScoreRule(const Context<lm::ngram::RestProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::ProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::TrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::QuantTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::ArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
+template float ScoreRule(const Context<lm::ngram::QuantArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);
} // namespace search
diff --git a/klm/search/rule.hh b/klm/search/rule.hh
index 920c64a7..0ce2794d 100644
--- a/klm/search/rule.hh
+++ b/klm/search/rule.hh
@@ -3,44 +3,17 @@
#include "lm/left.hh"
#include "lm/word_index.hh"
-#include "search/arity.hh"
#include "search/types.hh"
-#include <boost/array.hpp>
-
-#include <iosfwd>
#include <vector>
namespace search {
template <class Model> class Context;
-class Rule {
- public:
- Rule() : arity_(0) {}
-
- static const lm::WordIndex kNonTerminal = lm::kMaxWordIndex;
-
- // Use kNonTerminal for non-terminals.
- template <class Model> void Init(const Context<Model> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos);
-
- Score Bound() const { return bound_; }
-
- Score Additive() const { return additive_; }
-
- unsigned int Arity() const { return arity_; }
-
- const lm::ngram::ChartState &Lexical(unsigned int index) const {
- return lexical_[index];
- }
-
- private:
- Score bound_, additive_;
-
- unsigned int arity_;
+const lm::WordIndex kNonTerminal = lm::kMaxWordIndex;
- std::vector<lm::ngram::ChartState> lexical_;
-};
+template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *state_out);
} // namespace search
diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh
index 7ef29efc..e1a9ad11 100644
--- a/klm/search/vertex.hh
+++ b/klm/search/vertex.hh
@@ -16,8 +16,6 @@ namespace search {
class ContextBase;
-class Edge;
-
class VertexNode {
public:
VertexNode() : end_(NULL) {}
@@ -103,6 +101,10 @@ class PartialVertex {
unsigned char Length() const { return back_->Length(); }
+ bool HasAlternative() const {
+ return index_ + 1 < back_->Size();
+ }
+
// Split into continuation and alternative, rendering this the alternative.
bool Split(PartialVertex &continuation) {
assert(!Complete());
@@ -128,35 +130,26 @@ extern PartialVertex kBlankPartialVertex;
class Vertex {
public:
- Vertex()
-#ifdef DEBUG
- : finished_adding_(false)
-#endif
- {}
-
- void Add(Edge &edge) {
-#ifdef DEBUG
- assert(!finished_adding_);
-#endif
- edges_.push_back(&edge);
- }
-
- void FinishedAdding() {
-#ifdef DEBUG
- assert(!finished_adding_);
- finished_adding_ = true;
-#endif
- }
+ Vertex() {}
PartialVertex RootPartial() const { return PartialVertex(root_); }
+ const Final *BestChild() const {
+ PartialVertex top(RootPartial());
+ if (top.Empty()) {
+ return NULL;
+ } else {
+ PartialVertex continuation;
+ while (!top.Complete()) {
+ top.Split(continuation);
+ top = continuation;
+ }
+ return &top.End();
+ }
+ }
+
private:
friend class VertexGenerator;
- std::vector<Edge*> edges_;
-
-#ifdef DEBUG
- bool finished_adding_;
-#endif
VertexNode root_;
};
diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc
index 78948c97..d94e6e06 100644
--- a/klm/search/vertex_generator.cc
+++ b/klm/search/vertex_generator.cc
@@ -2,45 +2,30 @@
#include "lm/left.hh"
#include "search/context.hh"
+#include "search/edge.hh"
#include <stdint.h>
namespace search {
-template <class Model> VertexGenerator::VertexGenerator(Context<Model> &context, Vertex &gen) : context_(context), edges_(gen.edges_.size()), partial_edge_pool_(sizeof(PartialEdge), context.PopLimit() * 2) {
- for (std::size_t i = 0; i < gen.edges_.size(); ++i) {
- if (edges_[i].Init(*gen.edges_[i], *this))
- generate_.push(&edges_[i]);
- }
+VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) {
gen.root_.InitRoot();
root_.under = &gen.root_;
- to_pop_ = context.PopLimit();
- while (to_pop_ > 0 && !generate_.empty()) {
- EdgeGenerator *top = generate_.top();
- generate_.pop();
- if (top->Pop(context, *this)) {
- generate_.push(top);
- }
- }
- gen.root_.SortAndSet(context, NULL);
}
-template VertexGenerator::VertexGenerator(Context<lm::ngram::ProbingModel> &context, Vertex &gen);
-template VertexGenerator::VertexGenerator(Context<lm::ngram::RestProbingModel> &context, Vertex &gen);
-
namespace {
const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);
} // namespace
-void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) {
+void VertexGenerator::NewHypothesis(const PartialEdge &partial, Note note) {
+ const lm::ngram::ChartState &state = partial.CompletedState();
std::pair<Existing::iterator, bool> got(existing_.insert(std::pair<uint64_t, Final*>(hash_value(state), NULL)));
if (!got.second) {
// Found it already.
Final &exists = *got.first->second;
if (exists.Bound() < partial.score) {
- exists.Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End());
+ exists.Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End());
}
- --to_pop_;
return;
}
unsigned char left = 0, right = 0;
@@ -67,8 +52,7 @@ void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Ed
}
node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true);
- got.first->second = CompleteTransition(*node, state, from, partial);
- --to_pop_;
+ got.first->second = CompleteTransition(*node, state, note, partial);
}
VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) {
@@ -86,12 +70,12 @@ VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node
return next;
}
-Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) {
+Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial) {
VertexNode &node = *starter.under;
assert(node.State().left.full == state.left.full);
assert(!node.End());
Final *final = context_.NewFinal();
- final->Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End());
+ final->Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End());
node.SetEnd(final);
return final;
}
diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh
index 8cdf1420..6b98da3e 100644
--- a/klm/search/vertex_generator.hh
+++ b/klm/search/vertex_generator.hh
@@ -1,10 +1,9 @@
#ifndef SEARCH_VERTEX_GENERATOR__
#define SEARCH_VERTEX_GENERATOR__
-#include "search/edge.hh"
-#include "search/edge_generator.hh"
+#include "search/note.hh"
+#include "search/vertex.hh"
-#include <boost/pool/pool.hpp>
#include <boost/unordered_map.hpp>
#include <queue>
@@ -17,18 +16,21 @@ class ChartState;
namespace search {
-template <class Model> class Context;
class ContextBase;
class Final;
+struct PartialEdge;
class VertexGenerator {
public:
- template <class Model> VertexGenerator(Context<Model> &context, Vertex &gen);
+ VertexGenerator(ContextBase &context, Vertex &gen);
- PartialEdge *MallocPartialEdge() { return static_cast<PartialEdge*>(partial_edge_pool_.malloc()); }
- void FreePartialEdge(PartialEdge *value) { partial_edge_pool_.free(value); }
+ void NewHypothesis(const PartialEdge &partial, Note note);
- void NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial);
+ void FinishedSearch() {
+ root_.under->SortAndSet(context_, NULL);
+ }
+
+ const Vertex &Generating() const { return gen_; }
private:
// Parallel structure to VertexNode.
@@ -41,29 +43,16 @@ class VertexGenerator {
Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full);
- Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial);
+ Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial);
ContextBase &context_;
- std::vector<EdgeGenerator> edges_;
-
- struct LessByTop : public std::binary_function<const EdgeGenerator *, const EdgeGenerator *, bool> {
- bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const {
- return first->Top() < second->Top();
- }
- };
-
- typedef std::priority_queue<EdgeGenerator*, std::vector<EdgeGenerator*>, LessByTop> Generate;
- Generate generate_;
+ Vertex &gen_;
Trie root_;
typedef boost::unordered_map<uint64_t, Final*> Existing;
Existing existing_;
-
- int to_pop_;
-
- boost::pool<> partial_edge_pool_;
};
} // namespace search