From 143ba7317dcaee3058d66f9e6558316f88f95212 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Wed, 12 Sep 2012 12:01:26 +0100 Subject: Refactor search so that it knows even less, but keeps track of edge pointers --- klm/lm/word_index.hh | 3 +++ klm/search/context.hh | 1 - klm/search/final.hh | 12 +++++------ klm/search/rule.cc | 32 +++++++++------------------- klm/search/rule.hh | 21 ++++--------------- klm/search/vertex_generator.cc | 4 ++-- klm/search/word.hh | 47 ------------------------------------------ 7 files changed, 25 insertions(+), 95 deletions(-) delete mode 100644 klm/search/word.hh diff --git a/klm/lm/word_index.hh b/klm/lm/word_index.hh index 67841c30..e09557a7 100644 --- a/klm/lm/word_index.hh +++ b/klm/lm/word_index.hh @@ -2,8 +2,11 @@ #ifndef LM_WORD_INDEX__ #define LM_WORD_INDEX__ +#include + namespace lm { typedef unsigned int WordIndex; +const WordIndex kMaxWordIndex = UINT_MAX; } // namespace lm typedef lm::WordIndex LMWordIndex; diff --git a/klm/search/context.hh b/klm/search/context.hh index ae248549..27940053 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -6,7 +6,6 @@ #include "search/final.hh" #include "search/types.hh" #include "search/vertex.hh" -#include "search/word.hh" #include "util/exception.hh" #include diff --git a/klm/search/final.hh b/klm/search/final.hh index 24e6f0a5..823b8c1a 100644 --- a/klm/search/final.hh +++ b/klm/search/final.hh @@ -1,18 +1,20 @@ #ifndef SEARCH_FINAL__ #define SEARCH_FINAL__ -#include "search/rule.hh" +#include "search/arity.hh" #include "search/types.hh" #include namespace search { +class Edge; + class Final { public: typedef boost::array ChildArray; - void Reset(Score bound, const Rule &from, const Final &left, const Final &right) { + void Reset(Score bound, const Edge &from, const Final &left, const Final &right) { bound_ = bound; from_ = &from; children_[0] = &left; @@ -21,16 +23,14 @@ class Final { const ChildArray &Children() const { return children_; } - unsigned int ChildCount() const { return from_->Arity(); } - - const Rule &From() const { return *from_; } + const Edge &From() const { return *from_; } Score Bound() const { return bound_; } private: Score bound_; - const Rule *from_; + const Edge *from_; ChildArray children_; }; diff --git a/klm/search/rule.cc b/klm/search/rule.cc index a8b993eb..0a941527 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -9,47 +9,35 @@ namespace search { -template void Rule::FinishedAdding(const Context &context, Score additive, bool prepend_bos) { +template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos) { additive_ = additive; Score lm_score = 0.0; lexical_.clear(); const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound(); - for (std::vector::const_iterator word = items_.begin(); ; ++word) { + for (std::vector::const_iterator word = words.begin(); ; ++word) { lexical_.resize(lexical_.size() + 1); lm::ngram::RuleScore scorer(context.LanguageModel(), lexical_.back()); // TODO: optimize - if (prepend_bos && (word == items_.begin())) { + if (prepend_bos && (word == words.begin())) { scorer.BeginSentence(); } for (; ; ++word) { - if (word == items_.end()) { + if (word == words.end()) { lm_score += scorer.Finish(); bound_ = additive_ + context.GetWeights().LM() * lm_score; - assert(lexical_.size() == arity_ + 1); + arity_ = lexical_.size() - 1; return; } - if (!word->Terminal()) break; - if (word->Index() == oov) additive_ += context.GetWeights().OOV(); - scorer.Terminal(word->Index()); + if (*word == kNonTerminal) break; + if (*word == oov) additive_ += context.GetWeights().OOV(); + scorer.Terminal(*word); } lm_score += scorer.Finish(); } } -template void Rule::FinishedAdding(const Context &context, Score additive, bool prepend_bos); -template void Rule::FinishedAdding(const Context &context, Score additive, bool prepend_bos); - -std::ostream &operator<<(std::ostream &o, const Rule &rule) { - const Rule::ItemsRet &items = rule.Items(); - for (Rule::ItemsRet::const_iterator i = items.begin(); i != items.end(); ++i) { - if (i->Terminal()) { - o << i->String() << ' '; - } else { - o << "[] "; - } - } - return o; -} +template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos); +template void Rule::Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos); } // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh index 79192d40..920c64a7 100644 --- a/klm/search/rule.hh +++ b/klm/search/rule.hh @@ -2,9 +2,9 @@ #define SEARCH_RULE__ #include "lm/left.hh" +#include "lm/word_index.hh" #include "search/arity.hh" #include "search/types.hh" -#include "search/word.hh" #include @@ -19,14 +19,10 @@ class Rule { public: Rule() : arity_(0) {} - void AppendTerminal(Word w) { items_.push_back(w); } + static const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; - void AppendNonTerminal() { - items_.resize(items_.size() + 1); - ++arity_; - } - - template void FinishedAdding(const Context &context, Score additive, bool prepend_bos); + // Use kNonTerminal for non-terminals. + template void Init(const Context &context, Score additive, const std::vector &words, bool prepend_bos); Score Bound() const { return bound_; } @@ -38,23 +34,14 @@ class Rule { return lexical_[index]; } - // For printing. - typedef const std::vector ItemsRet; - ItemsRet &Items() const { return items_; } - private: Score bound_, additive_; unsigned int arity_; - // TODO: pool? - std::vector items_; - std::vector lexical_; }; -std::ostream &operator<<(std::ostream &o, const Rule &rule); - } // namespace search #endif // SEARCH_RULE__ diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index 0281fc37..78948c97 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -38,7 +38,7 @@ void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Ed // Found it already. Final &exists = *got.first->second; if (exists.Bound() < partial.score) { - exists.Reset(partial.score, from.GetRule(), partial.nt[0].End(), partial.nt[1].End()); + exists.Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End()); } --to_pop_; return; @@ -91,7 +91,7 @@ Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const assert(node.State().left.full == state.left.full); assert(!node.End()); Final *final = context_.NewFinal(); - final->Reset(partial.score, from.GetRule(), partial.nt[0].End(), partial.nt[1].End()); + final->Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End()); node.SetEnd(final); return final; } diff --git a/klm/search/word.hh b/klm/search/word.hh deleted file mode 100644 index e7a15be9..00000000 --- a/klm/search/word.hh +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef SEARCH_WORD__ -#define SEARCH_WORD__ - -#include "lm/word_index.hh" - -#include - -#include -#include - -namespace search { - -class Word { - public: - // Construct a non-terminal. - Word() : entry_(NULL) {} - - explicit Word(const std::pair &entry) { - entry_ = &entry; - } - - // Returns true for two non-terminals even if their labels are different (since we don't care about labels). - bool operator==(const Word &other) const { - return entry_ == other.entry_; - } - - bool Terminal() const { return entry_ != NULL; } - - const std::string &String() const { return entry_->first; } - - lm::WordIndex Index() const { return entry_->second; } - - protected: - friend size_t hash_value(const Word &word); - - const std::pair *Entry() const { return entry_; } - - private: - const std::pair *entry_; -}; - -inline size_t hash_value(const Word &word) { - return boost::hash_value(word.Entry()); -} - -} // namespace search -#endif // SEARCH_WORD__ -- cgit v1.2.3