diff options
Diffstat (limited to 'klm')
-rw-r--r-- | klm/lm/fragment.cc | 37 | ||||
-rw-r--r-- | klm/search/Jamfile | 2 | ||||
-rw-r--r-- | klm/search/edge.hh | 31 | ||||
-rw-r--r-- | klm/search/edge_generator.cc | 53 | ||||
-rw-r--r-- | klm/search/edge_generator.hh | 24 | ||||
-rw-r--r-- | klm/search/edge_queue.cc | 25 | ||||
-rw-r--r-- | klm/search/edge_queue.hh | 73 | ||||
-rw-r--r-- | klm/search/final.hh | 11 | ||||
-rw-r--r-- | klm/search/note.hh | 12 | ||||
-rw-r--r-- | klm/search/rule.cc | 32 | ||||
-rw-r--r-- | klm/search/rule.hh | 31 | ||||
-rw-r--r-- | klm/search/vertex.hh | 45 | ||||
-rw-r--r-- | klm/search/vertex_generator.cc | 32 | ||||
-rw-r--r-- | klm/search/vertex_generator.hh | 35 |
14 files changed, 250 insertions, 193 deletions
diff --git a/klm/lm/fragment.cc b/klm/lm/fragment.cc new file mode 100644 index 00000000..0267cd4e --- /dev/null +++ b/klm/lm/fragment.cc @@ -0,0 +1,37 @@ +#include "lm/binary_format.hh" +#include "lm/model.hh" +#include "lm/left.hh" +#include "util/tokenize_piece.hh" + +template <class Model> void Query(const char *name) { + Model model(name); + std::string line; + lm::ngram::ChartState ignored; + while (getline(std::cin, line)) { + lm::ngram::RuleScore<Model> scorer(model, ignored); + for (util::TokenIter<util::SingleCharacter, true> i(line, ' '); i; ++i) { + scorer.Terminal(model.GetVocabulary().Index(*i)); + } + std::cout << scorer.Finish() << '\n'; + } +} + +int main(int argc, char *argv[]) { + if (argc != 2) { + std::cerr << "Expected model file name." << std::endl; + return 1; + } + const char *name = argv[1]; + lm::ngram::ModelType model_type = lm::ngram::PROBING; + lm::ngram::RecognizeBinary(name, model_type); + switch (model_type) { + case lm::ngram::PROBING: + Query<lm::ngram::ProbingModel>(name); + break; + case lm::ngram::REST_PROBING: + Query<lm::ngram::RestProbingModel>(name); + break; + default: + std::cerr << "Model type not supported yet." << std::endl; + } +} diff --git a/klm/search/Jamfile b/klm/search/Jamfile index ac47c249..e8b14363 100644 --- a/klm/search/Jamfile +++ b/klm/search/Jamfile @@ -1,4 +1,4 @@ -lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil : : : <include>.. ; +lib search : weights.cc vertex.cc vertex_generator.cc edge_queue.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ; import testing ; diff --git a/klm/search/edge.hh b/klm/search/edge.hh index 4d2a5cbf..77ab0ade 100644 --- a/klm/search/edge.hh +++ b/klm/search/edge.hh @@ -11,33 +11,6 @@ namespace search { -class Edge { - public: - Edge() { - end_to_ = to_; - } - - Rule &InitRule() { return rule_; } - - void Add(Vertex &vertex) { - assert(end_to_ - to_ < kMaxArity); - *(end_to_++) = &vertex; - } - - const Vertex &GetVertex(std::size_t index) const { - return *to_[index]; - } - - const Rule &GetRule() const { return rule_; } - - private: - // Rule and pointers to rule arguments. - Rule rule_; - - Vertex *to_[kMaxArity]; - Vertex **end_to_; -}; - struct PartialEdge { Score score; // Terminals @@ -45,6 +18,10 @@ struct PartialEdge { // Non-terminals PartialVertex nt[kMaxArity]; + const lm::ngram::ChartState &CompletedState() const { + return between[0]; + } + bool operator<(const PartialEdge &other) const { return score < other.score; } diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index d135899a..56239dfb 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -10,28 +10,15 @@ namespace search { -bool EdgeGenerator::Init(Edge &edge, VertexGenerator &parent) { - from_ = &edge; - for (unsigned int i = 0; i < GetRule().Arity(); ++i) { - if (edge.GetVertex(i).RootPartial().Empty()) return false; - } - PartialEdge &root = *parent.MallocPartialEdge(); - root.score = GetRule().Bound(); - for (unsigned int i = 0; i < GetRule().Arity(); ++i) { +EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) { +/* for (unsigned char i = 0; i < edge.Arity(); ++i) { root.nt[i] = edge.GetVertex(i).RootPartial(); - root.score += root.nt[i].Bound(); } - for (unsigned int i = GetRule().Arity(); i < 2; ++i) { + for (unsigned char i = edge.Arity(); i < 2; ++i) { root.nt[i] = kBlankPartialVertex; - } - for (unsigned int i = 0; i < GetRule().Arity() + 1; ++i) { - root.between[i] = GetRule().Lexical(i); - } - // wtf no clear method? - generate_ = Generate(); + }*/ generate_.push(&root); - top_ = root.score; - return true; + top_score_ = root.score; } namespace { @@ -78,13 +65,13 @@ template <class Model> float FastScore(const Context<Model> &context, unsigned c } // namespace -template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGenerator &parent) { +template <class Model> PartialEdge *EdgeGenerator::Pop(Context<Model> &context, boost::pool<> &partial_edge_pool) { assert(!generate_.empty()); PartialEdge &top = *generate_.top(); generate_.pop(); unsigned int victim = 0; unsigned char lowest_length = 255; - for (unsigned int i = 0; i != GetRule().Arity(); ++i) { + for (unsigned char i = 0; i != arity_; ++i) { if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) { lowest_length = top.nt[i].Length(); victim = i; @@ -92,21 +79,21 @@ template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGe } if (lowest_length == 255) { // All states report complete. - top.between[0].right = top.between[GetRule().Arity()].right; - parent.NewHypothesis(top.between[0], *from_, top); - top_ = generate_.empty() ? -kScoreInf : generate_.top()->score; - return !generate_.empty(); + top.between[0].right = top.between[arity_].right; + // Now top.between[0] is the full edge state. + top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score; + return ⊤ } unsigned int stay = !victim; - PartialEdge &continuation = *parent.MallocPartialEdge(); + PartialEdge &continuation = *static_cast<PartialEdge*>(partial_edge_pool.malloc()); float old_bound = top.nt[victim].Bound(); // The alternate's score will change because alternate.nt[victim] changes. bool split = top.nt[victim].Split(continuation.nt[victim]); // top is now the alternate. continuation.nt[stay] = top.nt[stay]; - continuation.score = FastScore(context, victim, GetRule().Arity(), top, continuation); + continuation.score = FastScore(context, victim, arity_, top, continuation); // TODO: dedupe? generate_.push(&continuation); @@ -116,14 +103,18 @@ template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGe // TODO: dedupe? generate_.push(&top); } else { - parent.FreePartialEdge(&top); + partial_edge_pool.free(&top); } - top_ = generate_.top()->score; - return true; + top_score_ = generate_.top()->score; + return NULL; } -template bool EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, VertexGenerator &parent); -template bool EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, VertexGenerator &parent); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context, boost::pool<> &partial_edge_pool); } // namespace search diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh index e306dc61..875ccc5e 100644 --- a/klm/search/edge_generator.hh +++ b/klm/search/edge_generator.hh @@ -2,7 +2,9 @@ #define SEARCH_EDGE_GENERATOR__ #include "search/edge.hh" +#include "search/note.hh" +#include <boost/pool/pool.hpp> #include <boost/unordered_map.hpp> #include <functional> @@ -28,26 +30,28 @@ struct PartialEdgePointerLess : std::binary_function<const PartialEdge *, const class EdgeGenerator { public: - // True if it has a hypothesis. - bool Init(Edge &edge, VertexGenerator &parent); + EdgeGenerator(PartialEdge &root, unsigned char arity, Note note); - Score Top() const { - return top_; + Score TopScore() const { + return top_score_; } - template <class Model> bool Pop(Context<Model> &context, VertexGenerator &parent); + Note GetNote() const { + return note_; + } + + // Pop. If there's a complete hypothesis, return it. Otherwise return NULL. + template <class Model> PartialEdge *Pop(Context<Model> &context, boost::pool<> &partial_edge_pool); private: - const Rule &GetRule() const { - return from_->GetRule(); - } + Score top_score_; - Score top_; + unsigned char arity_; typedef std::priority_queue<PartialEdge*, std::vector<PartialEdge*>, PartialEdgePointerLess> Generate; Generate generate_; - Edge *from_; + Note note_; }; } // namespace search diff --git a/klm/search/edge_queue.cc b/klm/search/edge_queue.cc new file mode 100644 index 00000000..e3ae6ebf --- /dev/null +++ b/klm/search/edge_queue.cc @@ -0,0 +1,25 @@ +#include "search/edge_queue.hh" + +#include "lm/left.hh" +#include "search/context.hh" + +#include <stdint.h> + +namespace search { + +EdgeQueue::EdgeQueue(unsigned int pop_limit_hint) : partial_edge_pool_(sizeof(PartialEdge), pop_limit_hint * 2) { + take_ = static_cast<PartialEdge*>(partial_edge_pool_.malloc()); +} + +/*void EdgeQueue::AddEdge(PartialEdge &root, unsigned char arity, Note note) { + // Ignore empty edges. + for (unsigned char i = 0; i < edge.Arity(); ++i) { + PartialVertex root(edge.GetVertex(i).RootPartial()); + if (root.Empty()) return; + total_score += root.Bound(); + } + PartialEdge &allocated = *static_cast<PartialEdge*>(partial_edge_pool_.malloc()); + allocated.score = total_score; +}*/ + +} // namespace search diff --git a/klm/search/edge_queue.hh b/klm/search/edge_queue.hh new file mode 100644 index 00000000..187eaed7 --- /dev/null +++ b/klm/search/edge_queue.hh @@ -0,0 +1,73 @@ +#ifndef SEARCH_EDGE_QUEUE__ +#define SEARCH_EDGE_QUEUE__ + +#include "search/edge.hh" +#include "search/edge_generator.hh" +#include "search/note.hh" + +#include <boost/pool/pool.hpp> +#include <boost/pool/object_pool.hpp> + +#include <queue> + +namespace search { + +template <class Model> class Context; + +class EdgeQueue { + public: + explicit EdgeQueue(unsigned int pop_limit_hint); + + PartialEdge &InitializeEdge() { + return *take_; + } + + void AddEdge(unsigned char arity, Note note) { + generate_.push(edge_pool_.construct(*take_, arity, note)); + take_ = static_cast<PartialEdge*>(partial_edge_pool_.malloc()); + } + + bool Empty() const { return generate_.empty(); } + + /* Generate hypotheses and send them to output. Normally, output is a + * VertexGenerator, but the decoder may want to route edges to different + * vertices i.e. if they have different LHS non-terminal labels. + */ + template <class Model, class Output> void Search(Context<Model> &context, Output &output) { + int to_pop = context.PopLimit(); + while (to_pop > 0 && !generate_.empty()) { + EdgeGenerator *top = generate_.top(); + generate_.pop(); + PartialEdge *ret = top->Pop(context, partial_edge_pool_); + if (ret) { + output.NewHypothesis(*ret, top->GetNote()); + --to_pop; + if (top->TopScore() != -kScoreInf) { + generate_.push(top); + } + } else { + generate_.push(top); + } + } + output.FinishedSearch(); + } + + private: + boost::object_pool<EdgeGenerator> edge_pool_; + + struct LessByTopScore : public std::binary_function<const EdgeGenerator *, const EdgeGenerator *, bool> { + bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { + return first->TopScore() < second->TopScore(); + } + }; + + typedef std::priority_queue<EdgeGenerator*, std::vector<EdgeGenerator*>, LessByTopScore> Generate; + Generate generate_; + + boost::pool<> partial_edge_pool_; + + PartialEdge *take_; +}; + +} // namespace search +#endif // SEARCH_EDGE_QUEUE__ diff --git a/klm/search/final.hh b/klm/search/final.hh index 823b8c1a..1b3092ac 100644 --- a/klm/search/final.hh +++ b/klm/search/final.hh @@ -2,35 +2,34 @@ #define SEARCH_FINAL__ #include "search/arity.hh" +#include "search/note.hh" #include "search/types.hh" #include <boost/array.hpp> namespace search { -class Edge; - class Final { public: typedef boost::array<const Final*, search::kMaxArity> ChildArray; - void Reset(Score bound, const Edge &from, const Final &left, const Final &right) { + void Reset(Score bound, Note note, const Final &left, const Final &right) { bound_ = bound; - from_ = &from; + note_ = note; children_[0] = &left; children_[1] = &right; } const ChildArray &Children() const { return children_; } - const Edge &From() const { return *from_; } + Note GetNote() const { return note_; } Score Bound() const { return bound_; } private: Score bound_; - const Edge *from_; + Note note_; ChildArray children_; }; diff --git a/klm/search/note.hh b/klm/search/note.hh new file mode 100644 index 00000000..50bed06e --- /dev/null +++ b/klm/search/note.hh @@ -0,0 +1,12 @@ +#ifndef SEARCH_NOTE__ +#define SEARCH_NOTE__ + +namespace search { + +union Note { + const void *vp; +}; + +} // namespace search + +#endif // SEARCH_NOTE__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc index 0a941527..5b00207e 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -9,35 +9,35 @@ namespace search { -template <class Model> void Rule::Init(const Context<Model> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos) { - additive_ = additive; - Score lm_score = 0.0; - lexical_.clear(); - const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound(); - +template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing) { + unsigned int oov_count = 0; + float prob = 0.0; + const Model &model = context.LanguageModel(); + const lm::WordIndex oov = model.GetVocabulary().NotFound(); for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) { - lexical_.resize(lexical_.size() + 1); - lm::ngram::RuleScore<Model> scorer(context.LanguageModel(), lexical_.back()); + lm::ngram::RuleScore<Model> scorer(model, *(writing++)); // TODO: optimize if (prepend_bos && (word == words.begin())) { scorer.BeginSentence(); } for (; ; ++word) { if (word == words.end()) { - lm_score += scorer.Finish(); - bound_ = additive_ + context.GetWeights().LM() * lm_score; - arity_ = lexical_.size() - 1; - return; + prob += scorer.Finish(); + return static_cast<float>(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM(); } if (*word == kNonTerminal) break; - if (*word == oov) additive_ += context.GetWeights().OOV(); + if (*word == oov) ++oov_count; scorer.Terminal(*word); } - lm_score += scorer.Finish(); + prob += scorer.Finish(); } } -template void Rule::Init(const Context<lm::ngram::RestProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos); -template void Rule::Init(const Context<lm::ngram::ProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos); +template float ScoreRule(const Context<lm::ngram::RestProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::ProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::TrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::QuantTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::ArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::QuantArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); } // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh index 920c64a7..0ce2794d 100644 --- a/klm/search/rule.hh +++ b/klm/search/rule.hh @@ -3,44 +3,17 @@ #include "lm/left.hh" #include "lm/word_index.hh" -#include "search/arity.hh" #include "search/types.hh" -#include <boost/array.hpp> - -#include <iosfwd> #include <vector> namespace search { template <class Model> class Context; -class Rule { - public: - Rule() : arity_(0) {} - - static const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; - - // Use kNonTerminal for non-terminals. - template <class Model> void Init(const Context<Model> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos); - - Score Bound() const { return bound_; } - - Score Additive() const { return additive_; } - - unsigned int Arity() const { return arity_; } - - const lm::ngram::ChartState &Lexical(unsigned int index) const { - return lexical_[index]; - } - - private: - Score bound_, additive_; - - unsigned int arity_; +const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; - std::vector<lm::ngram::ChartState> lexical_; -}; +template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *state_out); } // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index 7ef29efc..e1a9ad11 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -16,8 +16,6 @@ namespace search { class ContextBase; -class Edge; - class VertexNode { public: VertexNode() : end_(NULL) {} @@ -103,6 +101,10 @@ class PartialVertex { unsigned char Length() const { return back_->Length(); } + bool HasAlternative() const { + return index_ + 1 < back_->Size(); + } + // Split into continuation and alternative, rendering this the alternative. bool Split(PartialVertex &continuation) { assert(!Complete()); @@ -128,35 +130,26 @@ extern PartialVertex kBlankPartialVertex; class Vertex { public: - Vertex() -#ifdef DEBUG - : finished_adding_(false) -#endif - {} - - void Add(Edge &edge) { -#ifdef DEBUG - assert(!finished_adding_); -#endif - edges_.push_back(&edge); - } - - void FinishedAdding() { -#ifdef DEBUG - assert(!finished_adding_); - finished_adding_ = true; -#endif - } + Vertex() {} PartialVertex RootPartial() const { return PartialVertex(root_); } + const Final *BestChild() const { + PartialVertex top(RootPartial()); + if (top.Empty()) { + return NULL; + } else { + PartialVertex continuation; + while (!top.Complete()) { + top.Split(continuation); + top = continuation; + } + return &top.End(); + } + } + private: friend class VertexGenerator; - std::vector<Edge*> edges_; - -#ifdef DEBUG - bool finished_adding_; -#endif VertexNode root_; }; diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index 78948c97..d94e6e06 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -2,45 +2,30 @@ #include "lm/left.hh" #include "search/context.hh" +#include "search/edge.hh" #include <stdint.h> namespace search { -template <class Model> VertexGenerator::VertexGenerator(Context<Model> &context, Vertex &gen) : context_(context), edges_(gen.edges_.size()), partial_edge_pool_(sizeof(PartialEdge), context.PopLimit() * 2) { - for (std::size_t i = 0; i < gen.edges_.size(); ++i) { - if (edges_[i].Init(*gen.edges_[i], *this)) - generate_.push(&edges_[i]); - } +VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { gen.root_.InitRoot(); root_.under = &gen.root_; - to_pop_ = context.PopLimit(); - while (to_pop_ > 0 && !generate_.empty()) { - EdgeGenerator *top = generate_.top(); - generate_.pop(); - if (top->Pop(context, *this)) { - generate_.push(top); - } - } - gen.root_.SortAndSet(context, NULL); } -template VertexGenerator::VertexGenerator(Context<lm::ngram::ProbingModel> &context, Vertex &gen); -template VertexGenerator::VertexGenerator(Context<lm::ngram::RestProbingModel> &context, Vertex &gen); - namespace { const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); } // namespace -void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) { +void VertexGenerator::NewHypothesis(const PartialEdge &partial, Note note) { + const lm::ngram::ChartState &state = partial.CompletedState(); std::pair<Existing::iterator, bool> got(existing_.insert(std::pair<uint64_t, Final*>(hash_value(state), NULL))); if (!got.second) { // Found it already. Final &exists = *got.first->second; if (exists.Bound() < partial.score) { - exists.Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End()); + exists.Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); } - --to_pop_; return; } unsigned char left = 0, right = 0; @@ -67,8 +52,7 @@ void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Ed } node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); - got.first->second = CompleteTransition(*node, state, from, partial); - --to_pop_; + got.first->second = CompleteTransition(*node, state, note, partial); } VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { @@ -86,12 +70,12 @@ VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node return next; } -Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) { +Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial) { VertexNode &node = *starter.under; assert(node.State().left.full == state.left.full); assert(!node.End()); Final *final = context_.NewFinal(); - final->Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End()); + final->Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); node.SetEnd(final); return final; } diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 8cdf1420..6b98da3e 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -1,10 +1,9 @@ #ifndef SEARCH_VERTEX_GENERATOR__ #define SEARCH_VERTEX_GENERATOR__ -#include "search/edge.hh" -#include "search/edge_generator.hh" +#include "search/note.hh" +#include "search/vertex.hh" -#include <boost/pool/pool.hpp> #include <boost/unordered_map.hpp> #include <queue> @@ -17,18 +16,21 @@ class ChartState; namespace search { -template <class Model> class Context; class ContextBase; class Final; +struct PartialEdge; class VertexGenerator { public: - template <class Model> VertexGenerator(Context<Model> &context, Vertex &gen); + VertexGenerator(ContextBase &context, Vertex &gen); - PartialEdge *MallocPartialEdge() { return static_cast<PartialEdge*>(partial_edge_pool_.malloc()); } - void FreePartialEdge(PartialEdge *value) { partial_edge_pool_.free(value); } + void NewHypothesis(const PartialEdge &partial, Note note); - void NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial); + void FinishedSearch() { + root_.under->SortAndSet(context_, NULL); + } + + const Vertex &Generating() const { return gen_; } private: // Parallel structure to VertexNode. @@ -41,29 +43,16 @@ class VertexGenerator { Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full); - Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial); + Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial); ContextBase &context_; - std::vector<EdgeGenerator> edges_; - - struct LessByTop : public std::binary_function<const EdgeGenerator *, const EdgeGenerator *, bool> { - bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { - return first->Top() < second->Top(); - } - }; - - typedef std::priority_queue<EdgeGenerator*, std::vector<EdgeGenerator*>, LessByTop> Generate; - Generate generate_; + Vertex &gen_; Trie root_; typedef boost::unordered_map<uint64_t, Final*> Existing; Existing existing_; - - int to_pop_; - - boost::pool<> partial_edge_pool_; }; } // namespace search |