From 5d42134cc676278189b0f77708908542fbb5ccc9 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 14 Dec 2012 12:39:04 -0800 Subject: Updated incremental, updated kenlm. Incremental assumes --- klm/search/Makefile.am | 4 +- klm/search/applied.hh | 86 +++++++++++++++++++++++++++ klm/search/config.hh | 25 ++++++-- klm/search/context.hh | 28 ++------- klm/search/dedupe.hh | 131 +++++++++++++++++++++++++++++++++++++++++ klm/search/edge_generator.cc | 3 +- klm/search/edge_generator.hh | 1 - klm/search/final.hh | 36 ----------- klm/search/header.hh | 9 ++- klm/search/nbest.cc | 106 +++++++++++++++++++++++++++++++++ klm/search/nbest.hh | 81 +++++++++++++++++++++++++ klm/search/note.hh | 12 ---- klm/search/rule.cc | 52 ++++++++-------- klm/search/rule.hh | 11 +++- klm/search/types.hh | 17 ++++++ klm/search/vertex.cc | 27 ++++++--- klm/search/vertex.hh | 37 +++++++----- klm/search/vertex_generator.cc | 44 +++----------- klm/search/vertex_generator.hh | 72 ++++++++++++++++++---- klm/search/weights.cc | 71 ---------------------- klm/search/weights.hh | 52 ---------------- klm/search/weights_test.cc | 38 ------------ 22 files changed, 604 insertions(+), 339 deletions(-) create mode 100644 klm/search/applied.hh create mode 100644 klm/search/dedupe.hh delete mode 100644 klm/search/final.hh create mode 100644 klm/search/nbest.cc create mode 100644 klm/search/nbest.hh delete mode 100644 klm/search/note.hh delete mode 100644 klm/search/weights.cc delete mode 100644 klm/search/weights.hh delete mode 100644 klm/search/weights_test.cc (limited to 'klm/search') diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am index ccc5b7f6..5aea33c2 100644 --- a/klm/search/Makefile.am +++ b/klm/search/Makefile.am @@ -2,10 +2,10 @@ noinst_LIBRARIES = libksearch.a libksearch_a_SOURCES = \ edge_generator.cc \ + nbest.cc \ rule.cc \ vertex.cc \ - vertex_generator.cc \ - weights.cc + vertex_generator.cc AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/search/applied.hh b/klm/search/applied.hh new file mode 100644 index 00000000..bd659e5c --- /dev/null +++ b/klm/search/applied.hh @@ -0,0 +1,86 @@ +#ifndef SEARCH_APPLIED__ +#define SEARCH_APPLIED__ + +#include "search/edge.hh" +#include "search/header.hh" +#include "util/pool.hh" + +#include + +namespace search { + +// A full hypothesis: a score, arity of the rule, a pointer to the decoder's rule (Note), and pointers to non-terminals that were substituted. +template class GenericApplied : public Header { + public: + GenericApplied() {} + + GenericApplied(void *location, PartialEdge partial) + : Header(location) { + memcpy(Base(), partial.Base(), kHeaderSize); + Below *child_out = Children(); + const PartialVertex *part = partial.NT(); + const PartialVertex *const part_end_loop = part + partial.GetArity(); + for (; part != part_end_loop; ++part, ++child_out) + *child_out = Below(part->End()); + } + + GenericApplied(void *location, Score score, Arity arity, Note note) : Header(location, arity) { + SetScore(score); + SetNote(note); + } + + explicit GenericApplied(History from) : Header(from) {} + + + // These are arrays of length GetArity(). + Below *Children() { + return reinterpret_cast(After()); + } + const Below *Children() const { + return reinterpret_cast(After()); + } + + static std::size_t Size(Arity arity) { + return kHeaderSize + arity * sizeof(const Below); + } +}; + +// Applied rule that references itself. +class Applied : public GenericApplied { + private: + typedef GenericApplied P; + + public: + Applied() {} + Applied(void *location, PartialEdge partial) : P(location, partial) {} + Applied(History from) : P(from) {} +}; + +// How to build single-best hypotheses. +class SingleBest { + public: + typedef PartialEdge Combine; + + void Add(PartialEdge &existing, PartialEdge add) const { + if (!existing.Valid() || existing.GetScore() < add.GetScore()) + existing = add; + } + + NBestComplete Complete(PartialEdge partial) { + if (!partial.Valid()) + return NBestComplete(NULL, lm::ngram::ChartState(), -INFINITY); + void *place_final = pool_.Allocate(Applied::Size(partial.GetArity())); + Applied(place_final, partial); + return NBestComplete( + place_final, + partial.CompletedState(), + partial.GetScore()); + } + + private: + util::Pool pool_; +}; + +} // namespace search + +#endif // SEARCH_APPLIED__ diff --git a/klm/search/config.hh b/klm/search/config.hh index ef8e2354..ba18c09e 100644 --- a/klm/search/config.hh +++ b/klm/search/config.hh @@ -1,23 +1,36 @@ #ifndef SEARCH_CONFIG__ #define SEARCH_CONFIG__ -#include "search/weights.hh" -#include "util/string_piece.hh" +#include "search/types.hh" namespace search { +struct NBestConfig { + explicit NBestConfig(unsigned int in_size) { + keep = in_size; + size = in_size; + } + + unsigned int keep, size; +}; + class Config { public: - Config(const Weights &weights, unsigned int pop_limit) : - weights_(weights), pop_limit_(pop_limit) {} + Config(Score lm_weight, unsigned int pop_limit, const NBestConfig &nbest) : + lm_weight_(lm_weight), pop_limit_(pop_limit), nbest_(nbest) {} - const Weights &GetWeights() const { return weights_; } + Score LMWeight() const { return lm_weight_; } unsigned int PopLimit() const { return pop_limit_; } + const NBestConfig &GetNBest() const { return nbest_; } + private: - Weights weights_; + Score lm_weight_; + unsigned int pop_limit_; + + NBestConfig nbest_; }; } // namespace search diff --git a/klm/search/context.hh b/klm/search/context.hh index 62163144..08f21bbf 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -1,30 +1,16 @@ #ifndef SEARCH_CONTEXT__ #define SEARCH_CONTEXT__ -#include "lm/model.hh" #include "search/config.hh" -#include "search/final.hh" -#include "search/types.hh" #include "search/vertex.hh" -#include "util/exception.hh" -#include "util/pool.hh" #include -#include - -#include namespace search { -class Weights; - class ContextBase { public: - explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} - - util::Pool &FinalPool() { - return final_pool_; - } + explicit ContextBase(const Config &config) : config_(config) {} VertexNode *NewVertexNode() { VertexNode *ret = vertex_node_pool_.construct(); @@ -36,18 +22,16 @@ class ContextBase { vertex_node_pool_.destroy(node); } - unsigned int PopLimit() const { return pop_limit_; } + unsigned int PopLimit() const { return config_.PopLimit(); } - const Weights &GetWeights() const { return weights_; } + Score LMWeight() const { return config_.LMWeight(); } - private: - util::Pool final_pool_; + const Config &GetConfig() const { return config_; } + private: boost::object_pool vertex_node_pool_; - unsigned int pop_limit_; - - const Weights &weights_; + Config config_; }; template class Context : public ContextBase { diff --git a/klm/search/dedupe.hh b/klm/search/dedupe.hh new file mode 100644 index 00000000..7eaa3b95 --- /dev/null +++ b/klm/search/dedupe.hh @@ -0,0 +1,131 @@ +#ifndef SEARCH_DEDUPE__ +#define SEARCH_DEDUPE__ + +#include "lm/state.hh" +#include "search/edge_generator.hh" + +#include +#include + +namespace search { + +class Dedupe { + public: + Dedupe() {} + + PartialEdge AllocateEdge(Arity arity) { + return behind_.AllocateEdge(arity); + } + + void AddEdge(PartialEdge edge) { + edge.MutableFlags() = 0; + + uint64_t hash = 0; + const PartialVertex *v = edge.NT(); + const PartialVertex *v_end = v + edge.GetArity(); + for (; v != v_end; ++v) { + const void *ptr = v->Identify(); + hash = util::MurmurHashNative(&ptr, sizeof(const void*), hash); + } + + const lm::ngram::ChartState *c = edge.Between(); + const lm::ngram::ChartState *const c_end = c + edge.GetArity() + 1; + for (; c != c_end; ++c) hash = hash_value(*c, hash); + + std::pair ret(table_.insert(std::make_pair(hash, edge))); + if (!ret.second) FoundDupe(ret.first->second, edge); + } + + bool Empty() const { return behind_.Empty(); } + + template void Search(Context &context, Output &output) { + for (Table::const_iterator i(table_.begin()); i != table_.end(); ++i) { + behind_.AddEdge(i->second); + } + Unpack unpack(output, *this); + behind_.Search(context, unpack); + } + + private: + void FoundDupe(PartialEdge &table, PartialEdge adding) { + if (table.GetFlags() & kPackedFlag) { + Packed &packed = *static_cast(table.GetNote().mut); + if (table.GetScore() >= adding.GetScore()) { + packed.others.push_back(adding); + return; + } + Note original(packed.original); + packed.original = adding.GetNote(); + adding.SetNote(table.GetNote()); + table.SetNote(original); + packed.others.push_back(table); + packed.starting = adding.GetScore(); + table = adding; + table.MutableFlags() |= kPackedFlag; + return; + } + PartialEdge loser; + if (adding.GetScore() > table.GetScore()) { + loser = table; + table = adding; + } else { + loser = adding; + } + // table is winner, loser is loser... + packed_.construct(table, loser); + } + + struct Packed { + Packed(PartialEdge winner, PartialEdge loser) + : original(winner.GetNote()), starting(winner.GetScore()), others(1, loser) { + winner.MutableNote().vp = this; + winner.MutableFlags() |= kPackedFlag; + loser.MutableFlags() &= ~kPackedFlag; + } + Note original; + Score starting; + std::vector others; + }; + + template class Unpack { + public: + explicit Unpack(Output &output, Dedupe &owner) : output_(output), owner_(owner) {} + + void NewHypothesis(PartialEdge edge) { + if (edge.GetFlags() & kPackedFlag) { + Packed &packed = *reinterpret_cast(edge.GetNote().mut); + edge.SetNote(packed.original); + edge.MutableFlags() = 0; + std::size_t copy_size = sizeof(PartialVertex) * edge.GetArity() + sizeof(lm::ngram::ChartState); + for (std::vector::iterator i = packed.others.begin(); i != packed.others.end(); ++i) { + PartialEdge copy(owner_.AllocateEdge(edge.GetArity())); + copy.SetScore(edge.GetScore() - packed.starting + i->GetScore()); + copy.MutableFlags() = 0; + copy.SetNote(i->GetNote()); + memcpy(copy.NT(), edge.NT(), copy_size); + output_.NewHypothesis(copy); + } + } + output_.NewHypothesis(edge); + } + + void FinishedSearch() { + output_.FinishedSearch(); + } + + private: + Output &output_; + Dedupe &owner_; + }; + + EdgeGenerator behind_; + + typedef boost::unordered_map Table; + Table table_; + + boost::object_pool packed_; + + static const uint16_t kPackedFlag = 1; +}; +} // namespace search +#endif // SEARCH_DEDUPE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index 260159b1..eacf5de5 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -1,6 +1,7 @@ #include "search/edge_generator.hh" #include "lm/left.hh" +#include "lm/model.hh" #include "lm/partial.hh" #include "search/context.hh" #include "search/vertex.hh" @@ -38,7 +39,7 @@ template void FastScore(const Context &context, Arity victi *cover = *(cover + 1); } } - update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM()); + update.SetScore(update.GetScore() + adjustment * context.LMWeight()); } } // namespace diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh index 582c78b7..203942c6 100644 --- a/klm/search/edge_generator.hh +++ b/klm/search/edge_generator.hh @@ -2,7 +2,6 @@ #define SEARCH_EDGE_GENERATOR__ #include "search/edge.hh" -#include "search/note.hh" #include "search/types.hh" #include diff --git a/klm/search/final.hh b/klm/search/final.hh deleted file mode 100644 index 50e62cf2..00000000 --- a/klm/search/final.hh +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef SEARCH_FINAL__ -#define SEARCH_FINAL__ - -#include "search/header.hh" -#include "util/pool.hh" - -namespace search { - -// A full hypothesis with pointers to children. -class Final : public Header { - public: - Final() {} - - Final(util::Pool &pool, Score score, Arity arity, Note note) - : Header(pool.Allocate(Size(arity)), arity) { - SetScore(score); - SetNote(note); - } - - // These are arrays of length GetArity(). - Final *Children() { - return reinterpret_cast(After()); - } - const Final *Children() const { - return reinterpret_cast(After()); - } - - private: - static std::size_t Size(Arity arity) { - return kHeaderSize + arity * sizeof(const Final); - } -}; - -} // namespace search - -#endif // SEARCH_FINAL__ diff --git a/klm/search/header.hh b/klm/search/header.hh index 25550dbe..69f0eed0 100644 --- a/klm/search/header.hh +++ b/klm/search/header.hh @@ -3,7 +3,6 @@ // Header consisting of Score, Arity, and Note -#include "search/note.hh" #include "search/types.hh" #include @@ -24,6 +23,9 @@ class Header { bool operator<(const Header &other) const { return GetScore() < other.GetScore(); } + bool operator>(const Header &other) const { + return GetScore() > other.GetScore(); + } Arity GetArity() const { return *reinterpret_cast(base_ + sizeof(Score)); @@ -36,9 +38,14 @@ class Header { *reinterpret_cast(base_ + sizeof(Score) + sizeof(Arity)) = to; } + uint8_t *Base() { return base_; } + const uint8_t *Base() const { return base_; } + protected: Header() : base_(NULL) {} + explicit Header(void *base) : base_(static_cast(base)) {} + Header(void *base, Arity arity) : base_(static_cast(base)) { *reinterpret_cast(base_ + sizeof(Score)) = arity; } diff --git a/klm/search/nbest.cc b/klm/search/nbest.cc new file mode 100644 index 00000000..ec3322c9 --- /dev/null +++ b/klm/search/nbest.cc @@ -0,0 +1,106 @@ +#include "search/nbest.hh" + +#include "util/pool.hh" + +#include +#include +#include + +#include +#include + +namespace search { + +NBestList::NBestList(std::vector &partials, util::Pool &entry_pool, std::size_t keep) { + assert(!partials.empty()); + std::vector::iterator end; + if (partials.size() > keep) { + end = partials.begin() + keep; + std::nth_element(partials.begin(), end, partials.end(), std::greater()); + } else { + end = partials.end(); + } + for (std::vector::const_iterator i(partials.begin()); i != end; ++i) { + queue_.push(QueueEntry(entry_pool.Allocate(QueueEntry::Size(i->GetArity())), *i)); + } +} + +Score NBestList::TopAfterConstructor() const { + assert(revealed_.empty()); + return queue_.top().GetScore(); +} + +const std::vector &NBestList::Extract(util::Pool &pool, std::size_t n) { + while (revealed_.size() < n && !queue_.empty()) { + MoveTop(pool); + } + return revealed_; +} + +Score NBestList::Visit(util::Pool &pool, std::size_t index) { + if (index + 1 < revealed_.size()) + return revealed_[index + 1].GetScore() - revealed_[index].GetScore(); + if (queue_.empty()) + return -INFINITY; + if (index + 1 == revealed_.size()) + return queue_.top().GetScore() - revealed_[index].GetScore(); + assert(index == revealed_.size()); + + MoveTop(pool); + + if (queue_.empty()) return -INFINITY; + return queue_.top().GetScore() - revealed_[index].GetScore(); +} + +Applied NBestList::Get(util::Pool &pool, std::size_t index) { + assert(index <= revealed_.size()); + if (index == revealed_.size()) MoveTop(pool); + return revealed_[index]; +} + +void NBestList::MoveTop(util::Pool &pool) { + assert(!queue_.empty()); + QueueEntry entry(queue_.top()); + queue_.pop(); + RevealedRef *const children_begin = entry.Children(); + RevealedRef *const children_end = children_begin + entry.GetArity(); + Score basis = entry.GetScore(); + for (RevealedRef *child = children_begin; child != children_end; ++child) { + Score change = child->in_->Visit(pool, child->index_); + if (change != -INFINITY) { + assert(change < 0.001); + QueueEntry new_entry(pool.Allocate(QueueEntry::Size(entry.GetArity())), basis + change, entry.GetArity(), entry.GetNote()); + std::copy(children_begin, child, new_entry.Children()); + RevealedRef *update = new_entry.Children() + (child - children_begin); + update->in_ = child->in_; + update->index_ = child->index_ + 1; + std::copy(child + 1, children_end, update + 1); + queue_.push(new_entry); + } + // Gesmundo, A. and Henderson, J. Faster Cube Pruning, IWSLT 2010. + if (child->index_) break; + } + + // Convert QueueEntry to Applied. This leaves some unused memory. + void *overwrite = entry.Children(); + for (unsigned int i = 0; i < entry.GetArity(); ++i) { + RevealedRef from(*(static_cast(overwrite) + i)); + *(static_cast(overwrite) + i) = from.in_->Get(pool, from.index_); + } + revealed_.push_back(Applied(entry.Base())); +} + +NBestComplete NBest::Complete(std::vector &partials) { + assert(!partials.empty()); + NBestList *list = list_pool_.construct(partials, entry_pool_, config_.keep); + return NBestComplete( + list, + partials.front().CompletedState(), // All partials have the same state + list->TopAfterConstructor()); +} + +const std::vector &NBest::Extract(History history) { + return static_cast(history)->Extract(entry_pool_, config_.size); +} + +} // namespace search diff --git a/klm/search/nbest.hh b/klm/search/nbest.hh new file mode 100644 index 00000000..cb7651bc --- /dev/null +++ b/klm/search/nbest.hh @@ -0,0 +1,81 @@ +#ifndef SEARCH_NBEST__ +#define SEARCH_NBEST__ + +#include "search/applied.hh" +#include "search/config.hh" +#include "search/edge.hh" + +#include + +#include +#include +#include + +#include + +namespace search { + +class NBestList; + +class NBestList { + private: + class RevealedRef { + public: + explicit RevealedRef(History history) + : in_(static_cast(history)), index_(0) {} + + private: + friend class NBestList; + + NBestList *in_; + std::size_t index_; + }; + + typedef GenericApplied QueueEntry; + + public: + NBestList(std::vector &existing, util::Pool &entry_pool, std::size_t keep); + + Score TopAfterConstructor() const; + + const std::vector &Extract(util::Pool &pool, std::size_t n); + + private: + Score Visit(util::Pool &pool, std::size_t index); + + Applied Get(util::Pool &pool, std::size_t index); + + void MoveTop(util::Pool &pool); + + typedef std::vector Revealed; + Revealed revealed_; + + typedef std::priority_queue Queue; + Queue queue_; +}; + +class NBest { + public: + typedef std::vector Combine; + + explicit NBest(const NBestConfig &config) : config_(config) {} + + void Add(std::vector &existing, PartialEdge addition) const { + existing.push_back(addition); + } + + NBestComplete Complete(std::vector &partials); + + const std::vector &Extract(History root); + + private: + const NBestConfig config_; + + boost::object_pool list_pool_; + + util::Pool entry_pool_; +}; + +} // namespace search + +#endif // SEARCH_NBEST__ diff --git a/klm/search/note.hh b/klm/search/note.hh deleted file mode 100644 index 50bed06e..00000000 --- a/klm/search/note.hh +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef SEARCH_NOTE__ -#define SEARCH_NOTE__ - -namespace search { - -union Note { - const void *vp; -}; - -} // namespace search - -#endif // SEARCH_NOTE__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc index 5b00207e..0244a09f 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -1,7 +1,7 @@ #include "search/rule.hh" +#include "lm/model.hh" #include "search/context.hh" -#include "search/final.hh" #include @@ -9,35 +9,35 @@ namespace search { -template float ScoreRule(const Context &context, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing) { - unsigned int oov_count = 0; - float prob = 0.0; - const Model &model = context.LanguageModel(); - const lm::WordIndex oov = model.GetVocabulary().NotFound(); - for (std::vector::const_iterator word = words.begin(); ; ++word) { - lm::ngram::RuleScore scorer(model, *(writing++)); - // TODO: optimize - if (prepend_bos && (word == words.begin())) { - scorer.BeginSentence(); - } - for (; ; ++word) { - if (word == words.end()) { - prob += scorer.Finish(); - return static_cast(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM(); - } - if (*word == kNonTerminal) break; - if (*word == oov) ++oov_count; +template ScoreRuleRet ScoreRule(const Model &model, const std::vector &words, lm::ngram::ChartState *writing) { + ScoreRuleRet ret; + ret.prob = 0.0; + ret.oov = 0; + const lm::WordIndex oov = model.GetVocabulary().NotFound(), bos = model.GetVocabulary().BeginSentence(); + lm::ngram::RuleScore scorer(model, *(writing++)); + std::vector::const_iterator word = words.begin(); + if (word != words.end() && *word == bos) { + scorer.BeginSentence(); + ++word; + } + for (; word != words.end(); ++word) { + if (*word == kNonTerminal) { + ret.prob += scorer.Finish(); + scorer.Reset(*(writing++)); + } else { + if (*word == oov) ++ret.oov; scorer.Terminal(*word); } - prob += scorer.Finish(); } + ret.prob += scorer.Finish(); + return ret; } -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); -template float ScoreRule(const Context &model, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::RestProbingModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ProbingModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::TrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantTrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::ArrayTrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); +template ScoreRuleRet ScoreRule(const lm::ngram::QuantArrayTrieModel &model, const std::vector &words, lm::ngram::ChartState *writing); } // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh index 0ce2794d..43ca6162 100644 --- a/klm/search/rule.hh +++ b/klm/search/rule.hh @@ -9,11 +9,16 @@ namespace search { -template class Context; - const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; -template float ScoreRule(const Context &context, const std::vector &words, bool prepend_bos, lm::ngram::ChartState *state_out); +struct ScoreRuleRet { + Score prob; + unsigned int oov; +}; + +// Pass and normally. +// Indicate non-terminals with kNonTerminal. +template ScoreRuleRet ScoreRule(const Model &model, const std::vector &words, lm::ngram::ChartState *state_out); } // namespace search diff --git a/klm/search/types.hh b/klm/search/types.hh index 06eb5bfa..f9c849b3 100644 --- a/klm/search/types.hh +++ b/klm/search/types.hh @@ -3,12 +3,29 @@ #include +namespace lm { namespace ngram { class ChartState; } } + namespace search { typedef float Score; typedef uint32_t Arity; +union Note { + const void *vp; +}; + +typedef void *History; + +struct NBestComplete { + NBestComplete(History in_history, const lm::ngram::ChartState &in_state, Score in_score) + : history(in_history), state(&in_state), score(in_score) {} + + History history; + const lm::ngram::ChartState *state; + Score score; +}; + } // namespace search #endif // SEARCH_TYPES__ diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc index 11f4631f..45842982 100644 --- a/klm/search/vertex.cc +++ b/klm/search/vertex.cc @@ -19,21 +19,34 @@ struct GreaterByBound : public std::binary_functionSortAndSet(context, parent_ptr); + if (extend_.size() == 1) { + parent_ptr = extend_[0]; + extend_[0]->RecursiveSortAndSet(context, parent_ptr); context.DeleteVertexNode(this); return; } for (std::vector::iterator i = extend_.begin(); i != extend_.end(); ++i) { - (*i)->SortAndSet(context, &*i); + (*i)->RecursiveSortAndSet(context, *i); + } + std::sort(extend_.begin(), extend_.end(), GreaterByBound()); + bound_ = extend_.front()->Bound(); +} + +void VertexNode::SortAndSet(ContextBase &context) { + // This is the root. The root might be empty. + if (extend_.empty()) { + bound_ = -INFINITY; + return; + } + // The root cannot be replaced. There's always one transition. + for (std::vector::iterator i = extend_.begin(); i != extend_.end(); ++i) { + (*i)->RecursiveSortAndSet(context, *i); } std::sort(extend_.begin(), extend_.end(), GreaterByBound()); bound_ = extend_.front()->Bound(); diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index 52bc1dfe..10b3339b 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -2,7 +2,6 @@ #define SEARCH_VERTEX__ #include "lm/left.hh" -#include "search/final.hh" #include "search/types.hh" #include @@ -10,6 +9,7 @@ #include #include +#include #include namespace search { @@ -18,7 +18,7 @@ class ContextBase; class VertexNode { public: - VertexNode() {} + VertexNode() : end_() {} void InitRoot() { extend_.clear(); @@ -26,7 +26,7 @@ class VertexNode { state_.left.length = 0; state_.right.length = 0; right_full_ = false; - end_ = Final(); + end_ = History(); } lm::ngram::ChartState &MutableState() { return state_; } @@ -36,20 +36,21 @@ class VertexNode { extend_.push_back(next); } - void SetEnd(Final end) { - assert(!end_.Valid()); + void SetEnd(History end, Score score) { + assert(!end_); end_ = end; + bound_ = score; } - void SortAndSet(ContextBase &context, VertexNode **parent_pointer); + void SortAndSet(ContextBase &context); // Should only happen to a root node when the entire vertex is empty. bool Empty() const { - return !end_.Valid() && extend_.empty(); + return !end_ && extend_.empty(); } bool Complete() const { - return end_.Valid(); + return end_; } const lm::ngram::ChartState &State() const { return state_; } @@ -64,7 +65,7 @@ class VertexNode { } // Will be invalid unless this is a leaf. - const Final End() const { return end_; } + const History End() const { return end_; } const VertexNode &operator[](size_t index) const { return *extend_[index]; @@ -75,13 +76,15 @@ class VertexNode { } private: + void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent); + std::vector extend_; lm::ngram::ChartState state_; bool right_full_; Score bound_; - Final end_; + History end_; }; class PartialVertex { @@ -97,7 +100,7 @@ class PartialVertex { const lm::ngram::ChartState &State() const { return back_->State(); } bool RightFull() const { return back_->RightFull(); } - Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); } + Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); } unsigned char Length() const { return back_->Length(); } @@ -121,7 +124,7 @@ class PartialVertex { return ret; } - const Final End() const { + const History End() const { return back_->End(); } @@ -130,16 +133,18 @@ class PartialVertex { unsigned int index_; }; +template class VertexGenerator; + class Vertex { public: Vertex() {} PartialVertex RootPartial() const { return PartialVertex(root_); } - const Final BestChild() const { + const History BestChild() const { PartialVertex top(RootPartial()); if (top.Empty()) { - return Final(); + return History(); } else { PartialVertex continuation; while (!top.Complete()) { @@ -150,8 +155,8 @@ class Vertex { } private: - friend class VertexGenerator; - + template friend class VertexGenerator; + template friend class RootVertexGenerator; VertexNode root_; }; diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index 0945fe55..73139ffc 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -4,26 +4,18 @@ #include "search/context.hh" #include "search/edge.hh" +#include +#include + #include namespace search { -VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { - gen.root_.InitRoot(); -} - +#if BOOST_VERSION > 104200 namespace { const uint64_t kCompleteAdd = static_cast(-1); -// Parallel structure to VertexNode. -struct Trie { - Trie() : under(NULL) {} - - VertexNode *under; - boost::unordered_map extend; -}; - Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { Trie &next = node.extend[added]; if (!next.under) { @@ -39,19 +31,10 @@ Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::n return next; } -void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) { - Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote()); - Final *child_out = final.Children(); - const PartialVertex *part = partial.NT(); - const PartialVertex *const part_end_loop = part + partial.GetArity(); - for (; part != part_end_loop; ++part, ++child_out) - *child_out = part->End(); - - starter.under->SetEnd(final); -} +} // namespace -void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { - const lm::ngram::ChartState &state = partial.CompletedState(); +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end) { + const lm::ngram::ChartState &state = *end.state; unsigned char left = 0, right = 0; Trie *node = &root; @@ -77,18 +60,9 @@ void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { } node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); - CompleteTransition(context, *node, partial); + node->under->SetEnd(end.history, end.score); } -} // namespace - -void VertexGenerator::FinishedSearch() { - Trie root; - root.under = &gen_.root_; - for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) { - AddHypothesis(context_, root, i->second); - } - root.under->SortAndSet(context_, NULL); -} +#endif // BOOST_VERSION } // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 60e86112..da563c2d 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -2,9 +2,11 @@ #define SEARCH_VERTEX_GENERATOR__ #include "search/edge.hh" +#include "search/types.hh" #include "search/vertex.hh" #include +#include namespace lm { namespace ngram { @@ -15,21 +17,44 @@ class ChartState; namespace search { class ContextBase; -class Final; -class VertexGenerator { +#if BOOST_VERSION > 104200 +// Parallel structure to VertexNode. +struct Trie { + Trie() : under(NULL) {} + + VertexNode *under; + boost::unordered_map extend; +}; + +void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end); + +#endif // BOOST_VERSION + +// Output makes the single-best or n-best list. +template class VertexGenerator { public: - VertexGenerator(ContextBase &context, Vertex &gen); + VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) { + gen.root_.InitRoot(); + } void NewHypothesis(PartialEdge partial) { - const lm::ngram::ChartState &state = partial.CompletedState(); - std::pair ret(existing_.insert(std::make_pair(hash_value(state), partial))); - if (!ret.second && ret.first->second < partial) { - ret.first->second = partial; - } + nbest_.Add(existing_[hash_value(partial.CompletedState())], partial); } - void FinishedSearch(); + void FinishedSearch() { +#if BOOST_VERSION > 104200 + Trie root; + root.under = &gen_.root_; + for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) { + AddHypothesis(context_, root, nbest_.Complete(i->second)); + } + existing_.clear(); + root.under->SortAndSet(context_); +#else + UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search."); +#endif + } const Vertex &Generating() const { return gen_; } @@ -38,8 +63,35 @@ class VertexGenerator { Vertex &gen_; - typedef boost::unordered_map Existing; + typedef boost::unordered_map Existing; Existing existing_; + + Output &nbest_; +}; + +// Special case for root vertex: everything should come together into the root +// node. In theory, this should happen naturally due to state collapsing with +// and . If that's the case, VertexGenerator is fine, though it will +// make one connection. +template class RootVertexGenerator { + public: + RootVertexGenerator(Vertex &gen, Output &out) : gen_(gen), out_(out) {} + + void NewHypothesis(PartialEdge partial) { + out_.Add(combine_, partial); + } + + void FinishedSearch() { + gen_.root_.InitRoot(); + NBestComplete completed(out_.Complete(combine_)); + gen_.root_.SetEnd(completed.history, completed.score); + } + + private: + Vertex &gen_; + + typename Output::Combine combine_; + Output &out_; }; } // namespace search diff --git a/klm/search/weights.cc b/klm/search/weights.cc deleted file mode 100644 index d65471ad..00000000 --- a/klm/search/weights.cc +++ /dev/null @@ -1,71 +0,0 @@ -#include "search/weights.hh" -#include "util/tokenize_piece.hh" - -#include - -namespace search { - -namespace { -struct Insert { - void operator()(boost::unordered_map &map, StringPiece name, search::Score score) const { - std::string copy(name.data(), name.size()); - map[copy] = score; - } -}; - -struct DotProduct { - search::Score total; - DotProduct() : total(0.0) {} - - void operator()(const boost::unordered_map &map, StringPiece name, search::Score score) { - boost::unordered_map::const_iterator i(FindStringPiece(map, name)); - if (i != map.end()) - total += score * i->second; - } -}; - -template void Parse(StringPiece text, Map &map, Op &op) { - for (util::TokenIter spaces(text, ' '); spaces; ++spaces) { - util::TokenIter equals(*spaces, '='); - UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces); - StringPiece name(*equals); - UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces); - char *end; - // Assumes proper termination. - double value = std::strtod(equals->data(), &end); - UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals); - UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces); - op(map, name, value); - } -} - -} // namespace - -Weights::Weights(StringPiece text) { - Insert op; - Parse(text, map_, op); - lm_ = Steal("LanguageModel"); - oov_ = Steal("OOV"); - word_penalty_ = Steal("WordPenalty"); -} - -Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {} - -search::Score Weights::DotNoLM(StringPiece text) const { - DotProduct dot; - Parse(text, map_, dot); - return dot.total; -} - -float Weights::Steal(const std::string &str) { - Map::iterator i(map_.find(str)); - if (i == map_.end()) { - return 0.0; - } else { - float ret = i->second; - map_.erase(i); - return ret; - } -} - -} // namespace search diff --git a/klm/search/weights.hh b/klm/search/weights.hh deleted file mode 100644 index df1c419f..00000000 --- a/klm/search/weights.hh +++ /dev/null @@ -1,52 +0,0 @@ -// For now, the individual features are not kept. -#ifndef SEARCH_WEIGHTS__ -#define SEARCH_WEIGHTS__ - -#include "search/types.hh" -#include "util/exception.hh" -#include "util/string_piece.hh" - -#include - -#include - -namespace search { - -class WeightParseException : public util::Exception { - public: - WeightParseException() {} - ~WeightParseException() throw() {} -}; - -class Weights { - public: - // Parses weights, sets lm_weight_, removes it from map_. - explicit Weights(StringPiece text); - - // Just the three scores we care about adding. - Weights(Score lm, Score oov, Score word_penalty); - - Score DotNoLM(StringPiece text) const; - - Score LM() const { return lm_; } - - Score OOV() const { return oov_; } - - Score WordPenalty() const { return word_penalty_; } - - // Mostly for testing. - const boost::unordered_map &GetMap() const { return map_; } - - private: - float Steal(const std::string &str); - - typedef boost::unordered_map Map; - - Map map_; - - Score lm_, oov_, word_penalty_; -}; - -} // namespace search - -#endif // SEARCH_WEIGHTS__ diff --git a/klm/search/weights_test.cc b/klm/search/weights_test.cc deleted file mode 100644 index 4811ff06..00000000 --- a/klm/search/weights_test.cc +++ /dev/null @@ -1,38 +0,0 @@ -#include "search/weights.hh" - -#define BOOST_TEST_MODULE WeightTest -#include -#include - -namespace search { -namespace { - -#define CHECK_WEIGHT(value, string) \ - i = parsed.find(string); \ - BOOST_REQUIRE(i != parsed.end()); \ - BOOST_CHECK_CLOSE((value), i->second, 0.001); - -BOOST_AUTO_TEST_CASE(parse) { - // These are not real feature weights. - Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); - const boost::unordered_map &parsed = w.GetMap(); - boost::unordered_map::const_iterator i; - CHECK_WEIGHT(0.0, "rarity"); - CHECK_WEIGHT(0.0, "phrase-SGT"); - CHECK_WEIGHT(9.45117, "phrase-TGS"); - CHECK_WEIGHT(2.33833, "lexical-SGT"); - BOOST_CHECK(parsed.end() == parsed.find("lm")); - BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001); - CHECK_WEIGHT(-28.3317, "lexical-TGS"); - CHECK_WEIGHT(5.0, "glue?"); -} - -BOOST_AUTO_TEST_CASE(dot) { - Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); - BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001); - BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001); - BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001); -} - -} // namespace -} // namespace search -- cgit v1.2.3