diff options
Diffstat (limited to 'klm/search')
-rw-r--r-- | klm/search/Jamfile | 5 | ||||
-rw-r--r-- | klm/search/Makefile.am | 11 | ||||
-rw-r--r-- | klm/search/config.hh | 25 | ||||
-rw-r--r-- | klm/search/context.hh | 65 | ||||
-rw-r--r-- | klm/search/edge.hh | 54 | ||||
-rw-r--r-- | klm/search/edge_generator.cc | 110 | ||||
-rw-r--r-- | klm/search/edge_generator.hh | 57 | ||||
-rw-r--r-- | klm/search/final.hh | 36 | ||||
-rw-r--r-- | klm/search/header.hh | 57 | ||||
-rw-r--r-- | klm/search/note.hh | 12 | ||||
-rw-r--r-- | klm/search/rule.cc | 43 | ||||
-rw-r--r-- | klm/search/rule.hh | 20 | ||||
-rw-r--r-- | klm/search/types.hh | 14 | ||||
-rw-r--r-- | klm/search/vertex.cc | 42 | ||||
-rw-r--r-- | klm/search/vertex.hh | 159 | ||||
-rw-r--r-- | klm/search/vertex_generator.cc | 94 | ||||
-rw-r--r-- | klm/search/vertex_generator.hh | 46 | ||||
-rw-r--r-- | klm/search/weights.cc | 71 | ||||
-rw-r--r-- | klm/search/weights.hh | 52 | ||||
-rw-r--r-- | klm/search/weights_test.cc | 38 |
20 files changed, 1011 insertions, 0 deletions
diff --git a/klm/search/Jamfile b/klm/search/Jamfile new file mode 100644 index 00000000..bc95c53a --- /dev/null +++ b/klm/search/Jamfile @@ -0,0 +1,5 @@ +lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ; + +import testing ; + +unit-test weights_test : weights_test.cc search /top//boost_unit_test_framework ; diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am new file mode 100644 index 00000000..ccc5b7f6 --- /dev/null +++ b/klm/search/Makefile.am @@ -0,0 +1,11 @@ +noinst_LIBRARIES = libksearch.a + +libksearch_a_SOURCES = \ + edge_generator.cc \ + rule.cc \ + vertex.cc \ + vertex_generator.cc \ + weights.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. + diff --git a/klm/search/config.hh b/klm/search/config.hh new file mode 100644 index 00000000..ef8e2354 --- /dev/null +++ b/klm/search/config.hh @@ -0,0 +1,25 @@ +#ifndef SEARCH_CONFIG__ +#define SEARCH_CONFIG__ + +#include "search/weights.hh" +#include "util/string_piece.hh" + +namespace search { + +class Config { + public: + Config(const Weights &weights, unsigned int pop_limit) : + weights_(weights), pop_limit_(pop_limit) {} + + const Weights &GetWeights() const { return weights_; } + + unsigned int PopLimit() const { return pop_limit_; } + + private: + Weights weights_; + unsigned int pop_limit_; +}; + +} // namespace search + +#endif // SEARCH_CONFIG__ diff --git a/klm/search/context.hh b/klm/search/context.hh new file mode 100644 index 00000000..62163144 --- /dev/null +++ b/klm/search/context.hh @@ -0,0 +1,65 @@ +#ifndef SEARCH_CONTEXT__ +#define SEARCH_CONTEXT__ + +#include "lm/model.hh" +#include "search/config.hh" +#include "search/final.hh" +#include "search/types.hh" +#include "search/vertex.hh" +#include "util/exception.hh" +#include "util/pool.hh" + +#include <boost/pool/object_pool.hpp> +#include <boost/ptr_container/ptr_vector.hpp> + +#include <vector> + +namespace search { + +class Weights; + +class ContextBase { + public: + explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} + + util::Pool &FinalPool() { + return final_pool_; + } + + VertexNode *NewVertexNode() { + VertexNode *ret = vertex_node_pool_.construct(); + assert(ret); + return ret; + } + + void DeleteVertexNode(VertexNode *node) { + vertex_node_pool_.destroy(node); + } + + unsigned int PopLimit() const { return pop_limit_; } + + const Weights &GetWeights() const { return weights_; } + + private: + util::Pool final_pool_; + + boost::object_pool<VertexNode> vertex_node_pool_; + + unsigned int pop_limit_; + + const Weights &weights_; +}; + +template <class Model> class Context : public ContextBase { + public: + Context(const Config &config, const Model &model) : ContextBase(config), model_(model) {} + + const Model &LanguageModel() const { return model_; } + + private: + const Model &model_; +}; + +} // namespace search + +#endif // SEARCH_CONTEXT__ diff --git a/klm/search/edge.hh b/klm/search/edge.hh new file mode 100644 index 00000000..187904bf --- /dev/null +++ b/klm/search/edge.hh @@ -0,0 +1,54 @@ +#ifndef SEARCH_EDGE__ +#define SEARCH_EDGE__ + +#include "lm/state.hh" +#include "search/header.hh" +#include "search/types.hh" +#include "search/vertex.hh" +#include "util/pool.hh" + +#include <functional> + +#include <stdint.h> + +namespace search { + +// Copyable, but the copy will be shallow. +class PartialEdge : public Header { + public: + // Allow default construction for STL. + PartialEdge() {} + + PartialEdge(util::Pool &pool, Arity arity) + : Header(pool.Allocate(Size(arity, arity + 1)), arity) {} + + PartialEdge(util::Pool &pool, Arity arity, Arity chart_states) + : Header(pool.Allocate(Size(arity, chart_states)), arity) {} + + // Non-terminals + const PartialVertex *NT() const { + return reinterpret_cast<const PartialVertex*>(After()); + } + PartialVertex *NT() { + return reinterpret_cast<PartialVertex*>(After()); + } + + const lm::ngram::ChartState &CompletedState() const { + return *Between(); + } + const lm::ngram::ChartState *Between() const { + return reinterpret_cast<const lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex)); + } + lm::ngram::ChartState *Between() { + return reinterpret_cast<lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex)); + } + + private: + static std::size_t Size(Arity arity, Arity chart_states) { + return kHeaderSize + arity * sizeof(PartialVertex) + chart_states * sizeof(lm::ngram::ChartState); + } +}; + + +} // namespace search +#endif // SEARCH_EDGE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc new file mode 100644 index 00000000..260159b1 --- /dev/null +++ b/klm/search/edge_generator.cc @@ -0,0 +1,110 @@ +#include "search/edge_generator.hh" + +#include "lm/left.hh" +#include "lm/partial.hh" +#include "search/context.hh" +#include "search/vertex.hh" + +#include <numeric> + +namespace search { + +namespace { + +template <class Model> void FastScore(const Context<Model> &context, Arity victim, Arity before_idx, Arity incomplete, const PartialVertex &previous_vertex, PartialEdge update) { + lm::ngram::ChartState *between = update.Between(); + lm::ngram::ChartState *before = &between[before_idx], *after = &between[before_idx + 1]; + + float adjustment = 0.0; + const lm::ngram::ChartState &previous_reveal = previous_vertex.State(); + const PartialVertex &update_nt = update.NT()[victim]; + const lm::ngram::ChartState &update_reveal = update_nt.State(); + if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) { + adjustment += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); + } + if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous_vertex.RightFull())) { + adjustment += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); + } + if (update_nt.Complete()) { + if (update_reveal.left.full) { + before->left.full = true; + } else { + assert(update_reveal.left.length == update_reveal.right.length); + adjustment += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); + } + before->right = after->right; + // Shift the others shifted one down, covering after. + for (lm::ngram::ChartState *cover = after; cover < between + incomplete; ++cover) { + *cover = *(cover + 1); + } + } + update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM()); +} + +} // namespace + +template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) { + assert(!generate_.empty()); + PartialEdge top = generate_.top(); + generate_.pop(); + PartialVertex *const top_nt = top.NT(); + const Arity arity = top.GetArity(); + + Arity victim = 0; + Arity victim_completed; + Arity incomplete; + // Select victim or return if complete. + { + Arity completed = 0; + unsigned char lowest_length = 255; + for (Arity i = 0; i != arity; ++i) { + if (top_nt[i].Complete()) { + ++completed; + } else if (top_nt[i].Length() < lowest_length) { + lowest_length = top_nt[i].Length(); + victim = i; + victim_completed = completed; + } + } + if (lowest_length == 255) { + return top; + } + incomplete = arity - completed; + } + + PartialVertex old_value(top_nt[victim]); + PartialVertex alternate_changed; + if (top_nt[victim].Split(alternate_changed)) { + PartialEdge alternate(partial_edge_pool_, arity, incomplete + 1); + alternate.SetScore(top.GetScore() + alternate_changed.Bound() - old_value.Bound()); + + alternate.SetNote(top.GetNote()); + + PartialVertex *alternate_nt = alternate.NT(); + for (Arity i = 0; i < victim; ++i) alternate_nt[i] = top_nt[i]; + alternate_nt[victim] = alternate_changed; + for (Arity i = victim + 1; i < arity; ++i) alternate_nt[i] = top_nt[i]; + + memcpy(alternate.Between(), top.Between(), sizeof(lm::ngram::ChartState) * (incomplete + 1)); + + // TODO: dedupe? + generate_.push(alternate); + } + + // top is now the continuation. + FastScore(context, victim, victim - victim_completed, incomplete, old_value, top); + // TODO: dedupe? + generate_.push(top); + + // Invalid indicates no new hypothesis generated. + return PartialEdge(); +} + +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context); + +} // namespace search diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh new file mode 100644 index 00000000..582c78b7 --- /dev/null +++ b/klm/search/edge_generator.hh @@ -0,0 +1,57 @@ +#ifndef SEARCH_EDGE_GENERATOR__ +#define SEARCH_EDGE_GENERATOR__ + +#include "search/edge.hh" +#include "search/note.hh" +#include "search/types.hh" + +#include <queue> + +namespace lm { +namespace ngram { +class ChartState; +} // namespace ngram +} // namespace lm + +namespace search { + +template <class Model> class Context; + +class EdgeGenerator { + public: + EdgeGenerator() {} + + PartialEdge AllocateEdge(Arity arity) { + return PartialEdge(partial_edge_pool_, arity); + } + + void AddEdge(PartialEdge edge) { + generate_.push(edge); + } + + bool Empty() const { return generate_.empty(); } + + // Pop. If there's a complete hypothesis, return it. Otherwise return an invalid PartialEdge. + template <class Model> PartialEdge Pop(Context<Model> &context); + + template <class Model, class Output> void Search(Context<Model> &context, Output &output) { + unsigned to_pop = context.PopLimit(); + while (to_pop > 0 && !generate_.empty()) { + PartialEdge got(Pop(context)); + if (got.Valid()) { + output.NewHypothesis(got); + --to_pop; + } + } + output.FinishedSearch(); + } + + private: + util::Pool partial_edge_pool_; + + typedef std::priority_queue<PartialEdge> Generate; + Generate generate_; +}; + +} // namespace search +#endif // SEARCH_EDGE_GENERATOR__ diff --git a/klm/search/final.hh b/klm/search/final.hh new file mode 100644 index 00000000..50e62cf2 --- /dev/null +++ b/klm/search/final.hh @@ -0,0 +1,36 @@ +#ifndef SEARCH_FINAL__ +#define SEARCH_FINAL__ + +#include "search/header.hh" +#include "util/pool.hh" + +namespace search { + +// A full hypothesis with pointers to children. +class Final : public Header { + public: + Final() {} + + Final(util::Pool &pool, Score score, Arity arity, Note note) + : Header(pool.Allocate(Size(arity)), arity) { + SetScore(score); + SetNote(note); + } + + // These are arrays of length GetArity(). + Final *Children() { + return reinterpret_cast<Final*>(After()); + } + const Final *Children() const { + return reinterpret_cast<const Final*>(After()); + } + + private: + static std::size_t Size(Arity arity) { + return kHeaderSize + arity * sizeof(const Final); + } +}; + +} // namespace search + +#endif // SEARCH_FINAL__ diff --git a/klm/search/header.hh b/klm/search/header.hh new file mode 100644 index 00000000..25550dbe --- /dev/null +++ b/klm/search/header.hh @@ -0,0 +1,57 @@ +#ifndef SEARCH_HEADER__ +#define SEARCH_HEADER__ + +// Header consisting of Score, Arity, and Note + +#include "search/note.hh" +#include "search/types.hh" + +#include <stdint.h> + +namespace search { + +// Copying is shallow. +class Header { + public: + bool Valid() const { return base_; } + + Score GetScore() const { + return *reinterpret_cast<const float*>(base_); + } + void SetScore(Score to) { + *reinterpret_cast<float*>(base_) = to; + } + bool operator<(const Header &other) const { + return GetScore() < other.GetScore(); + } + + Arity GetArity() const { + return *reinterpret_cast<const Arity*>(base_ + sizeof(Score)); + } + + Note GetNote() const { + return *reinterpret_cast<const Note*>(base_ + sizeof(Score) + sizeof(Arity)); + } + void SetNote(Note to) { + *reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to; + } + + protected: + Header() : base_(NULL) {} + + Header(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) { + *reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity; + } + + static const std::size_t kHeaderSize = sizeof(Score) + sizeof(Arity) + sizeof(Note); + + uint8_t *After() { return base_ + kHeaderSize; } + const uint8_t *After() const { return base_ + kHeaderSize; } + + private: + uint8_t *base_; +}; + +} // namespace search + +#endif // SEARCH_HEADER__ diff --git a/klm/search/note.hh b/klm/search/note.hh new file mode 100644 index 00000000..50bed06e --- /dev/null +++ b/klm/search/note.hh @@ -0,0 +1,12 @@ +#ifndef SEARCH_NOTE__ +#define SEARCH_NOTE__ + +namespace search { + +union Note { + const void *vp; +}; + +} // namespace search + +#endif // SEARCH_NOTE__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc new file mode 100644 index 00000000..5b00207e --- /dev/null +++ b/klm/search/rule.cc @@ -0,0 +1,43 @@ +#include "search/rule.hh" + +#include "search/context.hh" +#include "search/final.hh" + +#include <ostream> + +#include <math.h> + +namespace search { + +template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing) { + unsigned int oov_count = 0; + float prob = 0.0; + const Model &model = context.LanguageModel(); + const lm::WordIndex oov = model.GetVocabulary().NotFound(); + for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) { + lm::ngram::RuleScore<Model> scorer(model, *(writing++)); + // TODO: optimize + if (prepend_bos && (word == words.begin())) { + scorer.BeginSentence(); + } + for (; ; ++word) { + if (word == words.end()) { + prob += scorer.Finish(); + return static_cast<float>(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM(); + } + if (*word == kNonTerminal) break; + if (*word == oov) ++oov_count; + scorer.Terminal(*word); + } + prob += scorer.Finish(); + } +} + +template float ScoreRule(const Context<lm::ngram::RestProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::ProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::TrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::QuantTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::ArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::QuantArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); + +} // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh new file mode 100644 index 00000000..0ce2794d --- /dev/null +++ b/klm/search/rule.hh @@ -0,0 +1,20 @@ +#ifndef SEARCH_RULE__ +#define SEARCH_RULE__ + +#include "lm/left.hh" +#include "lm/word_index.hh" +#include "search/types.hh" + +#include <vector> + +namespace search { + +template <class Model> class Context; + +const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; + +template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *state_out); + +} // namespace search + +#endif // SEARCH_RULE__ diff --git a/klm/search/types.hh b/klm/search/types.hh new file mode 100644 index 00000000..06eb5bfa --- /dev/null +++ b/klm/search/types.hh @@ -0,0 +1,14 @@ +#ifndef SEARCH_TYPES__ +#define SEARCH_TYPES__ + +#include <stdint.h> + +namespace search { + +typedef float Score; + +typedef uint32_t Arity; + +} // namespace search + +#endif // SEARCH_TYPES__ diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc new file mode 100644 index 00000000..11f4631f --- /dev/null +++ b/klm/search/vertex.cc @@ -0,0 +1,42 @@ +#include "search/vertex.hh" + +#include "search/context.hh" + +#include <algorithm> +#include <functional> + +#include <assert.h> + +namespace search { + +namespace { + +struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> { + bool operator()(const VertexNode *first, const VertexNode *second) const { + return first->Bound() > second->Bound(); + } +}; + +} // namespace + +void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) { + if (Complete()) { + assert(end_.Valid()); + assert(extend_.empty()); + bound_ = end_.GetScore(); + return; + } + if (extend_.size() == 1 && parent_ptr) { + *parent_ptr = extend_[0]; + extend_[0]->SortAndSet(context, parent_ptr); + context.DeleteVertexNode(this); + return; + } + for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { + (*i)->SortAndSet(context, &*i); + } + std::sort(extend_.begin(), extend_.end(), GreaterByBound()); + bound_ = extend_.front()->Bound(); +} + +} // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh new file mode 100644 index 00000000..52bc1dfe --- /dev/null +++ b/klm/search/vertex.hh @@ -0,0 +1,159 @@ +#ifndef SEARCH_VERTEX__ +#define SEARCH_VERTEX__ + +#include "lm/left.hh" +#include "search/final.hh" +#include "search/types.hh" + +#include <boost/unordered_set.hpp> + +#include <queue> +#include <vector> + +#include <stdint.h> + +namespace search { + +class ContextBase; + +class VertexNode { + public: + VertexNode() {} + + void InitRoot() { + extend_.clear(); + state_.left.full = false; + state_.left.length = 0; + state_.right.length = 0; + right_full_ = false; + end_ = Final(); + } + + lm::ngram::ChartState &MutableState() { return state_; } + bool &MutableRightFull() { return right_full_; } + + void AddExtend(VertexNode *next) { + extend_.push_back(next); + } + + void SetEnd(Final end) { + assert(!end_.Valid()); + end_ = end; + } + + void SortAndSet(ContextBase &context, VertexNode **parent_pointer); + + // Should only happen to a root node when the entire vertex is empty. + bool Empty() const { + return !end_.Valid() && extend_.empty(); + } + + bool Complete() const { + return end_.Valid(); + } + + const lm::ngram::ChartState &State() const { return state_; } + bool RightFull() const { return right_full_; } + + Score Bound() const { + return bound_; + } + + unsigned char Length() const { + return state_.left.length + state_.right.length; + } + + // Will be invalid unless this is a leaf. + const Final End() const { return end_; } + + const VertexNode &operator[](size_t index) const { + return *extend_[index]; + } + + size_t Size() const { + return extend_.size(); + } + + private: + std::vector<VertexNode*> extend_; + + lm::ngram::ChartState state_; + bool right_full_; + + Score bound_; + Final end_; +}; + +class PartialVertex { + public: + PartialVertex() {} + + explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {} + + bool Empty() const { return back_->Empty(); } + + bool Complete() const { return back_->Complete(); } + + const lm::ngram::ChartState &State() const { return back_->State(); } + bool RightFull() const { return back_->RightFull(); } + + Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); } + + unsigned char Length() const { return back_->Length(); } + + bool HasAlternative() const { + return index_ + 1 < back_->Size(); + } + + // Split into continuation and alternative, rendering this the continuation. + bool Split(PartialVertex &alternative) { + assert(!Complete()); + bool ret; + if (index_ + 1 < back_->Size()) { + alternative.index_ = index_ + 1; + alternative.back_ = back_; + ret = true; + } else { + ret = false; + } + back_ = &((*back_)[index_]); + index_ = 0; + return ret; + } + + const Final End() const { + return back_->End(); + } + + private: + const VertexNode *back_; + unsigned int index_; +}; + +class Vertex { + public: + Vertex() {} + + PartialVertex RootPartial() const { return PartialVertex(root_); } + + const Final BestChild() const { + PartialVertex top(RootPartial()); + if (top.Empty()) { + return Final(); + } else { + PartialVertex continuation; + while (!top.Complete()) { + top.Split(continuation); + } + return top.End(); + } + } + + private: + friend class VertexGenerator; + + VertexNode root_; +}; + +} // namespace search +#endif // SEARCH_VERTEX__ diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc new file mode 100644 index 00000000..0945fe55 --- /dev/null +++ b/klm/search/vertex_generator.cc @@ -0,0 +1,94 @@ +#include "search/vertex_generator.hh" + +#include "lm/left.hh" +#include "search/context.hh" +#include "search/edge.hh" + +#include <stdint.h> + +namespace search { + +VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { + gen.root_.InitRoot(); +} + +namespace { + +const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); + +// Parallel structure to VertexNode. +struct Trie { + Trie() : under(NULL) {} + + VertexNode *under; + boost::unordered_map<uint64_t, Trie> extend; +}; + +Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { + Trie &next = node.extend[added]; + if (!next.under) { + next.under = context.NewVertexNode(); + lm::ngram::ChartState &writing = next.under->MutableState(); + writing = state; + writing.left.full &= left_full && state.left.full; + next.under->MutableRightFull() = right_full && state.left.full; + writing.left.length = left; + writing.right.length = right; + node.under->AddExtend(next.under); + } + return next; +} + +void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) { + Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote()); + Final *child_out = final.Children(); + const PartialVertex *part = partial.NT(); + const PartialVertex *const part_end_loop = part + partial.GetArity(); + for (; part != part_end_loop; ++part, ++child_out) + *child_out = part->End(); + + starter.under->SetEnd(final); +} + +void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { + const lm::ngram::ChartState &state = partial.CompletedState(); + + unsigned char left = 0, right = 0; + Trie *node = &root; + while (true) { + if (left == state.left.length) { + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false); + for (; right < state.right.length; ++right) { + node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false); + } + break; + } + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false); + left++; + if (right == state.right.length) { + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true); + for (; left < state.left.length; ++left) { + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true); + } + break; + } + node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false); + right++; + } + + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); + CompleteTransition(context, *node, partial); +} + +} // namespace + +void VertexGenerator::FinishedSearch() { + Trie root; + root.under = &gen_.root_; + for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) { + AddHypothesis(context_, root, i->second); + } + root.under->SortAndSet(context_, NULL); +} + +} // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh new file mode 100644 index 00000000..60e86112 --- /dev/null +++ b/klm/search/vertex_generator.hh @@ -0,0 +1,46 @@ +#ifndef SEARCH_VERTEX_GENERATOR__ +#define SEARCH_VERTEX_GENERATOR__ + +#include "search/edge.hh" +#include "search/vertex.hh" + +#include <boost/unordered_map.hpp> + +namespace lm { +namespace ngram { +class ChartState; +} // namespace ngram +} // namespace lm + +namespace search { + +class ContextBase; +class Final; + +class VertexGenerator { + public: + VertexGenerator(ContextBase &context, Vertex &gen); + + void NewHypothesis(PartialEdge partial) { + const lm::ngram::ChartState &state = partial.CompletedState(); + std::pair<Existing::iterator, bool> ret(existing_.insert(std::make_pair(hash_value(state), partial))); + if (!ret.second && ret.first->second < partial) { + ret.first->second = partial; + } + } + + void FinishedSearch(); + + const Vertex &Generating() const { return gen_; } + + private: + ContextBase &context_; + + Vertex &gen_; + + typedef boost::unordered_map<uint64_t, PartialEdge> Existing; + Existing existing_; +}; + +} // namespace search +#endif // SEARCH_VERTEX_GENERATOR__ diff --git a/klm/search/weights.cc b/klm/search/weights.cc new file mode 100644 index 00000000..d65471ad --- /dev/null +++ b/klm/search/weights.cc @@ -0,0 +1,71 @@ +#include "search/weights.hh" +#include "util/tokenize_piece.hh" + +#include <cstdlib> + +namespace search { + +namespace { +struct Insert { + void operator()(boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) const { + std::string copy(name.data(), name.size()); + map[copy] = score; + } +}; + +struct DotProduct { + search::Score total; + DotProduct() : total(0.0) {} + + void operator()(const boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) { + boost::unordered_map<std::string, search::Score>::const_iterator i(FindStringPiece(map, name)); + if (i != map.end()) + total += score * i->second; + } +}; + +template <class Map, class Op> void Parse(StringPiece text, Map &map, Op &op) { + for (util::TokenIter<util::SingleCharacter, true> spaces(text, ' '); spaces; ++spaces) { + util::TokenIter<util::SingleCharacter> equals(*spaces, '='); + UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces); + StringPiece name(*equals); + UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces); + char *end; + // Assumes proper termination. + double value = std::strtod(equals->data(), &end); + UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals); + UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces); + op(map, name, value); + } +} + +} // namespace + +Weights::Weights(StringPiece text) { + Insert op; + Parse<Map, Insert>(text, map_, op); + lm_ = Steal("LanguageModel"); + oov_ = Steal("OOV"); + word_penalty_ = Steal("WordPenalty"); +} + +Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {} + +search::Score Weights::DotNoLM(StringPiece text) const { + DotProduct dot; + Parse<const Map, DotProduct>(text, map_, dot); + return dot.total; +} + +float Weights::Steal(const std::string &str) { + Map::iterator i(map_.find(str)); + if (i == map_.end()) { + return 0.0; + } else { + float ret = i->second; + map_.erase(i); + return ret; + } +} + +} // namespace search diff --git a/klm/search/weights.hh b/klm/search/weights.hh new file mode 100644 index 00000000..df1c419f --- /dev/null +++ b/klm/search/weights.hh @@ -0,0 +1,52 @@ +// For now, the individual features are not kept. +#ifndef SEARCH_WEIGHTS__ +#define SEARCH_WEIGHTS__ + +#include "search/types.hh" +#include "util/exception.hh" +#include "util/string_piece.hh" + +#include <boost/unordered_map.hpp> + +#include <string> + +namespace search { + +class WeightParseException : public util::Exception { + public: + WeightParseException() {} + ~WeightParseException() throw() {} +}; + +class Weights { + public: + // Parses weights, sets lm_weight_, removes it from map_. + explicit Weights(StringPiece text); + + // Just the three scores we care about adding. + Weights(Score lm, Score oov, Score word_penalty); + + Score DotNoLM(StringPiece text) const; + + Score LM() const { return lm_; } + + Score OOV() const { return oov_; } + + Score WordPenalty() const { return word_penalty_; } + + // Mostly for testing. + const boost::unordered_map<std::string, Score> &GetMap() const { return map_; } + + private: + float Steal(const std::string &str); + + typedef boost::unordered_map<std::string, Score> Map; + + Map map_; + + Score lm_, oov_, word_penalty_; +}; + +} // namespace search + +#endif // SEARCH_WEIGHTS__ diff --git a/klm/search/weights_test.cc b/klm/search/weights_test.cc new file mode 100644 index 00000000..4811ff06 --- /dev/null +++ b/klm/search/weights_test.cc @@ -0,0 +1,38 @@ +#include "search/weights.hh" + +#define BOOST_TEST_MODULE WeightTest +#include <boost/test/unit_test.hpp> +#include <boost/test/floating_point_comparison.hpp> + +namespace search { +namespace { + +#define CHECK_WEIGHT(value, string) \ + i = parsed.find(string); \ + BOOST_REQUIRE(i != parsed.end()); \ + BOOST_CHECK_CLOSE((value), i->second, 0.001); + +BOOST_AUTO_TEST_CASE(parse) { + // These are not real feature weights. + Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); + const boost::unordered_map<std::string, search::Score> &parsed = w.GetMap(); + boost::unordered_map<std::string, search::Score>::const_iterator i; + CHECK_WEIGHT(0.0, "rarity"); + CHECK_WEIGHT(0.0, "phrase-SGT"); + CHECK_WEIGHT(9.45117, "phrase-TGS"); + CHECK_WEIGHT(2.33833, "lexical-SGT"); + BOOST_CHECK(parsed.end() == parsed.find("lm")); + BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001); + CHECK_WEIGHT(-28.3317, "lexical-TGS"); + CHECK_WEIGHT(5.0, "glue?"); +} + +BOOST_AUTO_TEST_CASE(dot) { + Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); + BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001); + BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001); + BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001); +} + +} // namespace +} // namespace search |