diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-09-12 12:01:26 +0100 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2012-09-12 12:01:26 +0100 |
commit | c26c35a9bcbb4d42ae50ad0a75c1b5fb59702bd1 (patch) | |
tree | eced40cfee4bff7c4cd3fc644016e45f7903a55a | |
parent | 2ca3db90bd0a2e9a8619d2ebec7c6ac723838aca (diff) |
Refactor search so that it knows even less, but keeps track of edge pointers
-rw-r--r-- | klm/lm/word_index.hh | 3 | ||||
-rw-r--r-- | klm/search/context.hh | 1 | ||||
-rw-r--r-- | klm/search/final.hh | 12 | ||||
-rw-r--r-- | klm/search/rule.cc | 32 | ||||
-rw-r--r-- | klm/search/rule.hh | 21 | ||||
-rw-r--r-- | klm/search/vertex_generator.cc | 4 | ||||
-rw-r--r-- | klm/search/word.hh | 47 |
7 files changed, 25 insertions, 95 deletions
diff --git a/klm/lm/word_index.hh b/klm/lm/word_index.hh index 67841c30..e09557a7 100644 --- a/klm/lm/word_index.hh +++ b/klm/lm/word_index.hh @@ -2,8 +2,11 @@ #ifndef LM_WORD_INDEX__ #define LM_WORD_INDEX__ +#include <limits.h> + namespace lm { typedef unsigned int WordIndex; +const WordIndex kMaxWordIndex = UINT_MAX; } // namespace lm typedef lm::WordIndex LMWordIndex; diff --git a/klm/search/context.hh b/klm/search/context.hh index ae248549..27940053 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -6,7 +6,6 @@ #include "search/final.hh" #include "search/types.hh" #include "search/vertex.hh" -#include "search/word.hh" #include "util/exception.hh" #include <boost/pool/object_pool.hpp> diff --git a/klm/search/final.hh b/klm/search/final.hh index 24e6f0a5..823b8c1a 100644 --- a/klm/search/final.hh +++ b/klm/search/final.hh @@ -1,18 +1,20 @@ #ifndef SEARCH_FINAL__ #define SEARCH_FINAL__ -#include "search/rule.hh" +#include "search/arity.hh" #include "search/types.hh" #include <boost/array.hpp> namespace search { +class Edge; + class Final { public: typedef boost::array<const Final*, search::kMaxArity> ChildArray; - void Reset(Score bound, const Rule &from, const Final &left, const Final &right) { + void Reset(Score bound, const Edge &from, const Final &left, const Final &right) { bound_ = bound; from_ = &from; children_[0] = &left; @@ -21,16 +23,14 @@ class Final { const ChildArray &Children() const { return children_; } - unsigned int ChildCount() const { return from_->Arity(); } - - const Rule &From() const { return *from_; } + const Edge &From() const { return *from_; } Score Bound() const { return bound_; } private: Score bound_; - const Rule *from_; + const Edge *from_; ChildArray children_; }; diff --git a/klm/search/rule.cc b/klm/search/rule.cc index a8b993eb..0a941527 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -9,47 +9,35 @@ namespace search { -template <class Model> void Rule::FinishedAdding(const Context<Model> &context, Score additive, bool prepend_bos) { +template <class Model> void Rule::Init(const Context<Model> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos) { additive_ = additive; Score lm_score = 0.0; lexical_.clear(); const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound(); - for (std::vector<Word>::const_iterator word = items_.begin(); ; ++word) { + for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) { lexical_.resize(lexical_.size() + 1); lm::ngram::RuleScore<Model> scorer(context.LanguageModel(), lexical_.back()); // TODO: optimize - if (prepend_bos && (word == items_.begin())) { + if (prepend_bos && (word == words.begin())) { scorer.BeginSentence(); } for (; ; ++word) { - if (word == items_.end()) { + if (word == words.end()) { lm_score += scorer.Finish(); bound_ = additive_ + context.GetWeights().LM() * lm_score; - assert(lexical_.size() == arity_ + 1); + arity_ = lexical_.size() - 1; return; } - if (!word->Terminal()) break; - if (word->Index() == oov) additive_ += context.GetWeights().OOV(); - scorer.Terminal(word->Index()); + if (*word == kNonTerminal) break; + if (*word == oov) additive_ += context.GetWeights().OOV(); + scorer.Terminal(*word); } lm_score += scorer.Finish(); } } -template void Rule::FinishedAdding(const Context<lm::ngram::RestProbingModel> &context, Score additive, bool prepend_bos); -template void Rule::FinishedAdding(const Context<lm::ngram::ProbingModel> &context, Score additive, bool prepend_bos); - -std::ostream &operator<<(std::ostream &o, const Rule &rule) { - const Rule::ItemsRet &items = rule.Items(); - for (Rule::ItemsRet::const_iterator i = items.begin(); i != items.end(); ++i) { - if (i->Terminal()) { - o << i->String() << ' '; - } else { - o << "[] "; - } - } - return o; -} +template void Rule::Init(const Context<lm::ngram::RestProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos); +template void Rule::Init(const Context<lm::ngram::ProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos); } // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh index 79192d40..920c64a7 100644 --- a/klm/search/rule.hh +++ b/klm/search/rule.hh @@ -2,9 +2,9 @@ #define SEARCH_RULE__ #include "lm/left.hh" +#include "lm/word_index.hh" #include "search/arity.hh" #include "search/types.hh" -#include "search/word.hh" #include <boost/array.hpp> @@ -19,14 +19,10 @@ class Rule { public: Rule() : arity_(0) {} - void AppendTerminal(Word w) { items_.push_back(w); } + static const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; - void AppendNonTerminal() { - items_.resize(items_.size() + 1); - ++arity_; - } - - template <class Model> void FinishedAdding(const Context<Model> &context, Score additive, bool prepend_bos); + // Use kNonTerminal for non-terminals. + template <class Model> void Init(const Context<Model> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos); Score Bound() const { return bound_; } @@ -38,23 +34,14 @@ class Rule { return lexical_[index]; } - // For printing. - typedef const std::vector<Word> ItemsRet; - ItemsRet &Items() const { return items_; } - private: Score bound_, additive_; unsigned int arity_; - // TODO: pool? - std::vector<Word> items_; - std::vector<lm::ngram::ChartState> lexical_; }; -std::ostream &operator<<(std::ostream &o, const Rule &rule); - } // namespace search #endif // SEARCH_RULE__ diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index 0281fc37..78948c97 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -38,7 +38,7 @@ void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Ed // Found it already. Final &exists = *got.first->second; if (exists.Bound() < partial.score) { - exists.Reset(partial.score, from.GetRule(), partial.nt[0].End(), partial.nt[1].End()); + exists.Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End()); } --to_pop_; return; @@ -91,7 +91,7 @@ Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const assert(node.State().left.full == state.left.full); assert(!node.End()); Final *final = context_.NewFinal(); - final->Reset(partial.score, from.GetRule(), partial.nt[0].End(), partial.nt[1].End()); + final->Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End()); node.SetEnd(final); return final; } diff --git a/klm/search/word.hh b/klm/search/word.hh deleted file mode 100644 index e7a15be9..00000000 --- a/klm/search/word.hh +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef SEARCH_WORD__ -#define SEARCH_WORD__ - -#include "lm/word_index.hh" - -#include <boost/functional/hash.hpp> - -#include <string> -#include <utility> - -namespace search { - -class Word { - public: - // Construct a non-terminal. - Word() : entry_(NULL) {} - - explicit Word(const std::pair<const std::string, lm::WordIndex> &entry) { - entry_ = &entry; - } - - // Returns true for two non-terminals even if their labels are different (since we don't care about labels). - bool operator==(const Word &other) const { - return entry_ == other.entry_; - } - - bool Terminal() const { return entry_ != NULL; } - - const std::string &String() const { return entry_->first; } - - lm::WordIndex Index() const { return entry_->second; } - - protected: - friend size_t hash_value(const Word &word); - - const std::pair<const std::string, lm::WordIndex> *Entry() const { return entry_; } - - private: - const std::pair<const std::string, lm::WordIndex> *entry_; -}; - -inline size_t hash_value(const Word &word) { - return boost::hash_value(word.Entry()); -} - -} // namespace search -#endif // SEARCH_WORD__ |