diff options
Diffstat (limited to 'klm/search')
-rw-r--r-- | klm/search/Makefile.am | 23 | ||||
-rw-r--r-- | klm/search/applied.hh | 86 | ||||
-rw-r--r-- | klm/search/config.hh | 38 | ||||
-rw-r--r-- | klm/search/context.hh | 49 | ||||
-rw-r--r-- | klm/search/dedupe.hh | 131 | ||||
-rw-r--r-- | klm/search/edge.hh | 54 | ||||
-rw-r--r-- | klm/search/edge_generator.cc | 111 | ||||
-rw-r--r-- | klm/search/edge_generator.hh | 56 | ||||
-rw-r--r-- | klm/search/header.hh | 64 | ||||
-rw-r--r-- | klm/search/nbest.cc | 106 | ||||
-rw-r--r-- | klm/search/nbest.hh | 81 | ||||
-rw-r--r-- | klm/search/rule.cc | 43 | ||||
-rw-r--r-- | klm/search/rule.hh | 25 | ||||
-rw-r--r-- | klm/search/types.hh | 31 | ||||
-rw-r--r-- | klm/search/vertex.cc | 55 | ||||
-rw-r--r-- | klm/search/vertex.hh | 164 | ||||
-rw-r--r-- | klm/search/vertex_generator.cc | 68 | ||||
-rw-r--r-- | klm/search/vertex_generator.hh | 99 |
18 files changed, 1284 insertions, 0 deletions
diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am new file mode 100644 index 00000000..03554276 --- /dev/null +++ b/klm/search/Makefile.am @@ -0,0 +1,23 @@ +noinst_LIBRARIES = libksearch.a + +libksearch_a_SOURCES = \ + applied.hh \ + config.hh \ + context.hh \ + dedupe.hh \ + edge.hh \ + edge_generator.hh \ + header.hh \ + nbest.hh \ + rule.hh \ + types.hh \ + vertex.hh \ + vertex_generator.hh \ + edge_generator.cc \ + nbest.cc \ + rule.cc \ + vertex.cc \ + vertex_generator.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/klm + diff --git a/klm/search/applied.hh b/klm/search/applied.hh new file mode 100644 index 00000000..bd659e5c --- /dev/null +++ b/klm/search/applied.hh @@ -0,0 +1,86 @@ +#ifndef SEARCH_APPLIED__ +#define SEARCH_APPLIED__ + +#include "search/edge.hh" +#include "search/header.hh" +#include "util/pool.hh" + +#include <math.h> + +namespace search { + +// A full hypothesis: a score, arity of the rule, a pointer to the decoder's rule (Note), and pointers to non-terminals that were substituted. +template <class Below> class GenericApplied : public Header { + public: + GenericApplied() {} + + GenericApplied(void *location, PartialEdge partial) + : Header(location) { + memcpy(Base(), partial.Base(), kHeaderSize); + Below *child_out = Children(); + const PartialVertex *part = partial.NT(); + const PartialVertex *const part_end_loop = part + partial.GetArity(); + for (; part != part_end_loop; ++part, ++child_out) + *child_out = Below(part->End()); + } + + GenericApplied(void *location, Score score, Arity arity, Note note) : Header(location, arity) { + SetScore(score); + SetNote(note); + } + + explicit GenericApplied(History from) : Header(from) {} + + + // These are arrays of length GetArity(). + Below *Children() { + return reinterpret_cast<Below*>(After()); + } + const Below *Children() const { + return reinterpret_cast<const Below*>(After()); + } + + static std::size_t Size(Arity arity) { + return kHeaderSize + arity * sizeof(const Below); + } +}; + +// Applied rule that references itself. +class Applied : public GenericApplied<Applied> { + private: + typedef GenericApplied<Applied> P; + + public: + Applied() {} + Applied(void *location, PartialEdge partial) : P(location, partial) {} + Applied(History from) : P(from) {} +}; + +// How to build single-best hypotheses. +class SingleBest { + public: + typedef PartialEdge Combine; + + void Add(PartialEdge &existing, PartialEdge add) const { + if (!existing.Valid() || existing.GetScore() < add.GetScore()) + existing = add; + } + + NBestComplete Complete(PartialEdge partial) { + if (!partial.Valid()) + return NBestComplete(NULL, lm::ngram::ChartState(), -INFINITY); + void *place_final = pool_.Allocate(Applied::Size(partial.GetArity())); + Applied(place_final, partial); + return NBestComplete( + place_final, + partial.CompletedState(), + partial.GetScore()); + } + + private: + util::Pool pool_; +}; + +} // namespace search + +#endif // SEARCH_APPLIED__ diff --git a/klm/search/config.hh b/klm/search/config.hh new file mode 100644 index 00000000..ba18c09e --- /dev/null +++ b/klm/search/config.hh @@ -0,0 +1,38 @@ +#ifndef SEARCH_CONFIG__ +#define SEARCH_CONFIG__ + +#include "search/types.hh" + +namespace search { + +struct NBestConfig { + explicit NBestConfig(unsigned int in_size) { + keep = in_size; + size = in_size; + } + + unsigned int keep, size; +}; + +class Config { + public: + Config(Score lm_weight, unsigned int pop_limit, const NBestConfig &nbest) : + lm_weight_(lm_weight), pop_limit_(pop_limit), nbest_(nbest) {} + + Score LMWeight() const { return lm_weight_; } + + unsigned int PopLimit() const { return pop_limit_; } + + const NBestConfig &GetNBest() const { return nbest_; } + + private: + Score lm_weight_; + + unsigned int pop_limit_; + + NBestConfig nbest_; +}; + +} // namespace search + +#endif // SEARCH_CONFIG__ diff --git a/klm/search/context.hh b/klm/search/context.hh new file mode 100644 index 00000000..08f21bbf --- /dev/null +++ b/klm/search/context.hh @@ -0,0 +1,49 @@ +#ifndef SEARCH_CONTEXT__ +#define SEARCH_CONTEXT__ + +#include "search/config.hh" +#include "search/vertex.hh" + +#include <boost/pool/object_pool.hpp> + +namespace search { + +class ContextBase { + public: + explicit ContextBase(const Config &config) : config_(config) {} + + VertexNode *NewVertexNode() { + VertexNode *ret = vertex_node_pool_.construct(); + assert(ret); + return ret; + } + + void DeleteVertexNode(VertexNode *node) { + vertex_node_pool_.destroy(node); + } + + unsigned int PopLimit() const { return config_.PopLimit(); } + + Score LMWeight() const { return config_.LMWeight(); } + + const Config &GetConfig() const { return config_; } + + private: + boost::object_pool<VertexNode> vertex_node_pool_; + + Config config_; +}; + +template <class Model> class Context : public ContextBase { + public: + Context(const Config &config, const Model &model) : ContextBase(config), model_(model) {} + + const Model &LanguageModel() const { return model_; } + + private: + const Model &model_; +}; + +} // namespace search + +#endif // SEARCH_CONTEXT__ diff --git a/klm/search/dedupe.hh b/klm/search/dedupe.hh new file mode 100644 index 00000000..7eaa3b95 --- /dev/null +++ b/klm/search/dedupe.hh @@ -0,0 +1,131 @@ +#ifndef SEARCH_DEDUPE__ +#define SEARCH_DEDUPE__ + +#include "lm/state.hh" +#include "search/edge_generator.hh" + +#include <boost/pool/object_pool.hpp> +#include <boost/unordered_map.hpp> + +namespace search { + +class Dedupe { + public: + Dedupe() {} + + PartialEdge AllocateEdge(Arity arity) { + return behind_.AllocateEdge(arity); + } + + void AddEdge(PartialEdge edge) { + edge.MutableFlags() = 0; + + uint64_t hash = 0; + const PartialVertex *v = edge.NT(); + const PartialVertex *v_end = v + edge.GetArity(); + for (; v != v_end; ++v) { + const void *ptr = v->Identify(); + hash = util::MurmurHashNative(&ptr, sizeof(const void*), hash); + } + + const lm::ngram::ChartState *c = edge.Between(); + const lm::ngram::ChartState *const c_end = c + edge.GetArity() + 1; + for (; c != c_end; ++c) hash = hash_value(*c, hash); + + std::pair<Table::iterator, bool> ret(table_.insert(std::make_pair(hash, edge))); + if (!ret.second) FoundDupe(ret.first->second, edge); + } + + bool Empty() const { return behind_.Empty(); } + + template <class Model, class Output> void Search(Context<Model> &context, Output &output) { + for (Table::const_iterator i(table_.begin()); i != table_.end(); ++i) { + behind_.AddEdge(i->second); + } + Unpack<Output> unpack(output, *this); + behind_.Search(context, unpack); + } + + private: + void FoundDupe(PartialEdge &table, PartialEdge adding) { + if (table.GetFlags() & kPackedFlag) { + Packed &packed = *static_cast<Packed*>(table.GetNote().mut); + if (table.GetScore() >= adding.GetScore()) { + packed.others.push_back(adding); + return; + } + Note original(packed.original); + packed.original = adding.GetNote(); + adding.SetNote(table.GetNote()); + table.SetNote(original); + packed.others.push_back(table); + packed.starting = adding.GetScore(); + table = adding; + table.MutableFlags() |= kPackedFlag; + return; + } + PartialEdge loser; + if (adding.GetScore() > table.GetScore()) { + loser = table; + table = adding; + } else { + loser = adding; + } + // table is winner, loser is loser... + packed_.construct(table, loser); + } + + struct Packed { + Packed(PartialEdge winner, PartialEdge loser) + : original(winner.GetNote()), starting(winner.GetScore()), others(1, loser) { + winner.MutableNote().vp = this; + winner.MutableFlags() |= kPackedFlag; + loser.MutableFlags() &= ~kPackedFlag; + } + Note original; + Score starting; + std::vector<PartialEdge> others; + }; + + template <class Output> class Unpack { + public: + explicit Unpack(Output &output, Dedupe &owner) : output_(output), owner_(owner) {} + + void NewHypothesis(PartialEdge edge) { + if (edge.GetFlags() & kPackedFlag) { + Packed &packed = *reinterpret_cast<Packed*>(edge.GetNote().mut); + edge.SetNote(packed.original); + edge.MutableFlags() = 0; + std::size_t copy_size = sizeof(PartialVertex) * edge.GetArity() + sizeof(lm::ngram::ChartState); + for (std::vector<PartialEdge>::iterator i = packed.others.begin(); i != packed.others.end(); ++i) { + PartialEdge copy(owner_.AllocateEdge(edge.GetArity())); + copy.SetScore(edge.GetScore() - packed.starting + i->GetScore()); + copy.MutableFlags() = 0; + copy.SetNote(i->GetNote()); + memcpy(copy.NT(), edge.NT(), copy_size); + output_.NewHypothesis(copy); + } + } + output_.NewHypothesis(edge); + } + + void FinishedSearch() { + output_.FinishedSearch(); + } + + private: + Output &output_; + Dedupe &owner_; + }; + + EdgeGenerator behind_; + + typedef boost::unordered_map<uint64_t, PartialEdge> Table; + Table table_; + + boost::object_pool<Packed> packed_; + + static const uint16_t kPackedFlag = 1; +}; +} // namespace search +#endif // SEARCH_DEDUPE__ diff --git a/klm/search/edge.hh b/klm/search/edge.hh new file mode 100644 index 00000000..187904bf --- /dev/null +++ b/klm/search/edge.hh @@ -0,0 +1,54 @@ +#ifndef SEARCH_EDGE__ +#define SEARCH_EDGE__ + +#include "lm/state.hh" +#include "search/header.hh" +#include "search/types.hh" +#include "search/vertex.hh" +#include "util/pool.hh" + +#include <functional> + +#include <stdint.h> + +namespace search { + +// Copyable, but the copy will be shallow. +class PartialEdge : public Header { + public: + // Allow default construction for STL. + PartialEdge() {} + + PartialEdge(util::Pool &pool, Arity arity) + : Header(pool.Allocate(Size(arity, arity + 1)), arity) {} + + PartialEdge(util::Pool &pool, Arity arity, Arity chart_states) + : Header(pool.Allocate(Size(arity, chart_states)), arity) {} + + // Non-terminals + const PartialVertex *NT() const { + return reinterpret_cast<const PartialVertex*>(After()); + } + PartialVertex *NT() { + return reinterpret_cast<PartialVertex*>(After()); + } + + const lm::ngram::ChartState &CompletedState() const { + return *Between(); + } + const lm::ngram::ChartState *Between() const { + return reinterpret_cast<const lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex)); + } + lm::ngram::ChartState *Between() { + return reinterpret_cast<lm::ngram::ChartState*>(After() + GetArity() * sizeof(PartialVertex)); + } + + private: + static std::size_t Size(Arity arity, Arity chart_states) { + return kHeaderSize + arity * sizeof(PartialVertex) + chart_states * sizeof(lm::ngram::ChartState); + } +}; + + +} // namespace search +#endif // SEARCH_EDGE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc new file mode 100644 index 00000000..eacf5de5 --- /dev/null +++ b/klm/search/edge_generator.cc @@ -0,0 +1,111 @@ +#include "search/edge_generator.hh" + +#include "lm/left.hh" +#include "lm/model.hh" +#include "lm/partial.hh" +#include "search/context.hh" +#include "search/vertex.hh" + +#include <numeric> + +namespace search { + +namespace { + +template <class Model> void FastScore(const Context<Model> &context, Arity victim, Arity before_idx, Arity incomplete, const PartialVertex &previous_vertex, PartialEdge update) { + lm::ngram::ChartState *between = update.Between(); + lm::ngram::ChartState *before = &between[before_idx], *after = &between[before_idx + 1]; + + float adjustment = 0.0; + const lm::ngram::ChartState &previous_reveal = previous_vertex.State(); + const PartialVertex &update_nt = update.NT()[victim]; + const lm::ngram::ChartState &update_reveal = update_nt.State(); + if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) { + adjustment += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); + } + if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous_vertex.RightFull())) { + adjustment += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); + } + if (update_nt.Complete()) { + if (update_reveal.left.full) { + before->left.full = true; + } else { + assert(update_reveal.left.length == update_reveal.right.length); + adjustment += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); + } + before->right = after->right; + // Shift the others shifted one down, covering after. + for (lm::ngram::ChartState *cover = after; cover < between + incomplete; ++cover) { + *cover = *(cover + 1); + } + } + update.SetScore(update.GetScore() + adjustment * context.LMWeight()); +} + +} // namespace + +template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) { + assert(!generate_.empty()); + PartialEdge top = generate_.top(); + generate_.pop(); + PartialVertex *const top_nt = top.NT(); + const Arity arity = top.GetArity(); + + Arity victim = 0; + Arity victim_completed; + Arity incomplete; + // Select victim or return if complete. + { + Arity completed = 0; + unsigned char lowest_length = 255; + for (Arity i = 0; i != arity; ++i) { + if (top_nt[i].Complete()) { + ++completed; + } else if (top_nt[i].Length() < lowest_length) { + lowest_length = top_nt[i].Length(); + victim = i; + victim_completed = completed; + } + } + if (lowest_length == 255) { + return top; + } + incomplete = arity - completed; + } + + PartialVertex old_value(top_nt[victim]); + PartialVertex alternate_changed; + if (top_nt[victim].Split(alternate_changed)) { + PartialEdge alternate(partial_edge_pool_, arity, incomplete + 1); + alternate.SetScore(top.GetScore() + alternate_changed.Bound() - old_value.Bound()); + + alternate.SetNote(top.GetNote()); + + PartialVertex *alternate_nt = alternate.NT(); + for (Arity i = 0; i < victim; ++i) alternate_nt[i] = top_nt[i]; + alternate_nt[victim] = alternate_changed; + for (Arity i = victim + 1; i < arity; ++i) alternate_nt[i] = top_nt[i]; + + memcpy(alternate.Between(), top.Between(), sizeof(lm::ngram::ChartState) * (incomplete + 1)); + + // TODO: dedupe? + generate_.push(alternate); + } + + // top is now the continuation. + FastScore(context, victim, victim - victim_completed, incomplete, old_value, top); + // TODO: dedupe? + generate_.push(top); + + // Invalid indicates no new hypothesis generated. + return PartialEdge(); +} + +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context); +template PartialEdge EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context); + +} // namespace search diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh new file mode 100644 index 00000000..203942c6 --- /dev/null +++ b/klm/search/edge_generator.hh @@ -0,0 +1,56 @@ +#ifndef SEARCH_EDGE_GENERATOR__ +#define SEARCH_EDGE_GENERATOR__ + +#include "search/edge.hh" +#include "search/types.hh" + +#include <queue> + +namespace lm { +namespace ngram { +class ChartState; +} // namespace ngram +} // namespace lm + +namespace search { + +template <class Model> class Context; + +class EdgeGenerator { + public: + EdgeGenerator() {} + + PartialEdge AllocateEdge(Arity arity) { + return PartialEdge(partial_edge_pool_, arity); + } + + void AddEdge(PartialEdge edge) { + generate_.push(edge); + } + + bool Empty() const { return generate_.empty(); } + + // Pop. If there's a complete hypothesis, return it. Otherwise return an invalid PartialEdge. + template <class Model> PartialEdge Pop(Context<Model> &context); + + template <class Model, class Output> void Search(Context<Model> &context, Output &output) { + unsigned to_pop = context.PopLimit(); + while (to_pop > 0 && !generate_.empty()) { + PartialEdge got(Pop(context)); + if (got.Valid()) { + output.NewHypothesis(got); + --to_pop; + } + } + output.FinishedSearch(); + } + + private: + util::Pool partial_edge_pool_; + + typedef std::priority_queue<PartialEdge> Generate; + Generate generate_; +}; + +} // namespace search +#endif // SEARCH_EDGE_GENERATOR__ diff --git a/klm/search/header.hh b/klm/search/header.hh new file mode 100644 index 00000000..69f0eed0 --- /dev/null +++ b/klm/search/header.hh @@ -0,0 +1,64 @@ +#ifndef SEARCH_HEADER__ +#define SEARCH_HEADER__ + +// Header consisting of Score, Arity, and Note + +#include "search/types.hh" + +#include <stdint.h> + +namespace search { + +// Copying is shallow. +class Header { + public: + bool Valid() const { return base_; } + + Score GetScore() const { + return *reinterpret_cast<const float*>(base_); + } + void SetScore(Score to) { + *reinterpret_cast<float*>(base_) = to; + } + bool operator<(const Header &other) const { + return GetScore() < other.GetScore(); + } + bool operator>(const Header &other) const { + return GetScore() > other.GetScore(); + } + + Arity GetArity() const { + return *reinterpret_cast<const Arity*>(base_ + sizeof(Score)); + } + + Note GetNote() const { + return *reinterpret_cast<const Note*>(base_ + sizeof(Score) + sizeof(Arity)); + } + void SetNote(Note to) { + *reinterpret_cast<Note*>(base_ + sizeof(Score) + sizeof(Arity)) = to; + } + + uint8_t *Base() { return base_; } + const uint8_t *Base() const { return base_; } + + protected: + Header() : base_(NULL) {} + + explicit Header(void *base) : base_(static_cast<uint8_t*>(base)) {} + + Header(void *base, Arity arity) : base_(static_cast<uint8_t*>(base)) { + *reinterpret_cast<Arity*>(base_ + sizeof(Score)) = arity; + } + + static const std::size_t kHeaderSize = sizeof(Score) + sizeof(Arity) + sizeof(Note); + + uint8_t *After() { return base_ + kHeaderSize; } + const uint8_t *After() const { return base_ + kHeaderSize; } + + private: + uint8_t *base_; +}; + +} // namespace search + +#endif // SEARCH_HEADER__ diff --git a/klm/search/nbest.cc b/klm/search/nbest.cc new file mode 100644 index 00000000..ec3322c9 --- /dev/null +++ b/klm/search/nbest.cc @@ -0,0 +1,106 @@ +#include "search/nbest.hh" + +#include "util/pool.hh" + +#include <algorithm> +#include <functional> +#include <queue> + +#include <assert.h> +#include <math.h> + +namespace search { + +NBestList::NBestList(std::vector<PartialEdge> &partials, util::Pool &entry_pool, std::size_t keep) { + assert(!partials.empty()); + std::vector<PartialEdge>::iterator end; + if (partials.size() > keep) { + end = partials.begin() + keep; + std::nth_element(partials.begin(), end, partials.end(), std::greater<PartialEdge>()); + } else { + end = partials.end(); + } + for (std::vector<PartialEdge>::const_iterator i(partials.begin()); i != end; ++i) { + queue_.push(QueueEntry(entry_pool.Allocate(QueueEntry::Size(i->GetArity())), *i)); + } +} + +Score NBestList::TopAfterConstructor() const { + assert(revealed_.empty()); + return queue_.top().GetScore(); +} + +const std::vector<Applied> &NBestList::Extract(util::Pool &pool, std::size_t n) { + while (revealed_.size() < n && !queue_.empty()) { + MoveTop(pool); + } + return revealed_; +} + +Score NBestList::Visit(util::Pool &pool, std::size_t index) { + if (index + 1 < revealed_.size()) + return revealed_[index + 1].GetScore() - revealed_[index].GetScore(); + if (queue_.empty()) + return -INFINITY; + if (index + 1 == revealed_.size()) + return queue_.top().GetScore() - revealed_[index].GetScore(); + assert(index == revealed_.size()); + + MoveTop(pool); + + if (queue_.empty()) return -INFINITY; + return queue_.top().GetScore() - revealed_[index].GetScore(); +} + +Applied NBestList::Get(util::Pool &pool, std::size_t index) { + assert(index <= revealed_.size()); + if (index == revealed_.size()) MoveTop(pool); + return revealed_[index]; +} + +void NBestList::MoveTop(util::Pool &pool) { + assert(!queue_.empty()); + QueueEntry entry(queue_.top()); + queue_.pop(); + RevealedRef *const children_begin = entry.Children(); + RevealedRef *const children_end = children_begin + entry.GetArity(); + Score basis = entry.GetScore(); + for (RevealedRef *child = children_begin; child != children_end; ++child) { + Score change = child->in_->Visit(pool, child->index_); + if (change != -INFINITY) { + assert(change < 0.001); + QueueEntry new_entry(pool.Allocate(QueueEntry::Size(entry.GetArity())), basis + change, entry.GetArity(), entry.GetNote()); + std::copy(children_begin, child, new_entry.Children()); + RevealedRef *update = new_entry.Children() + (child - children_begin); + update->in_ = child->in_; + update->index_ = child->index_ + 1; + std::copy(child + 1, children_end, update + 1); + queue_.push(new_entry); + } + // Gesmundo, A. and Henderson, J. Faster Cube Pruning, IWSLT 2010. + if (child->index_) break; + } + + // Convert QueueEntry to Applied. This leaves some unused memory. + void *overwrite = entry.Children(); + for (unsigned int i = 0; i < entry.GetArity(); ++i) { + RevealedRef from(*(static_cast<const RevealedRef*>(overwrite) + i)); + *(static_cast<Applied*>(overwrite) + i) = from.in_->Get(pool, from.index_); + } + revealed_.push_back(Applied(entry.Base())); +} + +NBestComplete NBest::Complete(std::vector<PartialEdge> &partials) { + assert(!partials.empty()); + NBestList *list = list_pool_.construct(partials, entry_pool_, config_.keep); + return NBestComplete( + list, + partials.front().CompletedState(), // All partials have the same state + list->TopAfterConstructor()); +} + +const std::vector<Applied> &NBest::Extract(History history) { + return static_cast<NBestList*>(history)->Extract(entry_pool_, config_.size); +} + +} // namespace search diff --git a/klm/search/nbest.hh b/klm/search/nbest.hh new file mode 100644 index 00000000..cb7651bc --- /dev/null +++ b/klm/search/nbest.hh @@ -0,0 +1,81 @@ +#ifndef SEARCH_NBEST__ +#define SEARCH_NBEST__ + +#include "search/applied.hh" +#include "search/config.hh" +#include "search/edge.hh" + +#include <boost/pool/object_pool.hpp> + +#include <cstddef> +#include <queue> +#include <vector> + +#include <assert.h> + +namespace search { + +class NBestList; + +class NBestList { + private: + class RevealedRef { + public: + explicit RevealedRef(History history) + : in_(static_cast<NBestList*>(history)), index_(0) {} + + private: + friend class NBestList; + + NBestList *in_; + std::size_t index_; + }; + + typedef GenericApplied<RevealedRef> QueueEntry; + + public: + NBestList(std::vector<PartialEdge> &existing, util::Pool &entry_pool, std::size_t keep); + + Score TopAfterConstructor() const; + + const std::vector<Applied> &Extract(util::Pool &pool, std::size_t n); + + private: + Score Visit(util::Pool &pool, std::size_t index); + + Applied Get(util::Pool &pool, std::size_t index); + + void MoveTop(util::Pool &pool); + + typedef std::vector<Applied> Revealed; + Revealed revealed_; + + typedef std::priority_queue<QueueEntry> Queue; + Queue queue_; +}; + +class NBest { + public: + typedef std::vector<PartialEdge> Combine; + + explicit NBest(const NBestConfig &config) : config_(config) {} + + void Add(std::vector<PartialEdge> &existing, PartialEdge addition) const { + existing.push_back(addition); + } + + NBestComplete Complete(std::vector<PartialEdge> &partials); + + const std::vector<Applied> &Extract(History root); + + private: + const NBestConfig config_; + + boost::object_pool<NBestList> list_pool_; + + util::Pool entry_pool_; +}; + +} // namespace search + +#endif // SEARCH_NBEST__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc new file mode 100644 index 00000000..0244a09f --- /dev/null +++ b/klm/search/rule.cc @@ -0,0 +1,43 @@ +#include "search/rule.hh" + +#include "lm/model.hh" +#include "search/context.hh" + +#include <ostream> + +#include <math.h> + +namespace search { + +template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing) { + ScoreRuleRet ret; + ret.prob = 0.0; + ret.oov = 0; + const lm::WordIndex oov = model.GetVocabulary().NotFound(), bos = model.GetVocabulary().BeginSentence(); + lm::ngram::RuleScore<Model> scorer(model, *(writing++)); + std::vector<lm::WordIndex>::const_iterator word = words.begin(); + if (word != words.end() && *word == bos) { + scorer.BeginSentence(); + ++word; + } + for (; word != words.end(); ++word) { + if (*word == kNonTerminal) { + ret.prob += scorer.Finish(); + scorer.Reset(*(writing++)); + } else { + if (*word == oov) ++ret.oov; + scorer.Terminal(*word); + } + } + ret.prob += scorer.Finish(); + return ret; +} + +template ScoreRuleRet ScoreRule(const lm::ngram::RestProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ProbingModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::TrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantArrayTrieModel &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *writing); + +} // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh new file mode 100644 index 00000000..43ca6162 --- /dev/null +++ b/klm/search/rule.hh @@ -0,0 +1,25 @@ +#ifndef SEARCH_RULE__ +#define SEARCH_RULE__ + +#include "lm/left.hh" +#include "lm/word_index.hh" +#include "search/types.hh" + +#include <vector> + +namespace search { + +const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; + +struct ScoreRuleRet { + Score prob; + unsigned int oov; +}; + +// Pass <s> and </s> normally. +// Indicate non-terminals with kNonTerminal. +template <class Model> ScoreRuleRet ScoreRule(const Model &model, const std::vector<lm::WordIndex> &words, lm::ngram::ChartState *state_out); + +} // namespace search + +#endif // SEARCH_RULE__ diff --git a/klm/search/types.hh b/klm/search/types.hh new file mode 100644 index 00000000..f9c849b3 --- /dev/null +++ b/klm/search/types.hh @@ -0,0 +1,31 @@ +#ifndef SEARCH_TYPES__ +#define SEARCH_TYPES__ + +#include <stdint.h> + +namespace lm { namespace ngram { class ChartState; } } + +namespace search { + +typedef float Score; + +typedef uint32_t Arity; + +union Note { + const void *vp; +}; + +typedef void *History; + +struct NBestComplete { + NBestComplete(History in_history, const lm::ngram::ChartState &in_state, Score in_score) + : history(in_history), state(&in_state), score(in_score) {} + + History history; + const lm::ngram::ChartState *state; + Score score; +}; + +} // namespace search + +#endif // SEARCH_TYPES__ diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc new file mode 100644 index 00000000..45842982 --- /dev/null +++ b/klm/search/vertex.cc @@ -0,0 +1,55 @@ +#include "search/vertex.hh" + +#include "search/context.hh" + +#include <algorithm> +#include <functional> + +#include <assert.h> + +namespace search { + +namespace { + +struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> { + bool operator()(const VertexNode *first, const VertexNode *second) const { + return first->Bound() > second->Bound(); + } +}; + +} // namespace + +void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) { + if (Complete()) { + assert(end_); + assert(extend_.empty()); + return; + } + if (extend_.size() == 1) { + parent_ptr = extend_[0]; + extend_[0]->RecursiveSortAndSet(context, parent_ptr); + context.DeleteVertexNode(this); + return; + } + for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { + (*i)->RecursiveSortAndSet(context, *i); + } + std::sort(extend_.begin(), extend_.end(), GreaterByBound()); + bound_ = extend_.front()->Bound(); +} + +void VertexNode::SortAndSet(ContextBase &context) { + // This is the root. The root might be empty. + if (extend_.empty()) { + bound_ = -INFINITY; + return; + } + // The root cannot be replaced. There's always one transition. + for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { + (*i)->RecursiveSortAndSet(context, *i); + } + std::sort(extend_.begin(), extend_.end(), GreaterByBound()); + bound_ = extend_.front()->Bound(); +} + +} // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh new file mode 100644 index 00000000..ca9a4fcd --- /dev/null +++ b/klm/search/vertex.hh @@ -0,0 +1,164 @@ +#ifndef SEARCH_VERTEX__ +#define SEARCH_VERTEX__ + +#include "lm/left.hh" +#include "search/types.hh" + +#include <boost/unordered_set.hpp> + +#include <queue> +#include <vector> + +#include <math.h> +#include <stdint.h> + +namespace search { + +class ContextBase; + +class VertexNode { + public: + VertexNode() : end_() {} + + void InitRoot() { + extend_.clear(); + state_.left.full = false; + state_.left.length = 0; + state_.right.length = 0; + right_full_ = false; + end_ = History(); + } + + lm::ngram::ChartState &MutableState() { return state_; } + bool &MutableRightFull() { return right_full_; } + + void AddExtend(VertexNode *next) { + extend_.push_back(next); + } + + void SetEnd(History end, Score score) { + assert(!end_); + end_ = end; + bound_ = score; + } + + void SortAndSet(ContextBase &context); + + // Should only happen to a root node when the entire vertex is empty. + bool Empty() const { + return !end_ && extend_.empty(); + } + + bool Complete() const { + return end_; + } + + const lm::ngram::ChartState &State() const { return state_; } + bool RightFull() const { return right_full_; } + + Score Bound() const { + return bound_; + } + + unsigned char Length() const { + return state_.left.length + state_.right.length; + } + + // Will be invalid unless this is a leaf. + History End() const { return end_; } + + const VertexNode &operator[](size_t index) const { + return *extend_[index]; + } + + size_t Size() const { + return extend_.size(); + } + + private: + void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent); + + std::vector<VertexNode*> extend_; + + lm::ngram::ChartState state_; + bool right_full_; + + Score bound_; + History end_; +}; + +class PartialVertex { + public: + PartialVertex() {} + + explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {} + + bool Empty() const { return back_->Empty(); } + + bool Complete() const { return back_->Complete(); } + + const lm::ngram::ChartState &State() const { return back_->State(); } + bool RightFull() const { return back_->RightFull(); } + + Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); } + + unsigned char Length() const { return back_->Length(); } + + bool HasAlternative() const { + return index_ + 1 < back_->Size(); + } + + // Split into continuation and alternative, rendering this the continuation. + bool Split(PartialVertex &alternative) { + assert(!Complete()); + bool ret; + if (index_ + 1 < back_->Size()) { + alternative.index_ = index_ + 1; + alternative.back_ = back_; + ret = true; + } else { + ret = false; + } + back_ = &((*back_)[index_]); + index_ = 0; + return ret; + } + + History End() const { + return back_->End(); + } + + private: + const VertexNode *back_; + unsigned int index_; +}; + +template <class Output> class VertexGenerator; + +class Vertex { + public: + Vertex() {} + + PartialVertex RootPartial() const { return PartialVertex(root_); } + + History BestChild() const { + PartialVertex top(RootPartial()); + if (top.Empty()) { + return History(); + } else { + PartialVertex continuation; + while (!top.Complete()) { + top.Split(continuation); + } + return top.End(); + } + } + + private: + template <class Output> friend class VertexGenerator; + template <class Output> friend class RootVertexGenerator; + VertexNode root_; +}; + +} // namespace search +#endif // SEARCH_VERTEX__ diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc new file mode 100644 index 00000000..73139ffc --- /dev/null +++ b/klm/search/vertex_generator.cc @@ -0,0 +1,68 @@ +#include "search/vertex_generator.hh" + +#include "lm/left.hh" +#include "search/context.hh" +#include "search/edge.hh" + +#include <boost/unordered_map.hpp> +#include <boost/version.hpp> + +#include <stdint.h> + +namespace search { + +#if BOOST_VERSION > 104200 +namespace { + +const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); + +Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { + Trie &next = node.extend[added]; + if (!next.under) { + next.under = context.NewVertexNode(); + lm::ngram::ChartState &writing = next.under->MutableState(); + writing = state; + writing.left.full &= left_full && state.left.full; + next.under->MutableRightFull() = right_full && state.left.full; + writing.left.length = left; + writing.right.length = right; + node.under->AddExtend(next.under); + } + return next; +} + +} // namespace + +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end) { + const lm::ngram::ChartState &state = *end.state; + + unsigned char left = 0, right = 0; + Trie *node = &root; + while (true) { + if (left == state.left.length) { + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false); + for (; right < state.right.length; ++right) { + node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false); + } + break; + } + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false); + left++; + if (right == state.right.length) { + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true); + for (; left < state.left.length; ++left) { + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true); + } + break; + } + node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false); + right++; + } + + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); + node->under->SetEnd(end.history, end.score); +} + +#endif // BOOST_VERSION + +} // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh new file mode 100644 index 00000000..646b8189 --- /dev/null +++ b/klm/search/vertex_generator.hh @@ -0,0 +1,99 @@ +#ifndef SEARCH_VERTEX_GENERATOR__ +#define SEARCH_VERTEX_GENERATOR__ + +#include "search/edge.hh" +#include "search/types.hh" +#include "search/vertex.hh" +#include "util/exception.hh" + +#include <boost/unordered_map.hpp> +#include <boost/version.hpp> + +namespace lm { +namespace ngram { +class ChartState; +} // namespace ngram +} // namespace lm + +namespace search { + +class ContextBase; + +#if BOOST_VERSION > 104200 +// Parallel structure to VertexNode. +struct Trie { + Trie() : under(NULL) {} + + VertexNode *under; + boost::unordered_map<uint64_t, Trie> extend; +}; + +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end); + +#endif // BOOST_VERSION + +// Output makes the single-best or n-best list. +template <class Output> class VertexGenerator { + public: + VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) { + gen.root_.InitRoot(); + } + + void NewHypothesis(PartialEdge partial) { + nbest_.Add(existing_[hash_value(partial.CompletedState())], partial); + } + + void FinishedSearch() { +#if BOOST_VERSION > 104200 + Trie root; + root.under = &gen_.root_; + for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) { + AddHypothesis(context_, root, nbest_.Complete(i->second)); + } + existing_.clear(); + root.under->SortAndSet(context_); +#else + UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search."); +#endif + } + + const Vertex &Generating() const { return gen_; } + + private: + ContextBase &context_; + + Vertex &gen_; + + typedef boost::unordered_map<uint64_t, typename Output::Combine> Existing; + Existing existing_; + + Output &nbest_; +}; + +// Special case for root vertex: everything should come together into the root +// node. In theory, this should happen naturally due to state collapsing with +// <s> and </s>. If that's the case, VertexGenerator is fine, though it will +// make one connection. +template <class Output> class RootVertexGenerator { + public: + RootVertexGenerator(Vertex &gen, Output &out) : gen_(gen), out_(out) {} + + void NewHypothesis(PartialEdge partial) { + out_.Add(combine_, partial); + } + + void FinishedSearch() { + gen_.root_.InitRoot(); + NBestComplete completed(out_.Complete(combine_)); + gen_.root_.SetEnd(completed.history, completed.score); + } + + private: + Vertex &gen_; + + typename Output::Combine combine_; + Output &out_; +}; + +} // namespace search +#endif // SEARCH_VERTEX_GENERATOR__ |