From 58d7f847cd5b3c56682e834a2d9b897c6943fafc Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Tue, 11 Sep 2012 14:30:16 +0100 Subject: Add search library to cdec (not used yet) --- klm/search/Jamfile | 5 ++ klm/search/arity.hh | 8 ++ klm/search/config.hh | 25 +++++++ klm/search/context.hh | 66 +++++++++++++++++ klm/search/edge.hh | 54 ++++++++++++++ klm/search/edge_generator.cc | 129 ++++++++++++++++++++++++++++++++ klm/search/edge_generator.hh | 54 ++++++++++++++ klm/search/final.hh | 40 ++++++++++ klm/search/rule.cc | 55 ++++++++++++++ klm/search/rule.hh | 60 +++++++++++++++ klm/search/source.hh | 48 ++++++++++++ klm/search/types.hh | 18 +++++ klm/search/vertex.cc | 48 ++++++++++++ klm/search/vertex.hh | 165 +++++++++++++++++++++++++++++++++++++++++ klm/search/vertex_generator.cc | 99 +++++++++++++++++++++++++ klm/search/vertex_generator.hh | 70 +++++++++++++++++ klm/search/weights.cc | 69 +++++++++++++++++ klm/search/weights.hh | 49 ++++++++++++ klm/search/weights_test.cc | 38 ++++++++++ klm/search/word.hh | 47 ++++++++++++ 20 files changed, 1147 insertions(+) create mode 100644 klm/search/Jamfile create mode 100644 klm/search/arity.hh create mode 100644 klm/search/config.hh create mode 100644 klm/search/context.hh create mode 100644 klm/search/edge.hh create mode 100644 klm/search/edge_generator.cc create mode 100644 klm/search/edge_generator.hh create mode 100644 klm/search/final.hh create mode 100644 klm/search/rule.cc create mode 100644 klm/search/rule.hh create mode 100644 klm/search/source.hh create mode 100644 klm/search/types.hh create mode 100644 klm/search/vertex.cc create mode 100644 klm/search/vertex.hh create mode 100644 klm/search/vertex_generator.cc create mode 100644 klm/search/vertex_generator.hh create mode 100644 klm/search/weights.cc create mode 100644 klm/search/weights.hh create mode 100644 klm/search/weights_test.cc create mode 100644 klm/search/word.hh (limited to 'klm/search') diff --git a/klm/search/Jamfile b/klm/search/Jamfile new file mode 100644 index 00000000..ac47c249 --- /dev/null +++ b/klm/search/Jamfile @@ -0,0 +1,5 @@ +lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil : : : .. ; + +import testing ; + +unit-test weights_test : weights_test.cc search /top//boost_unit_test_framework ; diff --git a/klm/search/arity.hh b/klm/search/arity.hh new file mode 100644 index 00000000..09c2c671 --- /dev/null +++ b/klm/search/arity.hh @@ -0,0 +1,8 @@ +#ifndef SEARCH_ARITY__ +#define SEARCH_ARITY__ +namespace search { + +const unsigned int kMaxArity = 2; + +} // namespace search +#endif // SEARCH_ARITY__ diff --git a/klm/search/config.hh b/klm/search/config.hh new file mode 100644 index 00000000..e21e4b7c --- /dev/null +++ b/klm/search/config.hh @@ -0,0 +1,25 @@ +#ifndef SEARCH_CONFIG__ +#define SEARCH_CONFIG__ + +#include "search/weights.hh" +#include "util/string_piece.hh" + +namespace search { + +class Config { + public: + Config(StringPiece weight_str, unsigned int pop_limit) : + weights_(weight_str), pop_limit_(pop_limit) {} + + const Weights &GetWeights() const { return weights_; } + + unsigned int PopLimit() const { return pop_limit_; } + + private: + search::Weights weights_; + unsigned int pop_limit_; +}; + +} // namespace search + +#endif // SEARCH_CONFIG__ diff --git a/klm/search/context.hh b/klm/search/context.hh new file mode 100644 index 00000000..ae248549 --- /dev/null +++ b/klm/search/context.hh @@ -0,0 +1,66 @@ +#ifndef SEARCH_CONTEXT__ +#define SEARCH_CONTEXT__ + +#include "lm/model.hh" +#include "search/config.hh" +#include "search/final.hh" +#include "search/types.hh" +#include "search/vertex.hh" +#include "search/word.hh" +#include "util/exception.hh" + +#include +#include + +#include + +namespace search { + +class Weights; + +class ContextBase { + public: + explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} + + Final *NewFinal() { + Final *ret = final_pool_.construct(); + assert(ret); + return ret; + } + + VertexNode *NewVertexNode() { + VertexNode *ret = vertex_node_pool_.construct(); + assert(ret); + return ret; + } + + void DeleteVertexNode(VertexNode *node) { + vertex_node_pool_.destroy(node); + } + + unsigned int PopLimit() const { return pop_limit_; } + + const Weights &GetWeights() const { return weights_; } + + private: + boost::object_pool final_pool_; + boost::object_pool vertex_node_pool_; + + unsigned int pop_limit_; + + const Weights &weights_; +}; + +template class Context : public ContextBase { + public: + Context(const Config &config, const Model &model) : ContextBase(config), model_(model) {} + + const Model &LanguageModel() const { return model_; } + + private: + const Model &model_; +}; + +} // namespace search + +#endif // SEARCH_CONTEXT__ diff --git a/klm/search/edge.hh b/klm/search/edge.hh new file mode 100644 index 00000000..4d2a5cbf --- /dev/null +++ b/klm/search/edge.hh @@ -0,0 +1,54 @@ +#ifndef SEARCH_EDGE__ +#define SEARCH_EDGE__ + +#include "lm/state.hh" +#include "search/arity.hh" +#include "search/rule.hh" +#include "search/types.hh" +#include "search/vertex.hh" + +#include + +namespace search { + +class Edge { + public: + Edge() { + end_to_ = to_; + } + + Rule &InitRule() { return rule_; } + + void Add(Vertex &vertex) { + assert(end_to_ - to_ < kMaxArity); + *(end_to_++) = &vertex; + } + + const Vertex &GetVertex(std::size_t index) const { + return *to_[index]; + } + + const Rule &GetRule() const { return rule_; } + + private: + // Rule and pointers to rule arguments. + Rule rule_; + + Vertex *to_[kMaxArity]; + Vertex **end_to_; +}; + +struct PartialEdge { + Score score; + // Terminals + lm::ngram::ChartState between[kMaxArity + 1]; + // Non-terminals + PartialVertex nt[kMaxArity]; + + bool operator<(const PartialEdge &other) const { + return score < other.score; + } +}; + +} // namespace search +#endif // SEARCH_EDGE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc new file mode 100644 index 00000000..d135899a --- /dev/null +++ b/klm/search/edge_generator.cc @@ -0,0 +1,129 @@ +#include "search/edge_generator.hh" + +#include "lm/left.hh" +#include "lm/partial.hh" +#include "search/context.hh" +#include "search/vertex.hh" +#include "search/vertex_generator.hh" + +#include + +namespace search { + +bool EdgeGenerator::Init(Edge &edge, VertexGenerator &parent) { + from_ = &edge; + for (unsigned int i = 0; i < GetRule().Arity(); ++i) { + if (edge.GetVertex(i).RootPartial().Empty()) return false; + } + PartialEdge &root = *parent.MallocPartialEdge(); + root.score = GetRule().Bound(); + for (unsigned int i = 0; i < GetRule().Arity(); ++i) { + root.nt[i] = edge.GetVertex(i).RootPartial(); + root.score += root.nt[i].Bound(); + } + for (unsigned int i = GetRule().Arity(); i < 2; ++i) { + root.nt[i] = kBlankPartialVertex; + } + for (unsigned int i = 0; i < GetRule().Arity() + 1; ++i) { + root.between[i] = GetRule().Lexical(i); + } + // wtf no clear method? + generate_ = Generate(); + generate_.push(&root); + top_ = root.score; + return true; +} + +namespace { + +template float FastScore(const Context &context, unsigned char victim, unsigned char arity, const PartialEdge &previous, PartialEdge &update) { + memcpy(update.between, previous.between, sizeof(lm::ngram::ChartState) * (arity + 1)); + + float ret = 0.0; + lm::ngram::ChartState *before, *after; + if (victim == 0) { + before = &update.between[0]; + after = &update.between[(arity == 2 && previous.nt[1].Complete()) ? 2 : 1]; + } else { + assert(victim == 1); + assert(arity == 2); + before = &update.between[previous.nt[0].Complete() ? 0 : 1]; + after = &update.between[2]; + } + const lm::ngram::ChartState &previous_reveal = previous.nt[victim].State(); + const PartialVertex &update_nt = update.nt[victim]; + const lm::ngram::ChartState &update_reveal = update_nt.State(); + float just_after = 0.0; + if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) { + just_after += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); + } + if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous.nt[victim].RightFull())) { + ret += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); + } + if (update_nt.Complete()) { + if (update_reveal.left.full) { + before->left.full = true; + } else { + assert(update_reveal.left.length == update_reveal.right.length); + ret += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); + } + if (victim == 0) { + update.between[0].right = after->right; + } else { + update.between[2].left = before->left; + } + } + return previous.score + (ret + just_after) * context.GetWeights().LM(); +} + +} // namespace + +template bool EdgeGenerator::Pop(Context &context, VertexGenerator &parent) { + assert(!generate_.empty()); + PartialEdge &top = *generate_.top(); + generate_.pop(); + unsigned int victim = 0; + unsigned char lowest_length = 255; + for (unsigned int i = 0; i != GetRule().Arity(); ++i) { + if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) { + lowest_length = top.nt[i].Length(); + victim = i; + } + } + if (lowest_length == 255) { + // All states report complete. + top.between[0].right = top.between[GetRule().Arity()].right; + parent.NewHypothesis(top.between[0], *from_, top); + top_ = generate_.empty() ? -kScoreInf : generate_.top()->score; + return !generate_.empty(); + } + + unsigned int stay = !victim; + PartialEdge &continuation = *parent.MallocPartialEdge(); + float old_bound = top.nt[victim].Bound(); + // The alternate's score will change because alternate.nt[victim] changes. + bool split = top.nt[victim].Split(continuation.nt[victim]); + // top is now the alternate. + + continuation.nt[stay] = top.nt[stay]; + continuation.score = FastScore(context, victim, GetRule().Arity(), top, continuation); + // TODO: dedupe? + generate_.push(&continuation); + + if (split) { + // We have an alternate. + top.score += top.nt[victim].Bound() - old_bound; + // TODO: dedupe? + generate_.push(&top); + } else { + parent.FreePartialEdge(&top); + } + + top_ = generate_.top()->score; + return true; +} + +template bool EdgeGenerator::Pop(Context &context, VertexGenerator &parent); +template bool EdgeGenerator::Pop(Context &context, VertexGenerator &parent); + +} // namespace search diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh new file mode 100644 index 00000000..e306dc61 --- /dev/null +++ b/klm/search/edge_generator.hh @@ -0,0 +1,54 @@ +#ifndef SEARCH_EDGE_GENERATOR__ +#define SEARCH_EDGE_GENERATOR__ + +#include "search/edge.hh" + +#include + +#include +#include + +namespace lm { +namespace ngram { +class ChartState; +} // namespace ngram +} // namespace lm + +namespace search { + +template class Context; + +class VertexGenerator; + +struct PartialEdgePointerLess : std::binary_function { + bool operator()(const PartialEdge *first, const PartialEdge *second) const { + return *first < *second; + } +}; + +class EdgeGenerator { + public: + // True if it has a hypothesis. + bool Init(Edge &edge, VertexGenerator &parent); + + Score Top() const { + return top_; + } + + template bool Pop(Context &context, VertexGenerator &parent); + + private: + const Rule &GetRule() const { + return from_->GetRule(); + } + + Score top_; + + typedef std::priority_queue, PartialEdgePointerLess> Generate; + Generate generate_; + + Edge *from_; +}; + +} // namespace search +#endif // SEARCH_EDGE_GENERATOR__ diff --git a/klm/search/final.hh b/klm/search/final.hh new file mode 100644 index 00000000..24e6f0a5 --- /dev/null +++ b/klm/search/final.hh @@ -0,0 +1,40 @@ +#ifndef SEARCH_FINAL__ +#define SEARCH_FINAL__ + +#include "search/rule.hh" +#include "search/types.hh" + +#include + +namespace search { + +class Final { + public: + typedef boost::array ChildArray; + + void Reset(Score bound, const Rule &from, const Final &left, const Final &right) { + bound_ = bound; + from_ = &from; + children_[0] = &left; + children_[1] = &right; + } + + const ChildArray &Children() const { return children_; } + + unsigned int ChildCount() const { return from_->Arity(); } + + const Rule &From() const { return *from_; } + + Score Bound() const { return bound_; } + + private: + Score bound_; + + const Rule *from_; + + ChildArray children_; +}; + +} // namespace search + +#endif // SEARCH_FINAL__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc new file mode 100644 index 00000000..a8b993eb --- /dev/null +++ b/klm/search/rule.cc @@ -0,0 +1,55 @@ +#include "search/rule.hh" + +#include "search/context.hh" +#include "search/final.hh" + +#include + +#include + +namespace search { + +template void Rule::FinishedAdding(const Context &context, Score additive, bool prepend_bos) { + additive_ = additive; + Score lm_score = 0.0; + lexical_.clear(); + const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound(); + + for (std::vector::const_iterator word = items_.begin(); ; ++word) { + lexical_.resize(lexical_.size() + 1); + lm::ngram::RuleScore scorer(context.LanguageModel(), lexical_.back()); + // TODO: optimize + if (prepend_bos && (word == items_.begin())) { + scorer.BeginSentence(); + } + for (; ; ++word) { + if (word == items_.end()) { + lm_score += scorer.Finish(); + bound_ = additive_ + context.GetWeights().LM() * lm_score; + assert(lexical_.size() == arity_ + 1); + return; + } + if (!word->Terminal()) break; + if (word->Index() == oov) additive_ += context.GetWeights().OOV(); + scorer.Terminal(word->Index()); + } + lm_score += scorer.Finish(); + } +} + +template void Rule::FinishedAdding(const Context &context, Score additive, bool prepend_bos); +template void Rule::FinishedAdding(const Context &context, Score additive, bool prepend_bos); + +std::ostream &operator<<(std::ostream &o, const Rule &rule) { + const Rule::ItemsRet &items = rule.Items(); + for (Rule::ItemsRet::const_iterator i = items.begin(); i != items.end(); ++i) { + if (i->Terminal()) { + o << i->String() << ' '; + } else { + o << "[] "; + } + } + return o; +} + +} // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh new file mode 100644 index 00000000..79192d40 --- /dev/null +++ b/klm/search/rule.hh @@ -0,0 +1,60 @@ +#ifndef SEARCH_RULE__ +#define SEARCH_RULE__ + +#include "lm/left.hh" +#include "search/arity.hh" +#include "search/types.hh" +#include "search/word.hh" + +#include + +#include +#include + +namespace search { + +template class Context; + +class Rule { + public: + Rule() : arity_(0) {} + + void AppendTerminal(Word w) { items_.push_back(w); } + + void AppendNonTerminal() { + items_.resize(items_.size() + 1); + ++arity_; + } + + template void FinishedAdding(const Context &context, Score additive, bool prepend_bos); + + Score Bound() const { return bound_; } + + Score Additive() const { return additive_; } + + unsigned int Arity() const { return arity_; } + + const lm::ngram::ChartState &Lexical(unsigned int index) const { + return lexical_[index]; + } + + // For printing. + typedef const std::vector ItemsRet; + ItemsRet &Items() const { return items_; } + + private: + Score bound_, additive_; + + unsigned int arity_; + + // TODO: pool? + std::vector items_; + + std::vector lexical_; +}; + +std::ostream &operator<<(std::ostream &o, const Rule &rule); + +} // namespace search + +#endif // SEARCH_RULE__ diff --git a/klm/search/source.hh b/klm/search/source.hh new file mode 100644 index 00000000..11839f7b --- /dev/null +++ b/klm/search/source.hh @@ -0,0 +1,48 @@ +#ifndef SEARCH_SOURCE__ +#define SEARCH_SOURCE__ + +#include "search/types.hh" + +#include +#include + +namespace search { + +template class Source { + public: + Source() : bound_(kScoreInf) {} + + Index Size() const { + return final_.size(); + } + + Score Bound() const { + return bound_; + } + + const Final &operator[](Index index) const { + return *final_[index]; + } + + Score ScoreOrBound(Index index) const { + return Size() > index ? final_[index]->Total() : Bound(); + } + + protected: + void AddFinal(const Final &store) { + final_.push_back(&store); + } + + void SetBound(Score to) { + assert(to <= bound_ + 0.001); + bound_ = to; + } + + private: + std::vector final_; + + Score bound_; +}; + +} // namespace search +#endif // SEARCH_SOURCE__ diff --git a/klm/search/types.hh b/klm/search/types.hh new file mode 100644 index 00000000..9726379f --- /dev/null +++ b/klm/search/types.hh @@ -0,0 +1,18 @@ +#ifndef SEARCH_TYPES__ +#define SEARCH_TYPES__ + +#include + +namespace search { + +typedef float Score; +const Score kScoreInf = INFINITY; + +// This could have been an enum but gcc wants 4 bytes. +typedef bool ExtendDirection; +const ExtendDirection kExtendLeft = 0; +const ExtendDirection kExtendRight = 1; + +} // namespace search + +#endif // SEARCH_TYPES__ diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc new file mode 100644 index 00000000..cc53c0dd --- /dev/null +++ b/klm/search/vertex.cc @@ -0,0 +1,48 @@ +#include "search/vertex.hh" + +#include "search/context.hh" + +#include +#include + +#include + +namespace search { + +namespace { + +struct GreaterByBound : public std::binary_function { + bool operator()(const VertexNode *first, const VertexNode *second) const { + return first->Bound() > second->Bound(); + } +}; + +} // namespace + +void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) { + if (Complete()) { + assert(end_); + assert(extend_.empty()); + bound_ = end_->Bound(); + return; + } + if (extend_.size() == 1 && parent_ptr) { + *parent_ptr = extend_[0]; + extend_[0]->SortAndSet(context, parent_ptr); + context.DeleteVertexNode(this); + return; + } + for (std::vector::iterator i = extend_.begin(); i != extend_.end(); ++i) { + (*i)->SortAndSet(context, &*i); + } + std::sort(extend_.begin(), extend_.end(), GreaterByBound()); + bound_ = extend_.front()->Bound(); +} + +namespace { +VertexNode kBlankVertexNode; +} // namespace + +PartialVertex kBlankPartialVertex(kBlankVertexNode); + +} // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh new file mode 100644 index 00000000..7ef29efc --- /dev/null +++ b/klm/search/vertex.hh @@ -0,0 +1,165 @@ +#ifndef SEARCH_VERTEX__ +#define SEARCH_VERTEX__ + +#include "lm/left.hh" +#include "search/final.hh" +#include "search/types.hh" + +#include + +#include +#include + +#include + +namespace search { + +class ContextBase; + +class Edge; + +class VertexNode { + public: + VertexNode() : end_(NULL) {} + + void InitRoot() { + extend_.clear(); + state_.left.full = false; + state_.left.length = 0; + state_.right.length = 0; + right_full_ = false; + bound_ = -kScoreInf; + end_ = NULL; + } + + lm::ngram::ChartState &MutableState() { return state_; } + bool &MutableRightFull() { return right_full_; } + + void AddExtend(VertexNode *next) { + extend_.push_back(next); + } + + void SetEnd(Final *end) { end_ = end; } + + Final &MutableEnd() { return *end_; } + + void SortAndSet(ContextBase &context, VertexNode **parent_pointer); + + // Should only happen to a root node when the entire vertex is empty. + bool Empty() const { + return !end_ && extend_.empty(); + } + + bool Complete() const { + return end_; + } + + const lm::ngram::ChartState &State() const { return state_; } + bool RightFull() const { return right_full_; } + + Score Bound() const { + return bound_; + } + + unsigned char Length() const { + return state_.left.length + state_.right.length; + } + + // May be NULL. + const Final *End() const { return end_; } + + const VertexNode &operator[](size_t index) const { + return *extend_[index]; + } + + size_t Size() const { + return extend_.size(); + } + + private: + std::vector extend_; + + lm::ngram::ChartState state_; + bool right_full_; + + Score bound_; + Final *end_; +}; + +class PartialVertex { + public: + PartialVertex() {} + + explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {} + + bool Empty() const { return back_->Empty(); } + + bool Complete() const { return back_->Complete(); } + + const lm::ngram::ChartState &State() const { return back_->State(); } + bool RightFull() const { return back_->RightFull(); } + + Score Bound() const { return Complete() ? back_->End()->Bound() : (*back_)[index_].Bound(); } + + unsigned char Length() const { return back_->Length(); } + + // Split into continuation and alternative, rendering this the alternative. + bool Split(PartialVertex &continuation) { + assert(!Complete()); + continuation.back_ = &((*back_)[index_]); + continuation.index_ = 0; + if (index_ + 1 < back_->Size()) { + ++index_; + return true; + } + return false; + } + + const Final &End() const { + return *back_->End(); + } + + private: + const VertexNode *back_; + unsigned int index_; +}; + +extern PartialVertex kBlankPartialVertex; + +class Vertex { + public: + Vertex() +#ifdef DEBUG + : finished_adding_(false) +#endif + {} + + void Add(Edge &edge) { +#ifdef DEBUG + assert(!finished_adding_); +#endif + edges_.push_back(&edge); + } + + void FinishedAdding() { +#ifdef DEBUG + assert(!finished_adding_); + finished_adding_ = true; +#endif + } + + PartialVertex RootPartial() const { return PartialVertex(root_); } + + private: + friend class VertexGenerator; + std::vector edges_; + +#ifdef DEBUG + bool finished_adding_; +#endif + + VertexNode root_; +}; + +} // namespace search +#endif // SEARCH_VERTEX__ diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc new file mode 100644 index 00000000..0281fc37 --- /dev/null +++ b/klm/search/vertex_generator.cc @@ -0,0 +1,99 @@ +#include "search/vertex_generator.hh" + +#include "lm/left.hh" +#include "search/context.hh" + +#include + +namespace search { + +template VertexGenerator::VertexGenerator(Context &context, Vertex &gen) : context_(context), edges_(gen.edges_.size()), partial_edge_pool_(sizeof(PartialEdge), context.PopLimit() * 2) { + for (std::size_t i = 0; i < gen.edges_.size(); ++i) { + if (edges_[i].Init(*gen.edges_[i], *this)) + generate_.push(&edges_[i]); + } + gen.root_.InitRoot(); + root_.under = &gen.root_; + to_pop_ = context.PopLimit(); + while (to_pop_ > 0 && !generate_.empty()) { + EdgeGenerator *top = generate_.top(); + generate_.pop(); + if (top->Pop(context, *this)) { + generate_.push(top); + } + } + gen.root_.SortAndSet(context, NULL); +} + +template VertexGenerator::VertexGenerator(Context &context, Vertex &gen); +template VertexGenerator::VertexGenerator(Context &context, Vertex &gen); + +namespace { +const uint64_t kCompleteAdd = static_cast(-1); +} // namespace + +void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) { + std::pair got(existing_.insert(std::pair(hash_value(state), NULL))); + if (!got.second) { + // Found it already. + Final &exists = *got.first->second; + if (exists.Bound() < partial.score) { + exists.Reset(partial.score, from.GetRule(), partial.nt[0].End(), partial.nt[1].End()); + } + --to_pop_; + return; + } + unsigned char left = 0, right = 0; + Trie *node = &root_; + while (true) { + if (left == state.left.length) { + node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, true, right, false); + for (; right < state.right.length; ++right) { + node = &FindOrInsert(*node, state.right.words[right], state, left, true, right + 1, false); + } + break; + } + node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, false); + left++; + if (right == state.right.length) { + node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, false, right, true); + for (; left < state.left.length; ++left) { + node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, true); + } + break; + } + node = &FindOrInsert(*node, state.right.words[right], state, left, false, right + 1, false); + right++; + } + + node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); + got.first->second = CompleteTransition(*node, state, from, partial); + --to_pop_; +} + +VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { + VertexGenerator::Trie &next = node.extend[added]; + if (!next.under) { + next.under = context_.NewVertexNode(); + lm::ngram::ChartState &writing = next.under->MutableState(); + writing = state; + writing.left.full &= left_full && state.left.full; + next.under->MutableRightFull() = right_full && state.left.full; + writing.left.length = left; + writing.right.length = right; + node.under->AddExtend(next.under); + } + return next; +} + +Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) { + VertexNode &node = *starter.under; + assert(node.State().left.full == state.left.full); + assert(!node.End()); + Final *final = context_.NewFinal(); + final->Reset(partial.score, from.GetRule(), partial.nt[0].End(), partial.nt[1].End()); + node.SetEnd(final); + return final; +} + +} // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh new file mode 100644 index 00000000..8cdf1420 --- /dev/null +++ b/klm/search/vertex_generator.hh @@ -0,0 +1,70 @@ +#ifndef SEARCH_VERTEX_GENERATOR__ +#define SEARCH_VERTEX_GENERATOR__ + +#include "search/edge.hh" +#include "search/edge_generator.hh" + +#include +#include + +#include + +namespace lm { +namespace ngram { +class ChartState; +} // namespace ngram +} // namespace lm + +namespace search { + +template class Context; +class ContextBase; +class Final; + +class VertexGenerator { + public: + template VertexGenerator(Context &context, Vertex &gen); + + PartialEdge *MallocPartialEdge() { return static_cast(partial_edge_pool_.malloc()); } + void FreePartialEdge(PartialEdge *value) { partial_edge_pool_.free(value); } + + void NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial); + + private: + // Parallel structure to VertexNode. + struct Trie { + Trie() : under(NULL) {} + + VertexNode *under; + boost::unordered_map extend; + }; + + Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full); + + Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial); + + ContextBase &context_; + + std::vector edges_; + + struct LessByTop : public std::binary_function { + bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { + return first->Top() < second->Top(); + } + }; + + typedef std::priority_queue, LessByTop> Generate; + Generate generate_; + + Trie root_; + + typedef boost::unordered_map Existing; + Existing existing_; + + int to_pop_; + + boost::pool<> partial_edge_pool_; +}; + +} // namespace search +#endif // SEARCH_VERTEX_GENERATOR__ diff --git a/klm/search/weights.cc b/klm/search/weights.cc new file mode 100644 index 00000000..82ff3f12 --- /dev/null +++ b/klm/search/weights.cc @@ -0,0 +1,69 @@ +#include "search/weights.hh" +#include "util/tokenize_piece.hh" + +#include + +namespace search { + +namespace { +struct Insert { + void operator()(boost::unordered_map &map, StringPiece name, search::Score score) const { + std::string copy(name.data(), name.size()); + map[copy] = score; + } +}; + +struct DotProduct { + search::Score total; + DotProduct() : total(0.0) {} + + void operator()(const boost::unordered_map &map, StringPiece name, search::Score score) { + boost::unordered_map::const_iterator i(FindStringPiece(map, name)); + if (i != map.end()) + total += score * i->second; + } +}; + +template void Parse(StringPiece text, Map &map, Op &op) { + for (util::TokenIter spaces(text, ' '); spaces; ++spaces) { + util::TokenIter equals(*spaces, '='); + UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces); + StringPiece name(*equals); + UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces); + char *end; + // Assumes proper termination. + double value = std::strtod(equals->data(), &end); + UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals); + UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces); + op(map, name, value); + } +} + +} // namespace + +Weights::Weights(StringPiece text) { + Insert op; + Parse(text, map_, op); + lm_ = Steal("LanguageModel"); + oov_ = Steal("OOV"); + word_penalty_ = Steal("WordPenalty"); +} + +search::Score Weights::DotNoLM(StringPiece text) const { + DotProduct dot; + Parse(text, map_, dot); + return dot.total; +} + +float Weights::Steal(const std::string &str) { + Map::iterator i(map_.find(str)); + if (i == map_.end()) { + return 0.0; + } else { + float ret = i->second; + map_.erase(i); + return ret; + } +} + +} // namespace search diff --git a/klm/search/weights.hh b/klm/search/weights.hh new file mode 100644 index 00000000..4a4388c7 --- /dev/null +++ b/klm/search/weights.hh @@ -0,0 +1,49 @@ +// For now, the individual features are not kept. +#ifndef SEARCH_WEIGHTS__ +#define SEARCH_WEIGHTS__ + +#include "search/types.hh" +#include "util/exception.hh" +#include "util/string_piece.hh" + +#include + +#include + +namespace search { + +class WeightParseException : public util::Exception { + public: + WeightParseException() {} + ~WeightParseException() throw() {} +}; + +class Weights { + public: + // Parses weights, sets lm_weight_, removes it from map_. + explicit Weights(StringPiece text); + + search::Score DotNoLM(StringPiece text) const; + + search::Score LM() const { return lm_; } + + search::Score OOV() const { return oov_; } + + search::Score WordPenalty() const { return word_penalty_; } + + // Mostly for testing. + const boost::unordered_map &GetMap() const { return map_; } + + private: + float Steal(const std::string &str); + + typedef boost::unordered_map Map; + + Map map_; + + search::Score lm_, oov_, word_penalty_; +}; + +} // namespace search + +#endif // SEARCH_WEIGHTS__ diff --git a/klm/search/weights_test.cc b/klm/search/weights_test.cc new file mode 100644 index 00000000..4811ff06 --- /dev/null +++ b/klm/search/weights_test.cc @@ -0,0 +1,38 @@ +#include "search/weights.hh" + +#define BOOST_TEST_MODULE WeightTest +#include +#include + +namespace search { +namespace { + +#define CHECK_WEIGHT(value, string) \ + i = parsed.find(string); \ + BOOST_REQUIRE(i != parsed.end()); \ + BOOST_CHECK_CLOSE((value), i->second, 0.001); + +BOOST_AUTO_TEST_CASE(parse) { + // These are not real feature weights. + Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); + const boost::unordered_map &parsed = w.GetMap(); + boost::unordered_map::const_iterator i; + CHECK_WEIGHT(0.0, "rarity"); + CHECK_WEIGHT(0.0, "phrase-SGT"); + CHECK_WEIGHT(9.45117, "phrase-TGS"); + CHECK_WEIGHT(2.33833, "lexical-SGT"); + BOOST_CHECK(parsed.end() == parsed.find("lm")); + BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001); + CHECK_WEIGHT(-28.3317, "lexical-TGS"); + CHECK_WEIGHT(5.0, "glue?"); +} + +BOOST_AUTO_TEST_CASE(dot) { + Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); + BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001); + BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001); + BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001); +} + +} // namespace +} // namespace search diff --git a/klm/search/word.hh b/klm/search/word.hh new file mode 100644 index 00000000..e7a15be9 --- /dev/null +++ b/klm/search/word.hh @@ -0,0 +1,47 @@ +#ifndef SEARCH_WORD__ +#define SEARCH_WORD__ + +#include "lm/word_index.hh" + +#include + +#include +#include + +namespace search { + +class Word { + public: + // Construct a non-terminal. + Word() : entry_(NULL) {} + + explicit Word(const std::pair &entry) { + entry_ = &entry; + } + + // Returns true for two non-terminals even if their labels are different (since we don't care about labels). + bool operator==(const Word &other) const { + return entry_ == other.entry_; + } + + bool Terminal() const { return entry_ != NULL; } + + const std::string &String() const { return entry_->first; } + + lm::WordIndex Index() const { return entry_->second; } + + protected: + friend size_t hash_value(const Word &word); + + const std::pair *Entry() const { return entry_; } + + private: + const std::pair *entry_; +}; + +inline size_t hash_value(const Word &word) { + return boost::hash_value(word.Entry()); +} + +} // namespace search +#endif // SEARCH_WORD__ -- cgit v1.2.3 From c26c35a9bcbb4d42ae50ad0a75c1b5fb59702bd1 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Wed, 12 Sep 2012 12:01:26 +0100 Subject: Refactor search so that it knows even less, but keeps track of edge pointers --- klm/lm/word_index.hh | 3 +++ klm/search/context.hh | 1 - klm/search/final.hh | 12 +++++------ klm/search/rule.cc | 32 +++++++++------------------- klm/search/rule.hh | 21 ++++--------------- klm/search/vertex_generator.cc | 4 ++-- klm/search/word.hh | 47 ------------------------------------------ 7 files changed, 25 insertions(+), 95 deletions(-) delete mode 100644 klm/search/word.hh (limited to 'klm/search') diff --git a/klm/lm/word_index.hh b/klm/lm/word_index.hh index 67841c30..e09557a7 100644 --- a/klm/lm/word_index.hh +++ b/klm/lm/word_index.hh @@ -2,8 +2,11 @@ #ifndef LM_WORD_INDEX__ #define LM_WORD_INDEX__ +#include + namespace lm { typedef unsigned int WordIndex; +const WordIndex kMaxWordIndex = UINT_MAX; } // namespace lm typedef lm::WordIndex LMWordIndex; diff --git a/klm/search/context.hh b/klm/search/context.hh index ae248549..27940053 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -6,7 +6,6 @@ #include "search/final.hh" #include "search/types.hh" #include "search/vertex.hh" -#include "search/word.hh" #include "util/exception.hh" #include diff --git a/klm/search/final.hh b/klm/search/final.hh index 24e6f0a5..823b8c1a 100644 --- a/klm/search/final.hh +++ b/klm/search/final.hh @@ -1,18 +1,20 @@ #ifndef SEARCH_FINAL__ #define SEARCH_FINAL__ -#include "search/rule.hh" +#include "search/arity.hh" #include "search/types.hh" #include namespace search { +class Edge; + class Final { public: typedef boost::array ChildArray; - void Reset(Score bound, const Rule &from, const Final &left, const Final &right) { + void Reset(Score bound, const Edge &from, const Final &left, const Final &right) { bound_ = bound; from_ = &from; children_[0] = &left; @@ -21,16 +23,14 @@ class Final { const ChildArray &Children() const { return children_; } - unsigned int ChildCount() const { return from_->Arity(); } - - const Rule &From() const { return *from_; } + const Edge &From() const { return *from_; } Score Bound() const { return bound_; } private: Score bound_; - const Rule *from_; + const Edge *from_; ChildArray children_; }; diff --git a/klm/search/rule.cc b/klm/search/rule.cc index a8b993eb..0a941527 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -9,47 +9,35 @@ namespace search { -template void Rule::FinishedAdding(const Context &context, Score additive, bool prepend_bos) { +template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos) { additive_ = additive; Score lm_score = 0.0; lexical_.clear(); const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound(); - for (std::vector::const_iterator word = items_.begin(); ; ++word) { + for (std::vector::const_iterator word = words.begin(); ; ++word) { lexical_.resize(lexical_.size() + 1); lm::ngram::RuleScore scorer(context.LanguageModel(), lexical_.back()); // TODO: optimize - if (prepend_bos && (word == items_.begin())) { + if (prepend_bos && (word == words.begin())) { scorer.BeginSentence(); } for (; ; ++word) { - if (word == items_.end()) { + if (word == words.end()) { lm_score += scorer.Finish(); bound_ = additive_ + context.GetWeights().LM() * lm_score; - assert(lexical_.size() == arity_ + 1); + arity_ = lexical_.size() - 1; return; } - if (!word->Terminal()) break; - if (word->Index() == oov) additive_ += context.GetWeights().OOV(); - scorer.Terminal(word->Index()); + if (*word == kNonTerminal) break; + if (*word == oov) additive_ += context.GetWeights().OOV(); + scorer.Terminal(*word); } lm_score += scorer.Finish(); } } -template void Rule::FinishedAdding(const Context &context, Score additive, bool prepend_bos); -template void Rule::FinishedAdding(const Context &context, Score additive, bool prepend_bos); - -std::ostream &operator<<(std::ostream &o, const Rule &rule) { - const Rule::ItemsRet &items = rule.Items(); - for (Rule::ItemsRet::const_iterator i = items.begin(); i != items.end(); ++i) { - if (i->Terminal()) { - o << i->String() << ' '; - } else { - o << "[] "; - } - } - return o; -} +template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos); +template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos); } // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh index 79192d40..920c64a7 100644 --- a/klm/search/rule.hh +++ b/klm/search/rule.hh @@ -2,9 +2,9 @@ #define SEARCH_RULE__ #include "lm/left.hh" +#include "lm/word_index.hh" #include "search/arity.hh" #include "search/types.hh" -#include "search/word.hh" #include @@ -19,14 +19,10 @@ class Rule { public: Rule() : arity_(0) {} - void AppendTerminal(Word w) { items_.push_back(w); } + static const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; - void AppendNonTerminal() { - items_.resize(items_.size() + 1); - ++arity_; - } - - template void FinishedAdding(const Context &context, Score additive, bool prepend_bos); + // Use kNonTerminal for non-terminals. + template void Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos); Score Bound() const { return bound_; } @@ -38,23 +34,14 @@ class Rule { return lexical_[index]; } - // For printing. - typedef const std::vector ItemsRet; - ItemsRet &Items() const { return items_; } - private: Score bound_, additive_; unsigned int arity_; - // TODO: pool? - std::vector items_; - std::vector lexical_; }; -std::ostream &operator<<(std::ostream &o, const Rule &rule); - } // namespace search #endif // SEARCH_RULE__ diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index 0281fc37..78948c97 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -38,7 +38,7 @@ void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Ed // Found it already. Final &exists = *got.first->second; if (exists.Bound() < partial.score) { - exists.Reset(partial.score, from.GetRule(), partial.nt[0].End(), partial.nt[1].End()); + exists.Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End()); } --to_pop_; return; @@ -91,7 +91,7 @@ Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const assert(node.State().left.full == state.left.full); assert(!node.End()); Final *final = context_.NewFinal(); - final->Reset(partial.score, from.GetRule(), partial.nt[0].End(), partial.nt[1].End()); + final->Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End()); node.SetEnd(final); return final; } diff --git a/klm/search/word.hh b/klm/search/word.hh deleted file mode 100644 index e7a15be9..00000000 --- a/klm/search/word.hh +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef SEARCH_WORD__ -#define SEARCH_WORD__ - -#include "lm/word_index.hh" - -#include - -#include -#include - -namespace search { - -class Word { - public: - // Construct a non-terminal. - Word() : entry_(NULL) {} - - explicit Word(const std::pair &entry) { - entry_ = &entry; - } - - // Returns true for two non-terminals even if their labels are different (since we don't care about labels). - bool operator==(const Word &other) const { - return entry_ == other.entry_; - } - - bool Terminal() const { return entry_ != NULL; } - - const std::string &String() const { return entry_->first; } - - lm::WordIndex Index() const { return entry_->second; } - - protected: - friend size_t hash_value(const Word &word); - - const std::pair *Entry() const { return entry_; } - - private: - const std::pair *entry_; -}; - -inline size_t hash_value(const Word &word) { - return boost::hash_value(word.Entry()); -} - -} // namespace search -#endif // SEARCH_WORD__ -- cgit v1.2.3 From 8505fdfdf0bc4ce9acec42e1980a2fdd4f254109 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Thu, 13 Sep 2012 11:15:32 +0100 Subject: It compiles. --- decoder/Jamfile | 2 ++ decoder/decoder.cc | 4 +++ decoder/lazy.cc | 78 +++++++++++++++++++++++++++++++++++++-------------- decoder/lazy.h | 5 +++- klm/search/config.hh | 6 ++-- klm/search/weights.cc | 2 ++ klm/search/weights.hh | 17 ++++++----- 7 files changed, 82 insertions(+), 32 deletions(-) (limited to 'klm/search') diff --git a/decoder/Jamfile b/decoder/Jamfile index da02d063..d778dc7f 100644 --- a/decoder/Jamfile +++ b/decoder/Jamfile @@ -58,10 +58,12 @@ lib decoder : rescore_translator.cc hg_remove_eps.cc hg_union.cc + lazy.cc $(glc) ..//utils ..//mteval ../klm/lm//kenlm + ../klm/search//search ..//boost_program_options : . : : diff --git a/decoder/decoder.cc b/decoder/decoder.cc index a69a6d05..3a410cf2 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -38,6 +38,7 @@ #include "sampler.h" #include "forest_writer.h" // TODO this section should probably be handled by an Observer +#include "lazy.h" #include "hg_io.h" #include "aligner.h" @@ -832,6 +833,9 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (conf.count("show_target_graph")) HypergraphIO::WriteTarget(conf["show_target_graph"].as(), sent_id, forest); + if (conf.count("lazy_search")) + PassToLazy(forest, CurrentWeightVector()); + for (int pass = 0; pass < rescoring_passes.size(); ++pass) { const RescoringPass& rp = rescoring_passes[pass]; const vector& cur_weights = *rp.weight_vector; diff --git a/decoder/lazy.cc b/decoder/lazy.cc index f5b61c75..4776c1b8 100644 --- a/decoder/lazy.cc +++ b/decoder/lazy.cc @@ -1,15 +1,23 @@ #include "hg.h" #include "lazy.h" +#include "fdict.h" #include "tdict.h" #include "lm/enumerate_vocab.hh" #include "lm/model.hh" +#include "search/config.hh" +#include "search/context.hh" #include "search/edge.hh" #include "search/vertex.hh" +#include "search/vertex_generator.hh" #include "util/exception.hh" +#include #include +#include +#include + namespace { struct MapVocab : public lm::EnumerateVocab { @@ -19,13 +27,13 @@ struct MapVocab : public lm::EnumerateVocab { // Do not call after Lookup. void Add(lm::WordIndex index, const StringPiece &str) { const WordID cdec_id = TD::Convert(str.as_string()); - if (cdec_id >= out_->size()) out_.resize(cdec_id + 1); + if (cdec_id >= out_.size()) out_.resize(cdec_id + 1); out_[cdec_id] = index; } // Assumes Add has been called and will never be called again. lm::WordIndex FromCDec(WordID id) const { - return out_[out.size() > id ? id : 0]; + return out_[out_.size() > id ? id : 0]; } private: @@ -34,44 +42,50 @@ struct MapVocab : public lm::EnumerateVocab { class LazyBase { public: - LazyBase() {} + LazyBase(const std::vector &weights) : + cdec_weights_(weights), + config_(search::Weights(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]), 1000) {} virtual ~LazyBase() {} virtual void Search(const Hypergraph &hg) const = 0; - static LazyBase *Load(const char *model_file); + static LazyBase *Load(const char *model_file, const std::vector &weights); protected: - lm::ngram::Config GetConfig() const { + lm::ngram::Config GetConfig() { lm::ngram::Config ret; ret.enumerate_vocab = &vocab_; return ret; } MapVocab vocab_; + + const std::vector &cdec_weights_; + + const search::Config config_; }; template class Lazy : public LazyBase { public: - explicit Lazy(const char *model_file) : m_(model_file, GetConfig()) {} + Lazy(const char *model_file, const std::vector &weights) : LazyBase(weights), m_(model_file, GetConfig()) {} void Search(const Hypergraph &hg) const; private: - void ConvertEdge(const Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::Edge &out) const; + void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::Edge &out) const; const Model m_; }; -static LazyBase *LazyBase::Load(const char *model_file) { +LazyBase *LazyBase::Load(const char *model_file, const std::vector &weights) { lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(lm_name, model_type)) model_type = lm::ngram::PROBING; + if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; switch (model_type) { case lm::ngram::PROBING: - return new Lazy(model_file); + return new Lazy(model_file, weights); case lm::ngram::REST_PROBING: - return new Lazy(model_file); + return new Lazy(model_file, weights); default: UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); } @@ -80,25 +94,41 @@ static LazyBase *LazyBase::Load(const char *model_file) { template void Lazy::Search(const Hypergraph &hg) const { boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); boost::scoped_array out_edges(new search::Edge[hg.edges_.size()]); + + search::Context context(config_, m_); + for (unsigned int i = 0; i < hg.nodes_.size(); ++i) { - search::Vertex *out_vertex = out_vertices[i]; + search::Vertex &out_vertex = out_vertices[i]; const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; - for (unsigned int j = 0; j < edges.size(); ++j) { + for (unsigned int j = 0; j < down_edges.size(); ++j) { unsigned int edge_index = down_edges[j]; - const Hypergraph::Edge &in_edge = hg.edges_[edge_index]; - search::Edge &out_edge = out_edges[edge_index]; + ConvertEdge(context, i == hg.nodes_.size() - 1, out_vertices.get(), hg.edges_[edge_index], out_edges[edge_index]); + out_vertex.Add(out_edges[edge_index]); } + out_vertex.FinishedAdding(); + search::VertexGenerator(context, out_vertex); + } + search::PartialVertex top = out_vertices[hg.nodes_.size() - 1].RootPartial(); + if (top.Empty()) { + std::cout << "NO PATH FOUND"; + } else { + search::PartialVertex continuation; + while (!top.Complete()) { + top.Split(continuation); + top = continuation; + } + std::cout << top.End().Bound() << std::endl; } } // TODO: get weights into here somehow. -template void Lazy::ConvertEdge(const Context &context, bool final, search::Vertices *vertices, const Hypergraph::Edge &in, search::Edge &out) const { - const std::vector &e = in_edge.rule_->e(); +template void Lazy::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::Edge &out) const { + const std::vector &e = in.rule_->e(); std::vector words; unsigned int terminals = 0; for (std::vector::const_iterator word = e.begin(); word != e.end(); ++word) { if (*word <= 0) { - out.Add(vertices[edge.tail_nodes_[-*word]]); + out.Add(vertices[in.tail_nodes_[-*word]]); words.push_back(lm::kMaxWordIndex); } else { ++terminals; @@ -110,13 +140,19 @@ template void Lazy::ConvertEdge(const Context &conte words.push_back(m_.GetVocabulary().EndSentence()); } - float additive = edge.rule_->GetFeatureValues().dot(weight_vector); + float additive = in.rule_->GetFeatureValues().dot(cdec_weights_); + additive -= terminals * context.GetWeights().WordPenalty() * static_cast(terminals) / M_LN10; out.InitRule().Init(context, additive, words, final); } -} // namespace +boost::scoped_ptr AwfulGlobalLazy; -void PassToLazy(const Hypergraph &hg) { +} // namespace +void PassToLazy(const Hypergraph &hg, const std::vector &weights) { + if (!AwfulGlobalLazy.get()) { + AwfulGlobalLazy.reset(LazyBase::Load("lm", weights)); + } + AwfulGlobalLazy->Search(hg); } diff --git a/decoder/lazy.h b/decoder/lazy.h index aecd030d..3e71a3b0 100644 --- a/decoder/lazy.h +++ b/decoder/lazy.h @@ -1,8 +1,11 @@ #ifndef _LAZY_H_ #define _LAZY_H_ +#include "weights.h" +#include + class Hypergraph; -void PassToLazy(const Hypergraph &hg); +void PassToLazy(const Hypergraph &hg, const std::vector &weights); #endif // _LAZY_H_ diff --git a/klm/search/config.hh b/klm/search/config.hh index e21e4b7c..ef8e2354 100644 --- a/klm/search/config.hh +++ b/klm/search/config.hh @@ -8,15 +8,15 @@ namespace search { class Config { public: - Config(StringPiece weight_str, unsigned int pop_limit) : - weights_(weight_str), pop_limit_(pop_limit) {} + Config(const Weights &weights, unsigned int pop_limit) : + weights_(weights), pop_limit_(pop_limit) {} const Weights &GetWeights() const { return weights_; } unsigned int PopLimit() const { return pop_limit_; } private: - search::Weights weights_; + Weights weights_; unsigned int pop_limit_; }; diff --git a/klm/search/weights.cc b/klm/search/weights.cc index 82ff3f12..d65471ad 100644 --- a/klm/search/weights.cc +++ b/klm/search/weights.cc @@ -49,6 +49,8 @@ Weights::Weights(StringPiece text) { word_penalty_ = Steal("WordPenalty"); } +Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {} + search::Score Weights::DotNoLM(StringPiece text) const { DotProduct dot; Parse(text, map_, dot); diff --git a/klm/search/weights.hh b/klm/search/weights.hh index 4a4388c7..df1c419f 100644 --- a/klm/search/weights.hh +++ b/klm/search/weights.hh @@ -23,25 +23,28 @@ class Weights { // Parses weights, sets lm_weight_, removes it from map_. explicit Weights(StringPiece text); - search::Score DotNoLM(StringPiece text) const; + // Just the three scores we care about adding. + Weights(Score lm, Score oov, Score word_penalty); - search::Score LM() const { return lm_; } + Score DotNoLM(StringPiece text) const; - search::Score OOV() const { return oov_; } + Score LM() const { return lm_; } - search::Score WordPenalty() const { return word_penalty_; } + Score OOV() const { return oov_; } + + Score WordPenalty() const { return word_penalty_; } // Mostly for testing. - const boost::unordered_map &GetMap() const { return map_; } + const boost::unordered_map &GetMap() const { return map_; } private: float Steal(const std::string &str); - typedef boost::unordered_map Map; + typedef boost::unordered_map Map; Map map_; - search::Score lm_, oov_, word_penalty_; + Score lm_, oov_, word_penalty_; }; } // namespace search -- cgit v1.2.3 From 9b99cb844e3e379b557ff8578df27893ce147f1a Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sun, 14 Oct 2012 10:46:34 +0100 Subject: Update to faster but less cute search --- decoder/lazy.cc | 57 +++++++++++++++++---------------- klm/lm/fragment.cc | 37 +++++++++++++++++++++ klm/search/Jamfile | 2 +- klm/search/edge.hh | 31 +++--------------- klm/search/edge_generator.cc | 53 +++++++++++++----------------- klm/search/edge_generator.hh | 24 ++++++++------ klm/search/edge_queue.cc | 25 +++++++++++++++ klm/search/edge_queue.hh | 73 ++++++++++++++++++++++++++++++++++++++++++ klm/search/final.hh | 11 +++---- klm/search/note.hh | 12 +++++++ klm/search/rule.cc | 32 +++++++++--------- klm/search/rule.hh | 31 ++---------------- klm/search/vertex.hh | 45 +++++++++++--------------- klm/search/vertex_generator.cc | 32 +++++------------- klm/search/vertex_generator.hh | 35 +++++++------------- 15 files changed, 279 insertions(+), 221 deletions(-) create mode 100644 klm/lm/fragment.cc create mode 100644 klm/search/edge_queue.cc create mode 100644 klm/search/edge_queue.hh create mode 100644 klm/search/note.hh (limited to 'klm/search') diff --git a/decoder/lazy.cc b/decoder/lazy.cc index c4138d7b..9dc657d6 100644 --- a/decoder/lazy.cc +++ b/decoder/lazy.cc @@ -8,6 +8,7 @@ #include "search/config.hh" #include "search/context.hh" #include "search/edge.hh" +#include "search/edge_queue.hh" #include "search/vertex.hh" #include "search/vertex_generator.hh" #include "util/exception.hh" @@ -75,7 +76,7 @@ template class Lazy : public LazyBase { void Search(unsigned int pop_limit, const Hypergraph &hg) const; private: - void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::Edge &out) const; + unsigned char ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const; const Model m_; }; @@ -93,73 +94,73 @@ LazyBase *LazyBase::Load(const char *model_file, const std::vector &we } } -void PrintFinal(const Hypergraph &hg, const search::Edge *edge_base, const search::Final &final) { - const std::vector &words = hg.edges_[&final.From() - edge_base].rule_->e(); +void PrintFinal(const Hypergraph &hg, const search::Final &final) { + const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); boost::array::const_iterator child(final.Children().begin()); for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { if (*i > 0) { std::cout << TD::Convert(*i) << ' '; } else { - PrintFinal(hg, edge_base, **child++); + PrintFinal(hg, **child++); } } } template void Lazy::Search(unsigned int pop_limit, const Hypergraph &hg) const { boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); - boost::scoped_array out_edges(new search::Edge[hg.edges_.size()]); search::Config config(weights_, pop_limit); search::Context context(config, m_); for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) { - search::Vertex &out_vertex = out_vertices[i]; + search::EdgeQueue queue(context.PopLimit()); const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; for (unsigned int j = 0; j < down_edges.size(); ++j) { unsigned int edge_index = down_edges[j]; - ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], out_edges[edge_index]); - out_vertex.Add(out_edges[edge_index]); + unsigned char arity = ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], queue.InitializeEdge()); + search::Note note; + note.vp = &hg.edges_[edge_index]; + if (arity != 255) queue.AddEdge(arity, note); } - out_vertex.FinishedAdding(); - search::VertexGenerator(context, out_vertex); + search::VertexGenerator vertex_gen(context, out_vertices[i]); + queue.Search(context, vertex_gen); } - search::PartialVertex top = out_vertices[hg.nodes_.size() - 2].RootPartial(); - if (top.Empty()) { - std::cout << "NO PATH FOUND"; + const search::Final *top = out_vertices[hg.nodes_.size() - 2].BestChild(); + if (!top) { + std::cout << "NO PATH FOUND" << std::endl; } else { - search::PartialVertex continuation; - while (!top.Complete()) { - top.Split(continuation); - top = continuation; - } - PrintFinal(hg, out_edges.get(), top.End()); - std::cout << "||| " << top.End().Bound() << std::endl; + PrintFinal(hg, *top); + std::cout << "||| " << top->Bound() << std::endl; } } -// TODO: get weights into here somehow. -template void Lazy::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::Edge &out) const { +template unsigned char Lazy::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const { const std::vector &e = in.rule_->e(); std::vector words; unsigned int terminals = 0; + unsigned char nt = 0; for (std::vector::const_iterator word = e.begin(); word != e.end(); ++word) { if (*word <= 0) { - out.Add(vertices[in.tail_nodes_[-*word]]); + out.nt[nt] = vertices[in.tail_nodes_[-*word]].RootPartial(); + if (out.nt[nt].Empty()) return 255; + ++nt; words.push_back(lm::kMaxWordIndex); } else { ++terminals; words.push_back(vocab_.FromCDec(*word)); } } + for (unsigned char fill = nt; fill < search::kMaxArity; ++fill) { + out.nt[nt] = search::kBlankPartialVertex; + } if (final) { words.push_back(m_.GetVocabulary().EndSentence()); } - float additive = in.rule_->GetFeatureValues().dot(cdec_weights_); - UTIL_THROW_IF(isnan(additive), util::Exception, "Bad dot product"); - additive -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; - - out.InitRule().Init(context, additive, words, final); + out.score = in.rule_->GetFeatureValues().dot(cdec_weights_); + out.score -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; + out.score += search::ScoreRule(context, words, final, out.between); + return nt; } boost::scoped_ptr AwfulGlobalLazy; diff --git a/klm/lm/fragment.cc b/klm/lm/fragment.cc new file mode 100644 index 00000000..0267cd4e --- /dev/null +++ b/klm/lm/fragment.cc @@ -0,0 +1,37 @@ +#include "lm/binary_format.hh" +#include "lm/model.hh" +#include "lm/left.hh" +#include "util/tokenize_piece.hh" + +template void Query(const char *name) { + Model model(name); + std::string line; + lm::ngram::ChartState ignored; + while (getline(std::cin, line)) { + lm::ngram::RuleScore scorer(model, ignored); + for (util::TokenIter i(line, ' '); i; ++i) { + scorer.Terminal(model.GetVocabulary().Index(*i)); + } + std::cout << scorer.Finish() << '\n'; + } +} + +int main(int argc, char *argv[]) { + if (argc != 2) { + std::cerr << "Expected model file name." << std::endl; + return 1; + } + const char *name = argv[1]; + lm::ngram::ModelType model_type = lm::ngram::PROBING; + lm::ngram::RecognizeBinary(name, model_type); + switch (model_type) { + case lm::ngram::PROBING: + Query(name); + break; + case lm::ngram::REST_PROBING: + Query(name); + break; + default: + std::cerr << "Model type not supported yet." << std::endl; + } +} diff --git a/klm/search/Jamfile b/klm/search/Jamfile index ac47c249..e8b14363 100644 --- a/klm/search/Jamfile +++ b/klm/search/Jamfile @@ -1,4 +1,4 @@ -lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil : : : .. ; +lib search : weights.cc vertex.cc vertex_generator.cc edge_queue.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : .. ; import testing ; diff --git a/klm/search/edge.hh b/klm/search/edge.hh index 4d2a5cbf..77ab0ade 100644 --- a/klm/search/edge.hh +++ b/klm/search/edge.hh @@ -11,33 +11,6 @@ namespace search { -class Edge { - public: - Edge() { - end_to_ = to_; - } - - Rule &InitRule() { return rule_; } - - void Add(Vertex &vertex) { - assert(end_to_ - to_ < kMaxArity); - *(end_to_++) = &vertex; - } - - const Vertex &GetVertex(std::size_t index) const { - return *to_[index]; - } - - const Rule &GetRule() const { return rule_; } - - private: - // Rule and pointers to rule arguments. - Rule rule_; - - Vertex *to_[kMaxArity]; - Vertex **end_to_; -}; - struct PartialEdge { Score score; // Terminals @@ -45,6 +18,10 @@ struct PartialEdge { // Non-terminals PartialVertex nt[kMaxArity]; + const lm::ngram::ChartState &CompletedState() const { + return between[0]; + } + bool operator<(const PartialEdge &other) const { return score < other.score; } diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index d135899a..56239dfb 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -10,28 +10,15 @@ namespace search { -bool EdgeGenerator::Init(Edge &edge, VertexGenerator &parent) { - from_ = &edge; - for (unsigned int i = 0; i < GetRule().Arity(); ++i) { - if (edge.GetVertex(i).RootPartial().Empty()) return false; - } - PartialEdge &root = *parent.MallocPartialEdge(); - root.score = GetRule().Bound(); - for (unsigned int i = 0; i < GetRule().Arity(); ++i) { +EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) { +/* for (unsigned char i = 0; i < edge.Arity(); ++i) { root.nt[i] = edge.GetVertex(i).RootPartial(); - root.score += root.nt[i].Bound(); } - for (unsigned int i = GetRule().Arity(); i < 2; ++i) { + for (unsigned char i = edge.Arity(); i < 2; ++i) { root.nt[i] = kBlankPartialVertex; - } - for (unsigned int i = 0; i < GetRule().Arity() + 1; ++i) { - root.between[i] = GetRule().Lexical(i); - } - // wtf no clear method? - generate_ = Generate(); + }*/ generate_.push(&root); - top_ = root.score; - return true; + top_score_ = root.score; } namespace { @@ -78,13 +65,13 @@ template float FastScore(const Context &context, unsigned c } // namespace -template bool EdgeGenerator::Pop(Context &context, VertexGenerator &parent) { +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool) { assert(!generate_.empty()); PartialEdge &top = *generate_.top(); generate_.pop(); unsigned int victim = 0; unsigned char lowest_length = 255; - for (unsigned int i = 0; i != GetRule().Arity(); ++i) { + for (unsigned char i = 0; i != arity_; ++i) { if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) { lowest_length = top.nt[i].Length(); victim = i; @@ -92,21 +79,21 @@ template bool EdgeGenerator::Pop(Context &context, VertexGe } if (lowest_length == 255) { // All states report complete. - top.between[0].right = top.between[GetRule().Arity()].right; - parent.NewHypothesis(top.between[0], *from_, top); - top_ = generate_.empty() ? -kScoreInf : generate_.top()->score; - return !generate_.empty(); + top.between[0].right = top.between[arity_].right; + // Now top.between[0] is the full edge state. + top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score; + return ⊤ } unsigned int stay = !victim; - PartialEdge &continuation = *parent.MallocPartialEdge(); + PartialEdge &continuation = *static_cast(partial_edge_pool.malloc()); float old_bound = top.nt[victim].Bound(); // The alternate's score will change because alternate.nt[victim] changes. bool split = top.nt[victim].Split(continuation.nt[victim]); // top is now the alternate. continuation.nt[stay] = top.nt[stay]; - continuation.score = FastScore(context, victim, GetRule().Arity(), top, continuation); + continuation.score = FastScore(context, victim, arity_, top, continuation); // TODO: dedupe? generate_.push(&continuation); @@ -116,14 +103,18 @@ template bool EdgeGenerator::Pop(Context &context, VertexGe // TODO: dedupe? generate_.push(&top); } else { - parent.FreePartialEdge(&top); + partial_edge_pool.free(&top); } - top_ = generate_.top()->score; - return true; + top_score_ = generate_.top()->score; + return NULL; } -template bool EdgeGenerator::Pop(Context &context, VertexGenerator &parent); -template bool EdgeGenerator::Pop(Context &context, VertexGenerator &parent); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); } // namespace search diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh index e306dc61..875ccc5e 100644 --- a/klm/search/edge_generator.hh +++ b/klm/search/edge_generator.hh @@ -2,7 +2,9 @@ #define SEARCH_EDGE_GENERATOR__ #include "search/edge.hh" +#include "search/note.hh" +#include #include #include @@ -28,26 +30,28 @@ struct PartialEdgePointerLess : std::binary_function bool Pop(Context &context, VertexGenerator &parent); + Note GetNote() const { + return note_; + } + + // Pop. If there's a complete hypothesis, return it. Otherwise return NULL. + template PartialEdge *Pop(Context &context, boost::pool<> &partial_edge_pool); private: - const Rule &GetRule() const { - return from_->GetRule(); - } + Score top_score_; - Score top_; + unsigned char arity_; typedef std::priority_queue, PartialEdgePointerLess> Generate; Generate generate_; - Edge *from_; + Note note_; }; } // namespace search diff --git a/klm/search/edge_queue.cc b/klm/search/edge_queue.cc new file mode 100644 index 00000000..e3ae6ebf --- /dev/null +++ b/klm/search/edge_queue.cc @@ -0,0 +1,25 @@ +#include "search/edge_queue.hh" + +#include "lm/left.hh" +#include "search/context.hh" + +#include + +namespace search { + +EdgeQueue::EdgeQueue(unsigned int pop_limit_hint) : partial_edge_pool_(sizeof(PartialEdge), pop_limit_hint * 2) { + take_ = static_cast(partial_edge_pool_.malloc()); +} + +/*void EdgeQueue::AddEdge(PartialEdge &root, unsigned char arity, Note note) { + // Ignore empty edges. + for (unsigned char i = 0; i < edge.Arity(); ++i) { + PartialVertex root(edge.GetVertex(i).RootPartial()); + if (root.Empty()) return; + total_score += root.Bound(); + } + PartialEdge &allocated = *static_cast(partial_edge_pool_.malloc()); + allocated.score = total_score; +}*/ + +} // namespace search diff --git a/klm/search/edge_queue.hh b/klm/search/edge_queue.hh new file mode 100644 index 00000000..187eaed7 --- /dev/null +++ b/klm/search/edge_queue.hh @@ -0,0 +1,73 @@ +#ifndef SEARCH_EDGE_QUEUE__ +#define SEARCH_EDGE_QUEUE__ + +#include "search/edge.hh" +#include "search/edge_generator.hh" +#include "search/note.hh" + +#include +#include + +#include + +namespace search { + +template class Context; + +class EdgeQueue { + public: + explicit EdgeQueue(unsigned int pop_limit_hint); + + PartialEdge &InitializeEdge() { + return *take_; + } + + void AddEdge(unsigned char arity, Note note) { + generate_.push(edge_pool_.construct(*take_, arity, note)); + take_ = static_cast(partial_edge_pool_.malloc()); + } + + bool Empty() const { return generate_.empty(); } + + /* Generate hypotheses and send them to output. Normally, output is a + * VertexGenerator, but the decoder may want to route edges to different + * vertices i.e. if they have different LHS non-terminal labels. + */ + template void Search(Context &context, Output &output) { + int to_pop = context.PopLimit(); + while (to_pop > 0 && !generate_.empty()) { + EdgeGenerator *top = generate_.top(); + generate_.pop(); + PartialEdge *ret = top->Pop(context, partial_edge_pool_); + if (ret) { + output.NewHypothesis(*ret, top->GetNote()); + --to_pop; + if (top->TopScore() != -kScoreInf) { + generate_.push(top); + } + } else { + generate_.push(top); + } + } + output.FinishedSearch(); + } + + private: + boost::object_pool edge_pool_; + + struct LessByTopScore : public std::binary_function { + bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { + return first->TopScore() < second->TopScore(); + } + }; + + typedef std::priority_queue, LessByTopScore> Generate; + Generate generate_; + + boost::pool<> partial_edge_pool_; + + PartialEdge *take_; +}; + +} // namespace search +#endif // SEARCH_EDGE_QUEUE__ diff --git a/klm/search/final.hh b/klm/search/final.hh index 823b8c1a..1b3092ac 100644 --- a/klm/search/final.hh +++ b/klm/search/final.hh @@ -2,35 +2,34 @@ #define SEARCH_FINAL__ #include "search/arity.hh" +#include "search/note.hh" #include "search/types.hh" #include namespace search { -class Edge; - class Final { public: typedef boost::array ChildArray; - void Reset(Score bound, const Edge &from, const Final &left, const Final &right) { + void Reset(Score bound, Note note, const Final &left, const Final &right) { bound_ = bound; - from_ = &from; + note_ = note; children_[0] = &left; children_[1] = &right; } const ChildArray &Children() const { return children_; } - const Edge &From() const { return *from_; } + Note GetNote() const { return note_; } Score Bound() const { return bound_; } private: Score bound_; - const Edge *from_; + Note note_; ChildArray children_; }; diff --git a/klm/search/note.hh b/klm/search/note.hh new file mode 100644 index 00000000..50bed06e --- /dev/null +++ b/klm/search/note.hh @@ -0,0 +1,12 @@ +#ifndef SEARCH_NOTE__ +#define SEARCH_NOTE__ + +namespace search { + +union Note { + const void *vp; +}; + +} // namespace search + +#endif // SEARCH_NOTE__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc index 0a941527..5b00207e 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -9,35 +9,35 @@ namespace search { -template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos) { - additive_ = additive; - Score lm_score = 0.0; - lexical_.clear(); - const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound(); - +template float ScoreRule(const Context &context, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing) { + unsigned int oov_count = 0; + float prob = 0.0; + const Model &model = context.LanguageModel(); + const lm::WordIndex oov = model.GetVocabulary().NotFound(); for (std::vector::const_iterator word = words.begin(); ; ++word) { - lexical_.resize(lexical_.size() + 1); - lm::ngram::RuleScore scorer(context.LanguageModel(), lexical_.back()); + lm::ngram::RuleScore scorer(model, *(writing++)); // TODO: optimize if (prepend_bos && (word == words.begin())) { scorer.BeginSentence(); } for (; ; ++word) { if (word == words.end()) { - lm_score += scorer.Finish(); - bound_ = additive_ + context.GetWeights().LM() * lm_score; - arity_ = lexical_.size() - 1; - return; + prob += scorer.Finish(); + return static_cast(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM(); } if (*word == kNonTerminal) break; - if (*word == oov) additive_ += context.GetWeights().OOV(); + if (*word == oov) ++oov_count; scorer.Terminal(*word); } - lm_score += scorer.Finish(); + prob += scorer.Finish(); } } -template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos); -template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); } // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh index 920c64a7..0ce2794d 100644 --- a/klm/search/rule.hh +++ b/klm/search/rule.hh @@ -3,44 +3,17 @@ #include "lm/left.hh" #include "lm/word_index.hh" -#include "search/arity.hh" #include "search/types.hh" -#include - -#include #include namespace search { template class Context; -class Rule { - public: - Rule() : arity_(0) {} - - static const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; - - // Use kNonTerminal for non-terminals. - template void Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos); - - Score Bound() const { return bound_; } - - Score Additive() const { return additive_; } - - unsigned int Arity() const { return arity_; } - - const lm::ngram::ChartState &Lexical(unsigned int index) const { - return lexical_[index]; - } - - private: - Score bound_, additive_; - - unsigned int arity_; +const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; - std::vector lexical_; -}; +template float ScoreRule(const Context &context, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *state_out); } // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index 7ef29efc..e1a9ad11 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -16,8 +16,6 @@ namespace search { class ContextBase; -class Edge; - class VertexNode { public: VertexNode() : end_(NULL) {} @@ -103,6 +101,10 @@ class PartialVertex { unsigned char Length() const { return back_->Length(); } + bool HasAlternative() const { + return index_ + 1 < back_->Size(); + } + // Split into continuation and alternative, rendering this the alternative. bool Split(PartialVertex &continuation) { assert(!Complete()); @@ -128,35 +130,26 @@ extern PartialVertex kBlankPartialVertex; class Vertex { public: - Vertex() -#ifdef DEBUG - : finished_adding_(false) -#endif - {} - - void Add(Edge &edge) { -#ifdef DEBUG - assert(!finished_adding_); -#endif - edges_.push_back(&edge); - } - - void FinishedAdding() { -#ifdef DEBUG - assert(!finished_adding_); - finished_adding_ = true; -#endif - } + Vertex() {} PartialVertex RootPartial() const { return PartialVertex(root_); } + const Final *BestChild() const { + PartialVertex top(RootPartial()); + if (top.Empty()) { + return NULL; + } else { + PartialVertex continuation; + while (!top.Complete()) { + top.Split(continuation); + top = continuation; + } + return &top.End(); + } + } + private: friend class VertexGenerator; - std::vector edges_; - -#ifdef DEBUG - bool finished_adding_; -#endif VertexNode root_; }; diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index 78948c97..d94e6e06 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -2,45 +2,30 @@ #include "lm/left.hh" #include "search/context.hh" +#include "search/edge.hh" #include namespace search { -template VertexGenerator::VertexGenerator(Context &context, Vertex &gen) : context_(context), edges_(gen.edges_.size()), partial_edge_pool_(sizeof(PartialEdge), context.PopLimit() * 2) { - for (std::size_t i = 0; i < gen.edges_.size(); ++i) { - if (edges_[i].Init(*gen.edges_[i], *this)) - generate_.push(&edges_[i]); - } +VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { gen.root_.InitRoot(); root_.under = &gen.root_; - to_pop_ = context.PopLimit(); - while (to_pop_ > 0 && !generate_.empty()) { - EdgeGenerator *top = generate_.top(); - generate_.pop(); - if (top->Pop(context, *this)) { - generate_.push(top); - } - } - gen.root_.SortAndSet(context, NULL); } -template VertexGenerator::VertexGenerator(Context &context, Vertex &gen); -template VertexGenerator::VertexGenerator(Context &context, Vertex &gen); - namespace { const uint64_t kCompleteAdd = static_cast(-1); } // namespace -void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) { +void VertexGenerator::NewHypothesis(const PartialEdge &partial, Note note) { + const lm::ngram::ChartState &state = partial.CompletedState(); std::pair got(existing_.insert(std::pair(hash_value(state), NULL))); if (!got.second) { // Found it already. Final &exists = *got.first->second; if (exists.Bound() < partial.score) { - exists.Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End()); + exists.Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); } - --to_pop_; return; } unsigned char left = 0, right = 0; @@ -67,8 +52,7 @@ void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Ed } node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); - got.first->second = CompleteTransition(*node, state, from, partial); - --to_pop_; + got.first->second = CompleteTransition(*node, state, note, partial); } VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { @@ -86,12 +70,12 @@ VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node return next; } -Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) { +Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial) { VertexNode &node = *starter.under; assert(node.State().left.full == state.left.full); assert(!node.End()); Final *final = context_.NewFinal(); - final->Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End()); + final->Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); node.SetEnd(final); return final; } diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 8cdf1420..6b98da3e 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -1,10 +1,9 @@ #ifndef SEARCH_VERTEX_GENERATOR__ #define SEARCH_VERTEX_GENERATOR__ -#include "search/edge.hh" -#include "search/edge_generator.hh" +#include "search/note.hh" +#include "search/vertex.hh" -#include #include #include @@ -17,18 +16,21 @@ class ChartState; namespace search { -template class Context; class ContextBase; class Final; +struct PartialEdge; class VertexGenerator { public: - template VertexGenerator(Context &context, Vertex &gen); + VertexGenerator(ContextBase &context, Vertex &gen); - PartialEdge *MallocPartialEdge() { return static_cast(partial_edge_pool_.malloc()); } - void FreePartialEdge(PartialEdge *value) { partial_edge_pool_.free(value); } + void NewHypothesis(const PartialEdge &partial, Note note); - void NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial); + void FinishedSearch() { + root_.under->SortAndSet(context_, NULL); + } + + const Vertex &Generating() const { return gen_; } private: // Parallel structure to VertexNode. @@ -41,29 +43,16 @@ class VertexGenerator { Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full); - Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial); + Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial); ContextBase &context_; - std::vector edges_; - - struct LessByTop : public std::binary_function { - bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { - return first->Top() < second->Top(); - } - }; - - typedef std::priority_queue, LessByTop> Generate; - Generate generate_; + Vertex &gen_; Trie root_; typedef boost::unordered_map Existing; Existing existing_; - - int to_pop_; - - boost::pool<> partial_edge_pool_; }; } // namespace search -- cgit v1.2.3 From 1fb7bfbbe287e868522613871ed6ca74369ed2a1 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 22 Oct 2012 14:04:27 +0100 Subject: Update search, make it compile --- Makefile.am | 1 + configure.ac | 6 +- decoder/Makefile.am | 3 +- decoder/decoder.cc | 8 +- decoder/incremental.cc | 184 +++++++++++++++++++++++++++++++++++++++ decoder/incremental.h | 11 +++ decoder/lazy.cc | 178 -------------------------------------- decoder/lazy.h | 11 --- dtrain/Makefile.am | 2 +- klm/alone/Jamfile | 4 - klm/alone/assemble.cc | 76 ---------------- klm/alone/assemble.hh | 21 ----- klm/alone/graph.hh | 87 ------------------- klm/alone/just_vocab.cc | 14 --- klm/alone/labeled_edge.hh | 30 ------- klm/alone/main.cc | 85 ------------------ klm/alone/read.cc | 118 ------------------------- klm/alone/read.hh | 29 ------- klm/alone/threading.cc | 80 ----------------- klm/alone/threading.hh | 129 --------------------------- klm/alone/vocab.cc | 19 ---- klm/alone/vocab.hh | 34 -------- klm/lm/model.cc | 2 +- klm/lm/vocab.cc | 4 +- klm/lm/vocab.hh | 2 +- klm/search/Jamfile | 2 +- klm/search/Makefile.am | 11 +++ klm/search/arity.hh | 8 -- klm/search/context.hh | 10 +-- klm/search/edge.hh | 53 ++++++++---- klm/search/edge_generator.cc | 144 ++++++++++++++----------------- klm/search/edge_generator.hh | 49 ++++++----- klm/search/edge_queue.cc | 25 ------ klm/search/edge_queue.hh | 73 ---------------- klm/search/final.hh | 41 ++++----- klm/search/header.hh | 57 ++++++++++++ klm/search/source.hh | 48 ----------- klm/search/types.hh | 8 +- klm/search/vertex.cc | 10 +-- klm/search/vertex.hh | 55 ++++++------ klm/search/vertex_generator.cc | 97 ++++++++++++--------- klm/search/vertex_generator.hh | 33 +++---- klm/util/Makefile.am | 2 + klm/util/ersatz_progress.hh | 2 +- klm/util/exception.hh | 2 +- klm/util/pool.cc | 35 ++++++++ klm/util/pool.hh | 45 ++++++++++ klm/util/probing_hash_table.hh | 2 +- klm/util/string_piece.cc | 192 +++++++++++++++++++++++++++++++++++++++++ klm/util/tokenize_piece.hh | 12 +++ mira/Makefile.am | 2 +- training/Makefile.am | 38 ++++---- 52 files changed, 838 insertions(+), 1356 deletions(-) create mode 100644 decoder/incremental.cc create mode 100644 decoder/incremental.h delete mode 100644 decoder/lazy.cc delete mode 100644 decoder/lazy.h delete mode 100644 klm/alone/Jamfile delete mode 100644 klm/alone/assemble.cc delete mode 100644 klm/alone/assemble.hh delete mode 100644 klm/alone/graph.hh delete mode 100644 klm/alone/just_vocab.cc delete mode 100644 klm/alone/labeled_edge.hh delete mode 100644 klm/alone/main.cc delete mode 100644 klm/alone/read.cc delete mode 100644 klm/alone/read.hh delete mode 100644 klm/alone/threading.cc delete mode 100644 klm/alone/threading.hh delete mode 100644 klm/alone/vocab.cc delete mode 100644 klm/alone/vocab.hh create mode 100644 klm/search/Makefile.am delete mode 100644 klm/search/arity.hh delete mode 100644 klm/search/edge_queue.cc delete mode 100644 klm/search/edge_queue.hh create mode 100644 klm/search/header.hh delete mode 100644 klm/search/source.hh create mode 100644 klm/util/pool.cc create mode 100644 klm/util/pool.hh create mode 100644 klm/util/string_piece.cc (limited to 'klm/search') diff --git a/Makefile.am b/Makefile.am index 3e0103a8..fefc470d 100644 --- a/Makefile.am +++ b/Makefile.am @@ -6,6 +6,7 @@ SUBDIRS = \ mteval \ klm/util \ klm/lm \ + klm/search \ decoder \ training \ training/liblbfgs \ diff --git a/configure.ac b/configure.ac index 03a0ee87..cb132d66 100644 --- a/configure.ac +++ b/configure.ac @@ -12,6 +12,7 @@ AC_PROG_CXX AC_LANG_CPLUSPLUS BOOST_REQUIRE([1.44]) BOOST_PROGRAM_OPTIONS +BOOST_SYSTEM BOOST_TEST AM_PATH_PYTHON AC_CHECK_HEADER(dlfcn.h,AC_DEFINE(HAVE_DLFCN_H)) @@ -73,9 +74,9 @@ fi #BOOST_THREADS CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS" -LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS" +LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS $BOOST_SYSTEM_LDFLAGS" # $BOOST_THREAD_LDFLAGS" -LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS" +LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS $BOOST_SYSTEM_LIBS" # $BOOST_THREAD_LIBS" AC_CHECK_HEADER(google/dense_hash_map, @@ -123,6 +124,7 @@ AC_CONFIG_FILES([rampion/Makefile]) AC_CONFIG_FILES([minrisk/Makefile]) AC_CONFIG_FILES([klm/util/Makefile]) AC_CONFIG_FILES([klm/lm/Makefile]) +AC_CONFIG_FILES([klm/search/Makefile]) AC_CONFIG_FILES([mira/Makefile]) AC_CONFIG_FILES([dtrain/Makefile]) AC_CONFIG_FILES([example_extff/Makefile]) diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 5c0a1964..f8f427d3 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -17,7 +17,7 @@ trule_test_SOURCES = trule_test.cc trule_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz cdec_SOURCES = cdec.cc -cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/search/libksearch.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. -I../mteval -I../utils -I../klm @@ -73,6 +73,7 @@ libcdec_a_SOURCES = \ ff_source_syntax.cc \ ff_bleu.cc \ ff_factory.cc \ + incremental.cc \ lexalign.cc \ lextrans.cc \ tagger.cc \ diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 052823ca..fe812011 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -39,7 +39,7 @@ #include "sampler.h" #include "forest_writer.h" // TODO this section should probably be handled by an Observer -#include "lazy.h" +#include "incremental.h" #include "hg_io.h" #include "aligner.h" @@ -412,7 +412,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation") ("show_cfg_search_space", "Show the search space as a CFG") ("show_target_graph", po::value(), "Directory to write the target hypergraphs to") - ("lazy_search", po::value(), "Run lazy search with this language model file") + ("incremental_search", po::value(), "Run lazy search with this language model file") ("coarse_to_fine_beam_prune", po::value(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)") ("ctf_beam_widen", po::value()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found") ("ctf_num_widenings", po::value()->default_value(2), "Widen coarse beam this many times before backing off to full parse") @@ -828,8 +828,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (conf.count("show_target_graph")) HypergraphIO::WriteTarget(conf["show_target_graph"].as(), sent_id, forest); - if (conf.count("lazy_search")) { - PassToLazy(conf["lazy_search"].as().c_str(), CurrentWeightVector(), pop_limit, forest); + if (conf.count("incremental_search")) { + PassToIncremental(conf["incremental_search"].as().c_str(), CurrentWeightVector(), pop_limit, forest); o->NotifyDecodingComplete(smeta); return true; } diff --git a/decoder/incremental.cc b/decoder/incremental.cc new file mode 100644 index 00000000..768bbd65 --- /dev/null +++ b/decoder/incremental.cc @@ -0,0 +1,184 @@ +#include "incremental.h" + +#include "hg.h" +#include "fdict.h" +#include "tdict.h" + +#include "lm/enumerate_vocab.hh" +#include "lm/model.hh" +#include "search/config.hh" +#include "search/context.hh" +#include "search/edge.hh" +#include "search/edge_generator.hh" +#include "search/rule.hh" +#include "search/vertex.hh" +#include "search/vertex_generator.hh" +#include "util/exception.hh" + +#include +#include + +#include +#include + +namespace { + +struct MapVocab : public lm::EnumerateVocab { + public: + MapVocab() {} + + // Do not call after Lookup. + void Add(lm::WordIndex index, const StringPiece &str) { + const WordID cdec_id = TD::Convert(str.as_string()); + if (cdec_id >= out_.size()) out_.resize(cdec_id + 1); + out_[cdec_id] = index; + } + + // Assumes Add has been called and will never be called again. + lm::WordIndex FromCDec(WordID id) const { + return out_[out_.size() > id ? id : 0]; + } + + private: + std::vector out_; +}; + +class IncrementalBase { + public: + IncrementalBase(const std::vector &weights) : + cdec_weights_(weights), + weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) { + std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; + } + + virtual ~IncrementalBase() {} + + virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; + + static IncrementalBase *Load(const char *model_file, const std::vector &weights); + + protected: + lm::ngram::Config GetConfig() { + lm::ngram::Config ret; + ret.enumerate_vocab = &vocab_; + return ret; + } + + MapVocab vocab_; + + const std::vector &cdec_weights_; + + const search::Weights weights_; +}; + +template class Incremental : public IncrementalBase { + public: + Incremental(const char *model_file, const std::vector &weights) : IncrementalBase(weights), m_(model_file, GetConfig()) {} + + void Search(unsigned int pop_limit, const Hypergraph &hg) const; + + private: + void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; + + const Model m_; +}; + +IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector &weights) { + lm::ngram::ModelType model_type; + if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; + switch (model_type) { + case lm::ngram::PROBING: + return new Incremental(model_file, weights); + case lm::ngram::REST_PROBING: + return new Incremental(model_file, weights); + default: + UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); + } +} + +void PrintFinal(const Hypergraph &hg, const search::Final final) { + const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); + const search::Final *child(final.Children()); + for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { + if (*i > 0) { + std::cout << TD::Convert(*i) << ' '; + } else { + PrintFinal(hg, *child++); + } + } +} + +template void Incremental::Search(unsigned int pop_limit, const Hypergraph &hg) const { + boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); + search::Config config(weights_, pop_limit); + search::Context context(config, m_); + + for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) { + search::EdgeGenerator gen; + const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; + for (unsigned int j = 0; j < down_edges.size(); ++j) { + unsigned int edge_index = down_edges[j]; + ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], gen); + } + search::VertexGenerator vertex_gen(context, out_vertices[i]); + gen.Search(context, vertex_gen); + } + const search::Final top = out_vertices[hg.nodes_.size() - 2].BestChild(); + if (top.Valid()) { + std::cout << "NO PATH FOUND" << std::endl; + } else { + PrintFinal(hg, top); + std::cout << "||| " << top.GetScore() << std::endl; + } +} + +template void Incremental::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const { + const std::vector &e = in.rule_->e(); + std::vector words; + words.reserve(e.size()); + std::vector nts; + unsigned int terminals = 0; + float score = 0.0; + for (std::vector::const_iterator word = e.begin(); word != e.end(); ++word) { + if (*word <= 0) { + nts.push_back(vertices[in.tail_nodes_[-*word]].RootPartial()); + if (nts.back().Empty()) return; + score += nts.back().Bound(); + words.push_back(lm::kMaxWordIndex); + } else { + ++terminals; + words.push_back(vocab_.FromCDec(*word)); + } + } + + if (final) { + words.push_back(m_.GetVocabulary().EndSentence()); + } + + search::PartialEdge out(gen.AllocateEdge(nts.size())); + + memcpy(out.NT(), &nts[0], sizeof(search::PartialVertex) * nts.size()); + + search::Note note; + note.vp = ∈ + out.SetNote(note); + + score += in.rule_->GetFeatureValues().dot(cdec_weights_); + score -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; + score += search::ScoreRule(context, words, final, out.Between()); + out.SetScore(score); + + gen.AddEdge(out); +} + +boost::scoped_ptr AwfulGlobalIncremental; + +} // namespace + +void PassToIncremental(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg) { + if (!AwfulGlobalIncremental.get()) { + std::cerr << "Pop limit " << pop_limit << std::endl; + AwfulGlobalIncremental.reset(IncrementalBase::Load(model_file, weights)); + } + AwfulGlobalIncremental->Search(pop_limit, hg); +} diff --git a/decoder/incremental.h b/decoder/incremental.h new file mode 100644 index 00000000..180383ce --- /dev/null +++ b/decoder/incremental.h @@ -0,0 +1,11 @@ +#ifndef _INCREMENTAL_H_ +#define _INCREMENTAL_H_ + +#include "weights.h" +#include + +class Hypergraph; + +void PassToIncremental(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg); + +#endif // _INCREMENTAL_H_ diff --git a/decoder/lazy.cc b/decoder/lazy.cc deleted file mode 100644 index 1e6a94fe..00000000 --- a/decoder/lazy.cc +++ /dev/null @@ -1,178 +0,0 @@ -#include "hg.h" -#include "lazy.h" -#include "fdict.h" -#include "tdict.h" - -#include "lm/enumerate_vocab.hh" -#include "lm/model.hh" -#include "search/config.hh" -#include "search/context.hh" -#include "search/edge.hh" -#include "search/edge_queue.hh" -#include "search/vertex.hh" -#include "search/vertex_generator.hh" -#include "util/exception.hh" - -#include -#include - -#include -#include - -namespace { - -struct MapVocab : public lm::EnumerateVocab { - public: - MapVocab() {} - - // Do not call after Lookup. - void Add(lm::WordIndex index, const StringPiece &str) { - const WordID cdec_id = TD::Convert(str.as_string()); - if (cdec_id >= out_.size()) out_.resize(cdec_id + 1); - out_[cdec_id] = index; - } - - // Assumes Add has been called and will never be called again. - lm::WordIndex FromCDec(WordID id) const { - return out_[out_.size() > id ? id : 0]; - } - - private: - std::vector out_; -}; - -class LazyBase { - public: - LazyBase(const std::vector &weights) : - cdec_weights_(weights), - weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) { - std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; - } - - virtual ~LazyBase() {} - - virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; - - static LazyBase *Load(const char *model_file, const std::vector &weights); - - protected: - lm::ngram::Config GetConfig() { - lm::ngram::Config ret; - ret.enumerate_vocab = &vocab_; - return ret; - } - - MapVocab vocab_; - - const std::vector &cdec_weights_; - - const search::Weights weights_; -}; - -template class Lazy : public LazyBase { - public: - Lazy(const char *model_file, const std::vector &weights) : LazyBase(weights), m_(model_file, GetConfig()) {} - - void Search(unsigned int pop_limit, const Hypergraph &hg) const; - - private: - unsigned char ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const; - - const Model m_; -}; - -LazyBase *LazyBase::Load(const char *model_file, const std::vector &weights) { - lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; - switch (model_type) { - case lm::ngram::PROBING: - return new Lazy(model_file, weights); - case lm::ngram::REST_PROBING: - return new Lazy(model_file, weights); - default: - UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); - } -} - -void PrintFinal(const Hypergraph &hg, const search::Final &final) { - const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); - boost::array::const_iterator child(final.Children().begin()); - for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { - if (*i > 0) { - std::cout << TD::Convert(*i) << ' '; - } else { - PrintFinal(hg, **child++); - } - } -} - -template void Lazy::Search(unsigned int pop_limit, const Hypergraph &hg) const { - boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); - search::Config config(weights_, pop_limit); - search::Context context(config, m_); - - for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) { - search::EdgeQueue queue(context.PopLimit()); - const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; - for (unsigned int j = 0; j < down_edges.size(); ++j) { - unsigned int edge_index = down_edges[j]; - unsigned char arity = ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], queue.InitializeEdge()); - search::Note note; - note.vp = &hg.edges_[edge_index]; - if (arity != 255) queue.AddEdge(arity, note); - } - search::VertexGenerator vertex_gen(context, out_vertices[i]); - queue.Search(context, vertex_gen); - } - const search::Final *top = out_vertices[hg.nodes_.size() - 2].BestChild(); - if (!top) { - std::cout << "NO PATH FOUND" << std::endl; - } else { - PrintFinal(hg, *top); - std::cout << "||| " << top->Bound() << std::endl; - } -} - -template unsigned char Lazy::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const { - const std::vector &e = in.rule_->e(); - std::vector words; - unsigned int terminals = 0; - unsigned char nt = 0; - out.score = 0.0; - for (std::vector::const_iterator word = e.begin(); word != e.end(); ++word) { - if (*word <= 0) { - out.nt[nt] = vertices[in.tail_nodes_[-*word]].RootPartial(); - if (out.nt[nt].Empty()) return 255; - out.score += out.nt[nt].Bound(); - ++nt; - words.push_back(lm::kMaxWordIndex); - } else { - ++terminals; - words.push_back(vocab_.FromCDec(*word)); - } - } - for (unsigned char fill = nt; fill < search::kMaxArity; ++fill) { - out.nt[fill] = search::kBlankPartialVertex; - } - - if (final) { - words.push_back(m_.GetVocabulary().EndSentence()); - } - - out.score += in.rule_->GetFeatureValues().dot(cdec_weights_); - out.score -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; - out.score += search::ScoreRule(context, words, final, out.between); - return nt; -} - -boost::scoped_ptr AwfulGlobalLazy; - -} // namespace - -void PassToLazy(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg) { - if (!AwfulGlobalLazy.get()) { - std::cerr << "Pop limit " << pop_limit << std::endl; - AwfulGlobalLazy.reset(LazyBase::Load(model_file, weights)); - } - AwfulGlobalLazy->Search(pop_limit, hg); -} diff --git a/decoder/lazy.h b/decoder/lazy.h deleted file mode 100644 index 94895b19..00000000 --- a/decoder/lazy.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef _LAZY_H_ -#define _LAZY_H_ - -#include "weights.h" -#include - -class Hypergraph; - -void PassToLazy(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg); - -#endif // _LAZY_H_ diff --git a/dtrain/Makefile.am b/dtrain/Makefile.am index 64fef489..ca9581f5 100644 --- a/dtrain/Makefile.am +++ b/dtrain/Makefile.am @@ -1,7 +1,7 @@ bin_PROGRAMS = dtrain dtrain_SOURCES = dtrain.cc score.cc -dtrain_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +dtrain_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/klm/alone/Jamfile b/klm/alone/Jamfile deleted file mode 100644 index 2cc90c05..00000000 --- a/klm/alone/Jamfile +++ /dev/null @@ -1,4 +0,0 @@ -lib standalone : assemble.cc read.cc threading.cc vocab.cc ../lm//kenlm ../util//kenutil ../search//search : .. : : .. ../search//search ../lm//kenlm ; - -exe decode : main.cc standalone main.cc : multi:..//boost_thread ; -exe just_vocab : just_vocab.cc standalone : multi:..//boost_thread ; diff --git a/klm/alone/assemble.cc b/klm/alone/assemble.cc deleted file mode 100644 index 2ae72ce9..00000000 --- a/klm/alone/assemble.cc +++ /dev/null @@ -1,76 +0,0 @@ -#include "alone/assemble.hh" - -#include "alone/labeled_edge.hh" -#include "search/final.hh" - -#include - -namespace alone { - -std::ostream &operator<<(std::ostream &o, const search::Final &final) { - const std::vector &words = static_cast(final.From()).Words(); - if (words.empty()) return o; - const search::Final *const *child = final.Children().data(); - std::vector::const_iterator i(words.begin()); - for (; i != words.end() - 1; ++i) { - if (*i) { - o << **i << ' '; - } else { - o << **child << ' '; - ++child; - } - } - - if (*i) { - if (**i != "") { - o << **i; - } - } else { - o << **child; - } - - return o; -} - -namespace { - -void MakeIndent(std::ostream &o, const char *indent_str, unsigned int level) { - for (unsigned int i = 0; i < level; ++i) - o << indent_str; -} - -void DetailedFinalInternal(std::ostream &o, const search::Final &final, const char *indent_str, unsigned int indent) { - o << "(\n"; - MakeIndent(o, indent_str, indent); - const std::vector &words = static_cast(final.From()).Words(); - const search::Final *const *child = final.Children().data(); - for (std::vector::const_iterator i(words.begin()); i != words.end(); ++i) { - if (*i) { - o << **i; - if (i == words.end() - 1) { - o << '\n'; - MakeIndent(o, indent_str, indent); - } else { - o << ' '; - } - } else { - // One extra indent from the line we're currently on. - o << indent_str; - DetailedFinalInternal(o, **child, indent_str, indent + 1); - for (unsigned int i = 0; i < indent; ++i) o << indent_str; - ++child; - } - } - o << ")=" << final.Bound() << '\n'; -} -} // namespace - -void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str) { - DetailedFinalInternal(o, final, indent_str, 0); -} - -void PrintFinal(const search::Final &final) { - std::cout << final << std::endl; -} - -} // namespace alone diff --git a/klm/alone/assemble.hh b/klm/alone/assemble.hh deleted file mode 100644 index e6b0ad5c..00000000 --- a/klm/alone/assemble.hh +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef ALONE_ASSEMBLE__ -#define ALONE_ASSEMBLE__ - -#include - -namespace search { -class Final; -} // namespace search - -namespace alone { - -std::ostream &operator<<(std::ostream &o, const search::Final &final); - -void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str = " "); - -// This isn't called anywhere but makes it easy to print from gdb. -void PrintFinal(const search::Final &final); - -} // namespace alone - -#endif // ALONE_ASSEMBLE__ diff --git a/klm/alone/graph.hh b/klm/alone/graph.hh deleted file mode 100644 index 788352c9..00000000 --- a/klm/alone/graph.hh +++ /dev/null @@ -1,87 +0,0 @@ -#ifndef ALONE_GRAPH__ -#define ALONE_GRAPH__ - -#include "alone/labeled_edge.hh" -#include "search/rule.hh" -#include "search/types.hh" -#include "search/vertex.hh" -#include "util/exception.hh" - -#include -#include -#include - -namespace alone { - -template class FixedAllocator : boost::noncopyable { - public: - FixedAllocator() : current_(NULL), end_(NULL) {} - - void Init(std::size_t count) { - assert(!current_); - array_.reset(new T[count]); - current_ = array_.get(); - end_ = current_ + count; - } - - T &operator[](std::size_t idx) { - return array_.get()[idx]; - } - - T *New() { - T *ret = current_++; - UTIL_THROW_IF(ret >= end_, util::Exception, "Allocating past end"); - return ret; - } - - std::size_t Size() const { - return end_ - array_.get(); - } - - private: - boost::scoped_array array_; - T *current_, *end_; -}; - -class Graph : boost::noncopyable { - public: - typedef LabeledEdge Edge; - typedef search::Vertex Vertex; - - Graph() {} - - void SetCounts(std::size_t vertices, std::size_t edges) { - vertices_.Init(vertices); - edges_.Init(edges); - } - - Vertex *NewVertex() { - return vertices_.New(); - } - - std::size_t VertexSize() const { return vertices_.Size(); } - - Vertex &MutableVertex(std::size_t index) { - return vertices_[index]; - } - - Edge *NewEdge() { - return edges_.New(); - } - - std::size_t EdgeSize() const { return edges_.Size(); } - - void SetRoot(Vertex *root) { root_ = root; } - - Vertex &Root() { return *root_; } - - private: - FixedAllocator vertices_; - FixedAllocator edges_; - - Vertex *root_; -}; - -} // namespace alone - -#endif // ALONE_GRAPH__ diff --git a/klm/alone/just_vocab.cc b/klm/alone/just_vocab.cc deleted file mode 100644 index 35aea5ed..00000000 --- a/klm/alone/just_vocab.cc +++ /dev/null @@ -1,14 +0,0 @@ -#include "alone/read.hh" -#include "util/file_piece.hh" - -#include - -int main() { - util::FilePiece f(0, "stdin", &std::cerr); - while (true) { - try { - alone::JustVocab(f, std::cout); - } catch (const util::EndOfFileException &e) { break; } - std::cout << '\n'; - } -} diff --git a/klm/alone/labeled_edge.hh b/klm/alone/labeled_edge.hh deleted file mode 100644 index 94d8cbdf..00000000 --- a/klm/alone/labeled_edge.hh +++ /dev/null @@ -1,30 +0,0 @@ -#ifndef ALONE_LABELED_EDGE__ -#define ALONE_LABELED_EDGE__ - -#include "search/edge.hh" - -#include -#include - -namespace alone { - -class LabeledEdge : public search::Edge { - public: - LabeledEdge() {} - - void AppendWord(const std::string *word) { - words_.push_back(word); - } - - const std::vector &Words() const { - return words_; - } - - private: - // NULL for non-terminals. - std::vector words_; -}; - -} // namespace alone - -#endif // ALONE_LABELED_EDGE__ diff --git a/klm/alone/main.cc b/klm/alone/main.cc deleted file mode 100644 index e09ab01d..00000000 --- a/klm/alone/main.cc +++ /dev/null @@ -1,85 +0,0 @@ -#include "alone/threading.hh" -#include "search/config.hh" -#include "search/context.hh" -#include "util/exception.hh" -#include "util/file_piece.hh" -#include "util/usage.hh" - -#include - -#include -#include - -namespace alone { - -template void ReadLoop(const std::string &graph_prefix, Control &control) { - for (unsigned int sentence = 0; ; ++sentence) { - std::stringstream name; - name << graph_prefix << '/' << sentence; - std::auto_ptr file; - try { - file.reset(new util::FilePiece(name.str().c_str())); - } catch (const util::ErrnoException &e) { - if (e.Error() == ENOENT) return; - throw; - } - control.Add(file.release()); - } -} - -template void RunWithModelType(const char *graph_prefix, const char *model_file, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) { - Model model(model_file); - search::Weights weights(weight_str); - search::Config config(weights, pop_limit); - - if (threads > 1) { -#ifdef WITH_THREADS - Controller controller(config, model, threads, std::cout); - ReadLoop(graph_prefix, controller); -#else - UTIL_THROW(util::Exception, "Threading support not compiled in."); -#endif - } else { - InThread controller(config, model, std::cout); - ReadLoop(graph_prefix, controller); - } -} - -void Run(const char *graph_prefix, const char *lm_name, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) { - lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(lm_name, model_type)) model_type = lm::ngram::PROBING; - switch (model_type) { - case lm::ngram::PROBING: - RunWithModelType(graph_prefix, lm_name, weight_str, pop_limit, threads); - break; - case lm::ngram::REST_PROBING: - RunWithModelType(graph_prefix, lm_name, weight_str, pop_limit, threads); - break; - default: - UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); - } -} - -} // namespace alone - -int main(int argc, char *argv[]) { - if (argc < 5 || argc > 6) { - std::cerr << argv[0] << " graph_prefix lm \"weights\" pop [threads]" << std::endl; - return 1; - } - -#ifdef WITH_THREADS - unsigned thread_count = boost::thread::hardware_concurrency(); -#else - unsigned thread_count = 1; -#endif - if (argc == 6) { - thread_count = boost::lexical_cast(argv[5]); - UTIL_THROW_IF(!thread_count, util::Exception, "Thread count 0"); - } - UTIL_THROW_IF(!thread_count, util::Exception, "Boost doesn't know how many threads there are. Pass it on the command line."); - alone::Run(argv[1], argv[2], argv[3], boost::lexical_cast(argv[4]), thread_count); - - util::PrintUsage(std::cerr); - return 0; -} diff --git a/klm/alone/read.cc b/klm/alone/read.cc deleted file mode 100644 index 0b20be35..00000000 --- a/klm/alone/read.cc +++ /dev/null @@ -1,118 +0,0 @@ -#include "alone/read.hh" - -#include "alone/graph.hh" -#include "alone/vocab.hh" -#include "search/arity.hh" -#include "search/context.hh" -#include "search/weights.hh" -#include "util/file_piece.hh" - -#include -#include - -#include - -namespace alone { - -namespace { - -template Graph::Edge &ReadEdge(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab, bool final) { - Graph::Edge *ret = to.NewEdge(); - - StringPiece got; - - std::vector words; - unsigned long int terminals = 0; - while ("|||" != (got = from.ReadDelimited())) { - if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) { - // non-terminal - char *end_ptr; - unsigned long int child = std::strtoul(got.data() + 1, &end_ptr, 10); - UTIL_THROW_IF(end_ptr != got.data() + got.size() - 1, FormatException, "Bad non-terminal" << got); - UTIL_THROW_IF(child >= to.VertexSize(), FormatException, "Reference to vertex " << child << " but we only have " << to.VertexSize() << " vertices. Is the file in bottom-up format?"); - ret->Add(to.MutableVertex(child)); - words.push_back(lm::kMaxWordIndex); - ret->AppendWord(NULL); - } else { - const std::pair &found = vocab.FindOrAdd(got); - words.push_back(found.second); - ret->AppendWord(&found.first); - ++terminals; - } - } - if (final) { - // This is not counted for the word penalty. - words.push_back(vocab.EndSentence().second); - ret->AppendWord(&vocab.EndSentence().first); - } - // Hard-coded word penalty. - float additive = context.GetWeights().DotNoLM(from.ReadLine()) - context.GetWeights().WordPenalty() * static_cast(terminals) / M_LN10; - ret->InitRule().Init(context, additive, words, final); - unsigned int arity = ret->GetRule().Arity(); - UTIL_THROW_IF(arity > search::kMaxArity, util::Exception, "Edit search/arity.hh and increase " << search::kMaxArity << " to at least " << arity); - return *ret; -} - -} // namespace - -// TODO: refactor -void JustVocab(util::FilePiece &from, std::ostream &out) { - boost::unordered_set seen; - unsigned long int vertices = from.ReadULong(); - from.ReadULong(); // edges - UTIL_THROW_IF(vertices == 0, FormatException, "Vertex count is zero"); - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts"); - std::string temp; - for (unsigned long int i = 0; i < vertices; ++i) { - unsigned long int edge_count = from.ReadULong(); - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count"); - for (unsigned long int e = 0; e < edge_count; ++e) { - StringPiece got; - while ("|||" != (got = from.ReadDelimited())) { - if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) continue; - temp.assign(got.data(), got.size()); - if (seen.insert(temp).second) out << temp << ' '; - } - from.ReadLine(); // weights - } - } - // Eat sentence - from.ReadLine(); -} - -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab) { - unsigned long int vertices; - try { - vertices = from.ReadULong(); - } catch (const util::EndOfFileException &e) { return false; } - unsigned long int edges = from.ReadULong(); - UTIL_THROW_IF(vertices < 2, FormatException, "Vertex count is " << vertices); - UTIL_THROW_IF(edges == 0, FormatException, "Edge count is " << edges); - --vertices; - --edges; - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts"); - to.SetCounts(vertices, edges); - Graph::Vertex *vertex; - for (unsigned long int i = 0; ; ++i) { - vertex = to.NewVertex(); - unsigned long int edge_count = from.ReadULong(); - bool root = (i == vertices - 1); - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count"); - for (unsigned long int e = 0; e < edge_count; ++e) { - vertex->Add(ReadEdge(context, from, to, vocab, root)); - } - vertex->FinishedAdding(); - if (root) break; - } - to.SetRoot(vertex); - StringPiece str = from.ReadLine(); - UTIL_THROW_IF("1" != str, FormatException, "Expected one edge to root"); - // The edge - from.ReadLine(); - return true; -} - -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab); -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab); - -} // namespace alone diff --git a/klm/alone/read.hh b/klm/alone/read.hh deleted file mode 100644 index 10769a86..00000000 --- a/klm/alone/read.hh +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef ALONE_READ__ -#define ALONE_READ__ - -#include "util/exception.hh" - -#include - -namespace util { class FilePiece; } - -namespace search { template class Context; } - -namespace alone { - -class Graph; -class Vocab; - -class FormatException : public util::Exception { - public: - FormatException() {} - ~FormatException() throw() {} -}; - -void JustVocab(util::FilePiece &from, std::ostream &to); - -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab); - -} // namespace alone - -#endif // ALONE_READ__ diff --git a/klm/alone/threading.cc b/klm/alone/threading.cc deleted file mode 100644 index 475386b6..00000000 --- a/klm/alone/threading.cc +++ /dev/null @@ -1,80 +0,0 @@ -#include "alone/threading.hh" - -#include "alone/assemble.hh" -#include "alone/graph.hh" -#include "alone/read.hh" -#include "alone/vocab.hh" -#include "lm/model.hh" -#include "search/context.hh" -#include "search/vertex_generator.hh" - -#include -#include -#include - -#include - -namespace alone { -template void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out) { - search::Context context(config, model); - Graph graph; - Vocab vocab(model.GetVocabulary()); - { - boost::scoped_ptr in(in_ptr); - ReadCDec(context, *in, graph, vocab); - } - - for (std::size_t i = 0; i < graph.VertexSize(); ++i) { - search::VertexGenerator(context, graph.MutableVertex(i)); - } - search::PartialVertex top = graph.Root().RootPartial(); - if (top.Empty()) { - out << "NO PATH FOUND"; - } else { - search::PartialVertex continuation; - while (!top.Complete()) { - top.Split(continuation); - top = continuation; - } - out << top.End() << " ||| " << top.End().Bound() << std::endl; - } -} - -template void Decode(const search::Config &config, const lm::ngram::ProbingModel &model, util::FilePiece *in_ptr, std::ostream &out); -template void Decode(const search::Config &config, const lm::ngram::RestProbingModel &model, util::FilePiece *in_ptr, std::ostream &out); - -#ifdef WITH_THREADS -template void DecodeHandler::operator()(Input message) { - std::stringstream assemble; - Decode(config_, model_, message.file, assemble); - Produce(message.sentence_id, assemble.str()); -} - -template void DecodeHandler::Produce(unsigned int sentence_id, const std::string &str) { - Output out; - out.sentence_id = sentence_id; - out.str = new std::string(str); - out_.Produce(out); -} - -void PrintHandler::operator()(Output message) { - unsigned int relative = message.sentence_id - done_; - if (waiting_.size() <= relative) waiting_.resize(relative + 1); - waiting_[relative] = message.str; - for (std::string *lead; !waiting_.empty() && (lead = waiting_[0]); waiting_.pop_front(), ++done_) { - out_ << *lead; - delete lead; - } -} - -template Controller::Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to) : - sentence_id_(0), - printer_(decode_workers, 1, boost::ref(to), Output::Poison()), - decoder_(3, decode_workers, boost::in_place(boost::ref(config), boost::ref(model), boost::ref(printer_.In())), Input::Poison()) {} - -template class Controller; -template class Controller; - -#endif - -} // namespace alone diff --git a/klm/alone/threading.hh b/klm/alone/threading.hh deleted file mode 100644 index 0ab0f739..00000000 --- a/klm/alone/threading.hh +++ /dev/null @@ -1,129 +0,0 @@ -#ifndef ALONE_THREADING__ -#define ALONE_THREADING__ - -#ifdef WITH_THREADS -#include "util/pcqueue.hh" -#include "util/pool.hh" -#endif - -#include -#include -#include - -namespace util { -class FilePiece; -} // namespace util - -namespace search { -class Config; -template class Context; -} // namespace search - -namespace alone { - -template void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out); - -class Graph; - -#ifdef WITH_THREADS -struct SentenceID { - unsigned int sentence_id; - bool operator==(const SentenceID &other) const { - return sentence_id == other.sentence_id; - } -}; - -struct Input : public SentenceID { - util::FilePiece *file; - static Input Poison() { - Input ret; - ret.sentence_id = static_cast(-1); - ret.file = NULL; - return ret; - } -}; - -struct Output : public SentenceID { - std::string *str; - static Output Poison() { - Output ret; - ret.sentence_id = static_cast(-1); - ret.str = NULL; - return ret; - } -}; - -template class DecodeHandler { - public: - typedef Input Request; - - DecodeHandler(const search::Config &config, const Model &model, util::PCQueue &out) : config_(config), model_(model), out_(out) {} - - void operator()(Input message); - - private: - void Produce(unsigned int sentence_id, const std::string &str); - - const search::Config &config_; - - const Model &model_; - - util::PCQueue &out_; -}; - -class PrintHandler { - public: - typedef Output Request; - - explicit PrintHandler(std::ostream &o) : out_(o), done_(0) {} - - void operator()(Output message); - - private: - std::ostream &out_; - std::deque waiting_; - unsigned int done_; -}; - -template class Controller { - public: - // This config must remain valid. - explicit Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to); - - // Takes ownership of in. - void Add(util::FilePiece *in) { - Input input; - input.sentence_id = sentence_id_++; - input.file = in; - decoder_.Produce(input); - } - - private: - unsigned int sentence_id_; - - util::Pool printer_; - - util::Pool > decoder_; -}; -#endif - -// Same API as controller. -template class InThread { - public: - InThread(const search::Config &config, const Model &model, std::ostream &to) : config_(config), model_(model), to_(to) {} - - // Takes ownership of in. - void Add(util::FilePiece *in) { - Decode(config_, model_, in, to_); - } - - private: - const search::Config &config_; - - const Model &model_; - - std::ostream &to_; -}; - -} // namespace alone -#endif // ALONE_THREADING__ diff --git a/klm/alone/vocab.cc b/klm/alone/vocab.cc deleted file mode 100644 index ffe55301..00000000 --- a/klm/alone/vocab.cc +++ /dev/null @@ -1,19 +0,0 @@ -#include "alone/vocab.hh" - -#include "lm/virtual_interface.hh" -#include "util/string_piece.hh" - -namespace alone { - -Vocab::Vocab(const lm::base::Vocabulary &backing) : backing_(backing), end_sentence_(FindOrAdd("")) {} - -const std::pair &Vocab::FindOrAdd(const StringPiece &str) { - Map::const_iterator i(FindStringPiece(map_, str)); - if (i != map_.end()) return *i; - std::pair to_ins; - to_ins.first.assign(str.data(), str.size()); - to_ins.second = backing_.Index(str); - return *map_.insert(to_ins).first; -} - -} // namespace alone diff --git a/klm/alone/vocab.hh b/klm/alone/vocab.hh deleted file mode 100644 index 3ac0f542..00000000 --- a/klm/alone/vocab.hh +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef ALONE_VOCAB__ -#define ALONE_VOCAB__ - -#include "lm/word_index.hh" -#include "util/string_piece.hh" - -#include -#include - -#include - -namespace lm { namespace base { class Vocabulary; } } - -namespace alone { - -class Vocab { - public: - explicit Vocab(const lm::base::Vocabulary &backing); - - const std::pair &FindOrAdd(const StringPiece &str); - - const std::pair &EndSentence() const { return end_sentence_; } - - private: - typedef boost::unordered_map Map; - Map map_; - - const lm::base::Vocabulary &backing_; - - const std::pair &end_sentence_; -}; - -} // namespace alone -#endif // ALONE_VCOAB__ diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 40af8a63..2fd20481 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -87,7 +87,7 @@ template void GenericModel.. ; +lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : .. ; import testing ; diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am new file mode 100644 index 00000000..ccc5b7f6 --- /dev/null +++ b/klm/search/Makefile.am @@ -0,0 +1,11 @@ +noinst_LIBRARIES = libksearch.a + +libksearch_a_SOURCES = \ + edge_generator.cc \ + rule.cc \ + vertex.cc \ + vertex_generator.cc \ + weights.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. + diff --git a/klm/search/arity.hh b/klm/search/arity.hh deleted file mode 100644 index 09c2c671..00000000 --- a/klm/search/arity.hh +++ /dev/null @@ -1,8 +0,0 @@ -#ifndef SEARCH_ARITY__ -#define SEARCH_ARITY__ -namespace search { - -const unsigned int kMaxArity = 2; - -} // namespace search -#endif // SEARCH_ARITY__ diff --git a/klm/search/context.hh b/klm/search/context.hh index 27940053..62163144 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -7,6 +7,7 @@ #include "search/types.hh" #include "search/vertex.hh" #include "util/exception.hh" +#include "util/pool.hh" #include #include @@ -21,10 +22,8 @@ class ContextBase { public: explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} - Final *NewFinal() { - Final *ret = final_pool_.construct(); - assert(ret); - return ret; + util::Pool &FinalPool() { + return final_pool_; } VertexNode *NewVertexNode() { @@ -42,7 +41,8 @@ class ContextBase { const Weights &GetWeights() const { return weights_; } private: - boost::object_pool final_pool_; + util::Pool final_pool_; + boost::object_pool vertex_node_pool_; unsigned int pop_limit_; diff --git a/klm/search/edge.hh b/klm/search/edge.hh index 77ab0ade..187904bf 100644 --- a/klm/search/edge.hh +++ b/klm/search/edge.hh @@ -2,30 +2,53 @@ #define SEARCH_EDGE__ #include "lm/state.hh" -#include "search/arity.hh" -#include "search/rule.hh" +#include "search/header.hh" #include "search/types.hh" #include "search/vertex.hh" +#include "util/pool.hh" -#include +#include + +#include namespace search { -struct PartialEdge { - Score score; - // Terminals - lm::ngram::ChartState between[kMaxArity + 1]; - // Non-terminals - PartialVertex nt[kMaxArity]; +// Copyable, but the copy will be shallow. +class PartialEdge : public Header { + public: + // Allow default construction for STL. + PartialEdge() {} + + PartialEdge(util::Pool &pool, Arity arity) + : Header(pool.Allocate(Size(arity, arity + 1)), arity) {} + + PartialEdge(util::Pool &pool, Arity arity, Arity chart_states) + : Header(pool.Allocate(Size(arity, chart_states)), arity) {} - const lm::ngram::ChartState &CompletedState() const { - return between[0]; - } + // Non-terminals + const PartialVertex *NT() const { + return reinterpret_cast(After()); + } + PartialVertex *NT() { + return reinterpret_cast(After()); + } - bool operator<(const PartialEdge &other) const { - return score < other.score; - } + const lm::ngram::ChartState &CompletedState() const { + return *Between(); + } + const lm::ngram::ChartState *Between() const { + return reinterpret_cast(After() + GetArity() * sizeof(PartialVertex)); + } + lm::ngram::ChartState *Between() { + return reinterpret_cast(After() + GetArity() * sizeof(PartialVertex)); + } + + private: + static std::size_t Size(Arity arity, Arity chart_states) { + return kHeaderSize + arity * sizeof(PartialVertex) + chart_states * sizeof(lm::ngram::ChartState); + } }; + } // namespace search #endif // SEARCH_EDGE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index 56239dfb..260159b1 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -4,117 +4,107 @@ #include "lm/partial.hh" #include "search/context.hh" #include "search/vertex.hh" -#include "search/vertex_generator.hh" #include namespace search { -EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) { -/* for (unsigned char i = 0; i < edge.Arity(); ++i) { - root.nt[i] = edge.GetVertex(i).RootPartial(); - } - for (unsigned char i = edge.Arity(); i < 2; ++i) { - root.nt[i] = kBlankPartialVertex; - }*/ - generate_.push(&root); - top_score_ = root.score; -} - namespace { -template float FastScore(const Context &context, unsigned char victim, unsigned char arity, const PartialEdge &previous, PartialEdge &update) { - memcpy(update.between, previous.between, sizeof(lm::ngram::ChartState) * (arity + 1)); - - float ret = 0.0; - lm::ngram::ChartState *before, *after; - if (victim == 0) { - before = &update.between[0]; - after = &update.between[(arity == 2 && previous.nt[1].Complete()) ? 2 : 1]; - } else { - assert(victim == 1); - assert(arity == 2); - before = &update.between[previous.nt[0].Complete() ? 0 : 1]; - after = &update.between[2]; - } - const lm::ngram::ChartState &previous_reveal = previous.nt[victim].State(); - const PartialVertex &update_nt = update.nt[victim]; +template void FastScore(const Context &context, Arity victim, Arity before_idx, Arity incomplete, const PartialVertex &previous_vertex, PartialEdge update) { + lm::ngram::ChartState *between = update.Between(); + lm::ngram::ChartState *before = &between[before_idx], *after = &between[before_idx + 1]; + + float adjustment = 0.0; + const lm::ngram::ChartState &previous_reveal = previous_vertex.State(); + const PartialVertex &update_nt = update.NT()[victim]; const lm::ngram::ChartState &update_reveal = update_nt.State(); - float just_after = 0.0; if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) { - just_after += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); + adjustment += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); } - if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous.nt[victim].RightFull())) { - ret += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); + if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous_vertex.RightFull())) { + adjustment += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); } if (update_nt.Complete()) { if (update_reveal.left.full) { before->left.full = true; } else { assert(update_reveal.left.length == update_reveal.right.length); - ret += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); + adjustment += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); } - if (victim == 0) { - update.between[0].right = after->right; - } else { - update.between[2].left = before->left; + before->right = after->right; + // Shift the others shifted one down, covering after. + for (lm::ngram::ChartState *cover = after; cover < between + incomplete; ++cover) { + *cover = *(cover + 1); } } - return previous.score + (ret + just_after) * context.GetWeights().LM(); + update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM()); } } // namespace -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool) { +template PartialEdge EdgeGenerator::Pop(Context &context) { assert(!generate_.empty()); - PartialEdge &top = *generate_.top(); + PartialEdge top = generate_.top(); generate_.pop(); - unsigned int victim = 0; - unsigned char lowest_length = 255; - for (unsigned char i = 0; i != arity_; ++i) { - if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) { - lowest_length = top.nt[i].Length(); - victim = i; + PartialVertex *const top_nt = top.NT(); + const Arity arity = top.GetArity(); + + Arity victim = 0; + Arity victim_completed; + Arity incomplete; + // Select victim or return if complete. + { + Arity completed = 0; + unsigned char lowest_length = 255; + for (Arity i = 0; i != arity; ++i) { + if (top_nt[i].Complete()) { + ++completed; + } else if (top_nt[i].Length() < lowest_length) { + lowest_length = top_nt[i].Length(); + victim = i; + victim_completed = completed; + } } - } - if (lowest_length == 255) { - // All states report complete. - top.between[0].right = top.between[arity_].right; - // Now top.between[0] is the full edge state. - top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score; - return ⊤ + if (lowest_length == 255) { + return top; + } + incomplete = arity - completed; } - unsigned int stay = !victim; - PartialEdge &continuation = *static_cast(partial_edge_pool.malloc()); - float old_bound = top.nt[victim].Bound(); - // The alternate's score will change because alternate.nt[victim] changes. - bool split = top.nt[victim].Split(continuation.nt[victim]); - // top is now the alternate. + PartialVertex old_value(top_nt[victim]); + PartialVertex alternate_changed; + if (top_nt[victim].Split(alternate_changed)) { + PartialEdge alternate(partial_edge_pool_, arity, incomplete + 1); + alternate.SetScore(top.GetScore() + alternate_changed.Bound() - old_value.Bound()); - continuation.nt[stay] = top.nt[stay]; - continuation.score = FastScore(context, victim, arity_, top, continuation); - // TODO: dedupe? - generate_.push(&continuation); + alternate.SetNote(top.GetNote()); + + PartialVertex *alternate_nt = alternate.NT(); + for (Arity i = 0; i < victim; ++i) alternate_nt[i] = top_nt[i]; + alternate_nt[victim] = alternate_changed; + for (Arity i = victim + 1; i < arity; ++i) alternate_nt[i] = top_nt[i]; + + memcpy(alternate.Between(), top.Between(), sizeof(lm::ngram::ChartState) * (incomplete + 1)); - if (split) { - // We have an alternate. - top.score += top.nt[victim].Bound() - old_bound; // TODO: dedupe? - generate_.push(&top); - } else { - partial_edge_pool.free(&top); + generate_.push(alternate); } - top_score_ = generate_.top()->score; - return NULL; + // top is now the continuation. + FastScore(context, victim, victim - victim_completed, incomplete, old_value, top); + // TODO: dedupe? + generate_.push(top); + + // Invalid indicates no new hypothesis generated. + return PartialEdge(); } -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); } // namespace search diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh index 875ccc5e..582c78b7 100644 --- a/klm/search/edge_generator.hh +++ b/klm/search/edge_generator.hh @@ -3,11 +3,8 @@ #include "search/edge.hh" #include "search/note.hh" +#include "search/types.hh" -#include -#include - -#include #include namespace lm { @@ -20,38 +17,40 @@ namespace search { template class Context; -class VertexGenerator; - -struct PartialEdgePointerLess : std::binary_function { - bool operator()(const PartialEdge *first, const PartialEdge *second) const { - return *first < *second; - } -}; - class EdgeGenerator { public: - EdgeGenerator(PartialEdge &root, unsigned char arity, Note note); + EdgeGenerator() {} - Score TopScore() const { - return top_score_; + PartialEdge AllocateEdge(Arity arity) { + return PartialEdge(partial_edge_pool_, arity); } - Note GetNote() const { - return note_; + void AddEdge(PartialEdge edge) { + generate_.push(edge); } - // Pop. If there's a complete hypothesis, return it. Otherwise return NULL. - template PartialEdge *Pop(Context &context, boost::pool<> &partial_edge_pool); + bool Empty() const { return generate_.empty(); } + + // Pop. If there's a complete hypothesis, return it. Otherwise return an invalid PartialEdge. + template PartialEdge Pop(Context &context); + + template void Search(Context &context, Output &output) { + unsigned to_pop = context.PopLimit(); + while (to_pop > 0 && !generate_.empty()) { + PartialEdge got(Pop(context)); + if (got.Valid()) { + output.NewHypothesis(got); + --to_pop; + } + } + output.FinishedSearch(); + } private: - Score top_score_; - - unsigned char arity_; + util::Pool partial_edge_pool_; - typedef std::priority_queue, PartialEdgePointerLess> Generate; + typedef std::priority_queue Generate; Generate generate_; - - Note note_; }; } // namespace search diff --git a/klm/search/edge_queue.cc b/klm/search/edge_queue.cc deleted file mode 100644 index e3ae6ebf..00000000 --- a/klm/search/edge_queue.cc +++ /dev/null @@ -1,25 +0,0 @@ -#include "search/edge_queue.hh" - -#include "lm/left.hh" -#include "search/context.hh" - -#include - -namespace search { - -EdgeQueue::EdgeQueue(unsigned int pop_limit_hint) : partial_edge_pool_(sizeof(PartialEdge), pop_limit_hint * 2) { - take_ = static_cast(partial_edge_pool_.malloc()); -} - -/*void EdgeQueue::AddEdge(PartialEdge &root, unsigned char arity, Note note) { - // Ignore empty edges. - for (unsigned char i = 0; i < edge.Arity(); ++i) { - PartialVertex root(edge.GetVertex(i).RootPartial()); - if (root.Empty()) return; - total_score += root.Bound(); - } - PartialEdge &allocated = *static_cast(partial_edge_pool_.malloc()); - allocated.score = total_score; -}*/ - -} // namespace search diff --git a/klm/search/edge_queue.hh b/klm/search/edge_queue.hh deleted file mode 100644 index 187eaed7..00000000 --- a/klm/search/edge_queue.hh +++ /dev/null @@ -1,73 +0,0 @@ -#ifndef SEARCH_EDGE_QUEUE__ -#define SEARCH_EDGE_QUEUE__ - -#include "search/edge.hh" -#include "search/edge_generator.hh" -#include "search/note.hh" - -#include -#include - -#include - -namespace search { - -template class Context; - -class EdgeQueue { - public: - explicit EdgeQueue(unsigned int pop_limit_hint); - - PartialEdge &InitializeEdge() { - return *take_; - } - - void AddEdge(unsigned char arity, Note note) { - generate_.push(edge_pool_.construct(*take_, arity, note)); - take_ = static_cast(partial_edge_pool_.malloc()); - } - - bool Empty() const { return generate_.empty(); } - - /* Generate hypotheses and send them to output. Normally, output is a - * VertexGenerator, but the decoder may want to route edges to different - * vertices i.e. if they have different LHS non-terminal labels. - */ - template void Search(Context &context, Output &output) { - int to_pop = context.PopLimit(); - while (to_pop > 0 && !generate_.empty()) { - EdgeGenerator *top = generate_.top(); - generate_.pop(); - PartialEdge *ret = top->Pop(context, partial_edge_pool_); - if (ret) { - output.NewHypothesis(*ret, top->GetNote()); - --to_pop; - if (top->TopScore() != -kScoreInf) { - generate_.push(top); - } - } else { - generate_.push(top); - } - } - output.FinishedSearch(); - } - - private: - boost::object_pool edge_pool_; - - struct LessByTopScore : public std::binary_function { - bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { - return first->TopScore() < second->TopScore(); - } - }; - - typedef std::priority_queue, LessByTopScore> Generate; - Generate generate_; - - boost::pool<> partial_edge_pool_; - - PartialEdge *take_; -}; - -} // namespace search -#endif // SEARCH_EDGE_QUEUE__ diff --git a/klm/search/final.hh b/klm/search/final.hh index 1b3092ac..50e62cf2 100644 --- a/klm/search/final.hh +++ b/klm/search/final.hh @@ -1,37 +1,34 @@ #ifndef SEARCH_FINAL__ #define SEARCH_FINAL__ -#include "search/arity.hh" -#include "search/note.hh" -#include "search/types.hh" - -#include +#include "search/header.hh" +#include "util/pool.hh" namespace search { -class Final { +// A full hypothesis with pointers to children. +class Final : public Header { public: - typedef boost::array ChildArray; + Final() {} - void Reset(Score bound, Note note, const Final &left, const Final &right) { - bound_ = bound; - note_ = note; - children_[0] = &left; - children_[1] = &right; + Final(util::Pool &pool, Score score, Arity arity, Note note) + : Header(pool.Allocate(Size(arity)), arity) { + SetScore(score); + SetNote(note); } - const ChildArray &Children() const { return children_; } - - Note GetNote() const { return note_; } - - Score Bound() const { return bound_; } + // These are arrays of length GetArity(). + Final *Children() { + return reinterpret_cast(After()); + } + const Final *Children() const { + return reinterpret_cast(After()); + } private: - Score bound_; - - Note note_; - - ChildArray children_; + static std::size_t Size(Arity arity) { + return kHeaderSize + arity * sizeof(const Final); + } }; } // namespace search diff --git a/klm/search/header.hh b/klm/search/header.hh new file mode 100644 index 00000000..25550dbe --- /dev/null +++ b/klm/search/header.hh @@ -0,0 +1,57 @@ +#ifndef SEARCH_HEADER__ +#define SEARCH_HEADER__ + +// Header consisting of Score, Arity, and Note + +#include "search/note.hh" +#include "search/types.hh" + +#include + +namespace search { + +// Copying is shallow. +class Header { + public: + bool Valid() const { return base_; } + + Score GetScore() const { + return *reinterpret_cast(base_); + } + void SetScore(Score to) { + *reinterpret_cast(base_) = to; + } + bool operator<(const Header &other) const { + return GetScore() < other.GetScore(); + } + + Arity GetArity() const { + return *reinterpret_cast(base_ + sizeof(Score)); + } + + Note GetNote() const { + return *reinterpret_cast(base_ + sizeof(Score) + sizeof(Arity)); + } + void SetNote(Note to) { + *reinterpret_cast(base_ + sizeof(Score) + sizeof(Arity)) = to; + } + + protected: + Header() : base_(NULL) {} + + Header(void *base, Arity arity) : base_(static_cast(base)) { + *reinterpret_cast(base_ + sizeof(Score)) = arity; + } + + static const std::size_t kHeaderSize = sizeof(Score) + sizeof(Arity) + sizeof(Note); + + uint8_t *After() { return base_ + kHeaderSize; } + const uint8_t *After() const { return base_ + kHeaderSize; } + + private: + uint8_t *base_; +}; + +} // namespace search + +#endif // SEARCH_HEADER__ diff --git a/klm/search/source.hh b/klm/search/source.hh deleted file mode 100644 index 11839f7b..00000000 --- a/klm/search/source.hh +++ /dev/null @@ -1,48 +0,0 @@ -#ifndef SEARCH_SOURCE__ -#define SEARCH_SOURCE__ - -#include "search/types.hh" - -#include -#include - -namespace search { - -template class Source { - public: - Source() : bound_(kScoreInf) {} - - Index Size() const { - return final_.size(); - } - - Score Bound() const { - return bound_; - } - - const Final &operator[](Index index) const { - return *final_[index]; - } - - Score ScoreOrBound(Index index) const { - return Size() > index ? final_[index]->Total() : Bound(); - } - - protected: - void AddFinal(const Final &store) { - final_.push_back(&store); - } - - void SetBound(Score to) { - assert(to <= bound_ + 0.001); - bound_ = to; - } - - private: - std::vector final_; - - Score bound_; -}; - -} // namespace search -#endif // SEARCH_SOURCE__ diff --git a/klm/search/types.hh b/klm/search/types.hh index 9726379f..06eb5bfa 100644 --- a/klm/search/types.hh +++ b/klm/search/types.hh @@ -1,17 +1,13 @@ #ifndef SEARCH_TYPES__ #define SEARCH_TYPES__ -#include +#include namespace search { typedef float Score; -const Score kScoreInf = INFINITY; -// This could have been an enum but gcc wants 4 bytes. -typedef bool ExtendDirection; -const ExtendDirection kExtendLeft = 0; -const ExtendDirection kExtendRight = 1; +typedef uint32_t Arity; } // namespace search diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc index cc53c0dd..11f4631f 100644 --- a/klm/search/vertex.cc +++ b/klm/search/vertex.cc @@ -21,9 +21,9 @@ struct GreaterByBound : public std::binary_functionBound(); + bound_ = end_.GetScore(); return; } if (extend_.size() == 1 && parent_ptr) { @@ -39,10 +39,4 @@ void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) { bound_ = extend_.front()->Bound(); } -namespace { -VertexNode kBlankVertexNode; -} // namespace - -PartialVertex kBlankPartialVertex(kBlankVertexNode); - } // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index e1a9ad11..52bc1dfe 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -18,7 +18,7 @@ class ContextBase; class VertexNode { public: - VertexNode() : end_(NULL) {} + VertexNode() {} void InitRoot() { extend_.clear(); @@ -26,8 +26,7 @@ class VertexNode { state_.left.length = 0; state_.right.length = 0; right_full_ = false; - bound_ = -kScoreInf; - end_ = NULL; + end_ = Final(); } lm::ngram::ChartState &MutableState() { return state_; } @@ -37,19 +36,20 @@ class VertexNode { extend_.push_back(next); } - void SetEnd(Final *end) { end_ = end; } + void SetEnd(Final end) { + assert(!end_.Valid()); + end_ = end; + } - Final &MutableEnd() { return *end_; } - void SortAndSet(ContextBase &context, VertexNode **parent_pointer); // Should only happen to a root node when the entire vertex is empty. bool Empty() const { - return !end_ && extend_.empty(); + return !end_.Valid() && extend_.empty(); } bool Complete() const { - return end_; + return end_.Valid(); } const lm::ngram::ChartState &State() const { return state_; } @@ -63,8 +63,8 @@ class VertexNode { return state_.left.length + state_.right.length; } - // May be NULL. - const Final *End() const { return end_; } + // Will be invalid unless this is a leaf. + const Final End() const { return end_; } const VertexNode &operator[](size_t index) const { return *extend_[index]; @@ -81,7 +81,7 @@ class VertexNode { bool right_full_; Score bound_; - Final *end_; + Final end_; }; class PartialVertex { @@ -97,7 +97,7 @@ class PartialVertex { const lm::ngram::ChartState &State() const { return back_->State(); } bool RightFull() const { return back_->RightFull(); } - Score Bound() const { return Complete() ? back_->End()->Bound() : (*back_)[index_].Bound(); } + Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); } unsigned char Length() const { return back_->Length(); } @@ -105,20 +105,24 @@ class PartialVertex { return index_ + 1 < back_->Size(); } - // Split into continuation and alternative, rendering this the alternative. - bool Split(PartialVertex &continuation) { + // Split into continuation and alternative, rendering this the continuation. + bool Split(PartialVertex &alternative) { assert(!Complete()); - continuation.back_ = &((*back_)[index_]); - continuation.index_ = 0; + bool ret; if (index_ + 1 < back_->Size()) { - ++index_; - return true; + alternative.index_ = index_ + 1; + alternative.back_ = back_; + ret = true; + } else { + ret = false; } - return false; + back_ = &((*back_)[index_]); + index_ = 0; + return ret; } - const Final &End() const { - return *back_->End(); + const Final End() const { + return back_->End(); } private: @@ -126,25 +130,22 @@ class PartialVertex { unsigned int index_; }; -extern PartialVertex kBlankPartialVertex; - class Vertex { public: Vertex() {} PartialVertex RootPartial() const { return PartialVertex(root_); } - const Final *BestChild() const { + const Final BestChild() const { PartialVertex top(RootPartial()); if (top.Empty()) { - return NULL; + return Final(); } else { PartialVertex continuation; while (!top.Complete()) { top.Split(continuation); - top = continuation; } - return &top.End(); + return top.End(); } } diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index d94e6e06..0945fe55 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -10,74 +10,85 @@ namespace search { VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { gen.root_.InitRoot(); - root_.under = &gen.root_; } namespace { + const uint64_t kCompleteAdd = static_cast(-1); -} // namespace -void VertexGenerator::NewHypothesis(const PartialEdge &partial, Note note) { - const lm::ngram::ChartState &state = partial.CompletedState(); - std::pair got(existing_.insert(std::pair(hash_value(state), NULL))); - if (!got.second) { - // Found it already. - Final &exists = *got.first->second; - if (exists.Bound() < partial.score) { - exists.Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); - } - return; +// Parallel structure to VertexNode. +struct Trie { + Trie() : under(NULL) {} + + VertexNode *under; + boost::unordered_map extend; +}; + +Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { + Trie &next = node.extend[added]; + if (!next.under) { + next.under = context.NewVertexNode(); + lm::ngram::ChartState &writing = next.under->MutableState(); + writing = state; + writing.left.full &= left_full && state.left.full; + next.under->MutableRightFull() = right_full && state.left.full; + writing.left.length = left; + writing.right.length = right; + node.under->AddExtend(next.under); } + return next; +} + +void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) { + Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote()); + Final *child_out = final.Children(); + const PartialVertex *part = partial.NT(); + const PartialVertex *const part_end_loop = part + partial.GetArity(); + for (; part != part_end_loop; ++part, ++child_out) + *child_out = part->End(); + + starter.under->SetEnd(final); +} + +void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { + const lm::ngram::ChartState &state = partial.CompletedState(); + unsigned char left = 0, right = 0; - Trie *node = &root_; + Trie *node = &root; while (true) { if (left == state.left.length) { - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, true, right, false); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false); for (; right < state.right.length; ++right) { - node = &FindOrInsert(*node, state.right.words[right], state, left, true, right + 1, false); + node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false); } break; } - node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, false); + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false); left++; if (right == state.right.length) { - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, false, right, true); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true); for (; left < state.left.length; ++left) { - node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, true); + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true); } break; } - node = &FindOrInsert(*node, state.right.words[right], state, left, false, right + 1, false); + node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false); right++; } - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); - got.first->second = CompleteTransition(*node, state, note, partial); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); + CompleteTransition(context, *node, partial); } -VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { - VertexGenerator::Trie &next = node.extend[added]; - if (!next.under) { - next.under = context_.NewVertexNode(); - lm::ngram::ChartState &writing = next.under->MutableState(); - writing = state; - writing.left.full &= left_full && state.left.full; - next.under->MutableRightFull() = right_full && state.left.full; - writing.left.length = left; - writing.right.length = right; - node.under->AddExtend(next.under); - } - return next; -} +} // namespace -Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial) { - VertexNode &node = *starter.under; - assert(node.State().left.full == state.left.full); - assert(!node.End()); - Final *final = context_.NewFinal(); - final->Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); - node.SetEnd(final); - return final; +void VertexGenerator::FinishedSearch() { + Trie root; + root.under = &gen_.root_; + for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) { + AddHypothesis(context_, root, i->second); + } + root.under->SortAndSet(context_, NULL); } } // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 6b98da3e..60e86112 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -1,13 +1,11 @@ #ifndef SEARCH_VERTEX_GENERATOR__ #define SEARCH_VERTEX_GENERATOR__ -#include "search/note.hh" +#include "search/edge.hh" #include "search/vertex.hh" #include -#include - namespace lm { namespace ngram { class ChartState; @@ -18,40 +16,29 @@ namespace search { class ContextBase; class Final; -struct PartialEdge; class VertexGenerator { public: VertexGenerator(ContextBase &context, Vertex &gen); - void NewHypothesis(const PartialEdge &partial, Note note); - - void FinishedSearch() { - root_.under->SortAndSet(context_, NULL); + void NewHypothesis(PartialEdge partial) { + const lm::ngram::ChartState &state = partial.CompletedState(); + std::pair ret(existing_.insert(std::make_pair(hash_value(state), partial))); + if (!ret.second && ret.first->second < partial) { + ret.first->second = partial; + } } + void FinishedSearch(); + const Vertex &Generating() const { return gen_; } private: - // Parallel structure to VertexNode. - struct Trie { - Trie() : under(NULL) {} - - VertexNode *under; - boost::unordered_map extend; - }; - - Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full); - - Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial); - ContextBase &context_; Vertex &gen_; - Trie root_; - - typedef boost::unordered_map Existing; + typedef boost::unordered_map Existing; Existing existing_; }; diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index 5ceccf2c..5306850f 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -26,6 +26,8 @@ libklm_util_a_SOURCES = \ file_piece.cc \ mmap.cc \ murmur_hash.cc \ + pool.cc \ + string_piece.cc \ usage.cc AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh index ff4d590f..9909736d 100644 --- a/klm/util/ersatz_progress.hh +++ b/klm/util/ersatz_progress.hh @@ -4,7 +4,7 @@ #include #include -#include +#include // Ersatz version of boost::progress so core language model doesn't depend on // boost. Also adds option to print nothing. diff --git a/klm/util/exception.hh b/klm/util/exception.hh index 83f99cd6..053a850b 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -6,7 +6,7 @@ #include #include -#include +#include namespace util { diff --git a/klm/util/pool.cc b/klm/util/pool.cc new file mode 100644 index 00000000..2dffd06f --- /dev/null +++ b/klm/util/pool.cc @@ -0,0 +1,35 @@ +#include "util/pool.hh" + +#include + +namespace util { + +Pool::Pool() { + current_ = NULL; + current_end_ = NULL; +} + +Pool::~Pool() { + FreeAll(); +} + +void Pool::FreeAll() { + for (std::vector::const_iterator i(free_list_.begin()); i != free_list_.end(); ++i) { + free(*i); + } + free_list_.clear(); + current_ = NULL; + current_end_ = NULL; +} + +void *Pool::More(std::size_t size) { + std::size_t amount = std::max(static_cast(32) << free_list_.size(), size); + uint8_t *ret = static_cast(malloc(amount)); + if (!ret) throw std::bad_alloc(); + free_list_.push_back(ret); + current_ = ret + size; + current_end_ = ret + amount; + return ret; +} + +} // namespace util diff --git a/klm/util/pool.hh b/klm/util/pool.hh new file mode 100644 index 00000000..72f8a0c8 --- /dev/null +++ b/klm/util/pool.hh @@ -0,0 +1,45 @@ +// Very simple pool. It can only allocate memory. And all of the memory it +// allocates must be freed at the same time. + +#ifndef UTIL_POOL__ +#define UTIL_POOL__ + +#include + +#include + +namespace util { + +class Pool { + public: + Pool(); + + ~Pool(); + + void *Allocate(std::size_t size) { + void *ret = current_; + current_ += size; + if (current_ < current_end_) { + return ret; + } else { + return More(size); + } + } + + void FreeAll(); + + private: + void *More(std::size_t size); + + std::vector free_list_; + + uint8_t *current_, *current_end_; + + // no copying + Pool(const Pool &); + Pool &operator=(const Pool &); +}; + +} // namespace util + +#endif // UTIL_POOL__ diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 770faa7e..4a8aff35 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -8,7 +8,7 @@ #include #include -#include +#include namespace util { diff --git a/klm/util/string_piece.cc b/klm/util/string_piece.cc new file mode 100644 index 00000000..b422cefc --- /dev/null +++ b/klm/util/string_piece.cc @@ -0,0 +1,192 @@ +// Copyright 2004 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in string_piece.hh. + +#include "util/string_piece.hh" + +#include + +#include + +#ifndef HAVE_ICU + +typedef StringPiece::size_type size_type; + +void StringPiece::CopyToString(std::string* target) const { + target->assign(ptr_, length_); +} + +size_type StringPiece::find(const StringPiece& s, size_type pos) const { + if (length_ < 0 || pos > static_cast(length_)) + return npos; + + const char* result = std::search(ptr_ + pos, ptr_ + length_, + s.ptr_, s.ptr_ + s.length_); + const size_type xpos = result - ptr_; + return xpos + s.length_ <= length_ ? xpos : npos; +} + +size_type StringPiece::find(char c, size_type pos) const { + if (length_ <= 0 || pos >= static_cast(length_)) { + return npos; + } + const char* result = std::find(ptr_ + pos, ptr_ + length_, c); + return result != ptr_ + length_ ? result - ptr_ : npos; +} + +size_type StringPiece::rfind(const StringPiece& s, size_type pos) const { + if (length_ < s.length_) return npos; + const size_t ulen = length_; + if (s.length_ == 0) return std::min(ulen, pos); + + const char* last = ptr_ + std::min(ulen - s.length_, pos) + s.length_; + const char* result = std::find_end(ptr_, last, s.ptr_, s.ptr_ + s.length_); + return result != last ? result - ptr_ : npos; +} + +size_type StringPiece::rfind(char c, size_type pos) const { + if (length_ <= 0) return npos; + for (int i = std::min(pos, static_cast(length_ - 1)); + i >= 0; --i) { + if (ptr_[i] == c) { + return i; + } + } + return npos; +} + +// For each character in characters_wanted, sets the index corresponding +// to the ASCII code of that character to 1 in table. This is used by +// the find_.*_of methods below to tell whether or not a character is in +// the lookup table in constant time. +// The argument `table' must be an array that is large enough to hold all +// the possible values of an unsigned char. Thus it should be be declared +// as follows: +// bool table[UCHAR_MAX + 1] +static inline void BuildLookupTable(const StringPiece& characters_wanted, + bool* table) { + const size_type length = characters_wanted.length(); + const char* const data = characters_wanted.data(); + for (size_type i = 0; i < length; ++i) { + table[static_cast(data[i])] = true; + } +} + +size_type StringPiece::find_first_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0 || s.length_ == 0) + return npos; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_first_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = pos; i < length_; ++i) { + if (lookup[static_cast(ptr_[i])]) { + return i; + } + } + return npos; +} + +size_type StringPiece::find_first_not_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0) + return npos; + + if (s.length_ == 0) + return 0; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_first_not_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = pos; i < length_; ++i) { + if (!lookup[static_cast(ptr_[i])]) { + return i; + } + } + return npos; +} + +size_type StringPiece::find_first_not_of(char c, size_type pos) const { + if (length_ == 0) + return npos; + + for (; pos < length_; ++pos) { + if (ptr_[pos] != c) { + return pos; + } + } + return npos; +} + +size_type StringPiece::find_last_of(const StringPiece& s, size_type pos) const { + if (length_ == 0 || s.length_ == 0) + return npos; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_last_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = std::min(pos, length_ - 1); ; --i) { + if (lookup[static_cast(ptr_[i])]) + return i; + if (i == 0) + break; + } + return npos; +} + +size_type StringPiece::find_last_not_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0) + return npos; + + size_type i = std::min(pos, length_ - 1); + if (s.length_ == 0) + return i; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_last_not_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (; ; --i) { + if (!lookup[static_cast(ptr_[i])]) + return i; + if (i == 0) + break; + } + return npos; +} + +size_type StringPiece::find_last_not_of(char c, size_type pos) const { + if (length_ == 0) + return npos; + + for (size_type i = std::min(pos, length_ - 1); ; --i) { + if (ptr_[i] != c) + return i; + if (i == 0) + break; + } + return npos; +} + +StringPiece StringPiece::substr(size_type pos, size_type n) const { + if (pos > length_) pos = length_; + if (n > length_ - pos) n = length_ - pos; + return StringPiece(ptr_ + pos, n); +} + +const size_type StringPiece::npos = size_type(-1); + +#endif // !HAVE_ICU diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh index c7e1c863..4a7f5460 100644 --- a/klm/util/tokenize_piece.hh +++ b/klm/util/tokenize_piece.hh @@ -54,6 +54,18 @@ class AnyCharacter { StringPiece chars_; }; +class AnyCharacterLast { + public: + explicit AnyCharacterLast(const StringPiece &chars) : chars_(chars) {} + + StringPiece Find(const StringPiece &in) const { + return StringPiece(std::find_end(in.data(), in.data() + in.size(), chars_.data(), chars_.data() + chars_.size()), 1); + } + + private: + StringPiece chars_; +}; + template class TokenIter : public boost::iterator_facade, const StringPiece, boost::forward_traversal_tag> { public: TokenIter() {} diff --git a/mira/Makefile.am b/mira/Makefile.am index 7b4a4e12..3f8f17cd 100644 --- a/mira/Makefile.am +++ b/mira/Makefile.am @@ -1,6 +1,6 @@ bin_PROGRAMS = kbest_mira kbest_mira_SOURCES = kbest_mira.cc -kbest_mira_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +kbest_mira_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/Makefile.am b/training/Makefile.am index 5254333a..f9c25391 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -32,60 +32,60 @@ libtraining_a_SOURCES = \ risk.cc mpi_online_optimize_SOURCES = mpi_online_optimize.cc -mpi_online_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_online_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_flex_optimize_SOURCES = mpi_flex_optimize.cc -mpi_flex_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_flex_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_extract_reachable_SOURCES = mpi_extract_reachable.cc -mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_extract_features_SOURCES = mpi_extract_features.cc -mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc cllh_observer.cc -mpi_batch_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_batch_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_compute_cllh_SOURCES = mpi_compute_cllh.cc cllh_observer.cc -mpi_compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz augment_grammar_SOURCES = augment_grammar.cc -augment_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +augment_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz test_ngram_SOURCES = test_ngram.cc -test_ngram_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +test_ngram_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz fast_align_SOURCES = fast_align.cc ttables.cc -fast_align_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +fast_align_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz lbl_model_SOURCES = lbl_model.cc -lbl_model_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +lbl_model_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz grammar_convert_SOURCES = grammar_convert.cc -grammar_convert_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +grammar_convert_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz optimize_test_SOURCES = optimize_test.cc -optimize_test_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +optimize_test_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz collapse_weights_SOURCES = collapse_weights.cc -collapse_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +collapse_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz lbfgs_test_SOURCES = lbfgs_test.cc -lbfgs_test_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +lbfgs_test_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_optimize_reduce_SOURCES = mr_optimize_reduce.cc -mr_optimize_reduce_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_optimize_reduce_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_em_map_adapter_SOURCES = mr_em_map_adapter.cc -mr_em_map_adapter_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_em_map_adapter_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_reduce_to_weights_SOURCES = mr_reduce_to_weights.cc -mr_reduce_to_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_reduce_to_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_em_adapted_reduce_SOURCES = mr_em_adapted_reduce.cc -mr_em_adapted_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_em_adapted_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz plftools_SOURCES = plftools.cc -plftools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +plftools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm -- cgit v1.2.3