summaryrefslogtreecommitdiff
path: root/klm
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-09-12 12:01:26 +0100
committerKenneth Heafield <github@kheafield.com>2012-09-12 12:01:26 +0100
commitc26c35a9bcbb4d42ae50ad0a75c1b5fb59702bd1 (patch)
treeeced40cfee4bff7c4cd3fc644016e45f7903a55a /klm
parent2ca3db90bd0a2e9a8619d2ebec7c6ac723838aca (diff)
Refactor search so that it knows even less, but keeps track of edge pointers
Diffstat (limited to 'klm')
-rw-r--r--klm/lm/word_index.hh3
-rw-r--r--klm/search/context.hh1
-rw-r--r--klm/search/final.hh12
-rw-r--r--klm/search/rule.cc32
-rw-r--r--klm/search/rule.hh21
-rw-r--r--klm/search/vertex_generator.cc4
-rw-r--r--klm/search/word.hh47
7 files changed, 25 insertions, 95 deletions
diff --git a/klm/lm/word_index.hh b/klm/lm/word_index.hh
index 67841c30..e09557a7 100644
--- a/klm/lm/word_index.hh
+++ b/klm/lm/word_index.hh
@@ -2,8 +2,11 @@
#ifndef LM_WORD_INDEX__
#define LM_WORD_INDEX__
+#include <limits.h>
+
namespace lm {
typedef unsigned int WordIndex;
+const WordIndex kMaxWordIndex = UINT_MAX;
} // namespace lm
typedef lm::WordIndex LMWordIndex;
diff --git a/klm/search/context.hh b/klm/search/context.hh
index ae248549..27940053 100644
--- a/klm/search/context.hh
+++ b/klm/search/context.hh
@@ -6,7 +6,6 @@
#include "search/final.hh"
#include "search/types.hh"
#include "search/vertex.hh"
-#include "search/word.hh"
#include "util/exception.hh"
#include <boost/pool/object_pool.hpp>
diff --git a/klm/search/final.hh b/klm/search/final.hh
index 24e6f0a5..823b8c1a 100644
--- a/klm/search/final.hh
+++ b/klm/search/final.hh
@@ -1,18 +1,20 @@
#ifndef SEARCH_FINAL__
#define SEARCH_FINAL__
-#include "search/rule.hh"
+#include "search/arity.hh"
#include "search/types.hh"
#include <boost/array.hpp>
namespace search {
+class Edge;
+
class Final {
public:
typedef boost::array<const Final*, search::kMaxArity> ChildArray;
- void Reset(Score bound, const Rule &from, const Final &left, const Final &right) {
+ void Reset(Score bound, const Edge &from, const Final &left, const Final &right) {
bound_ = bound;
from_ = &from;
children_[0] = &left;
@@ -21,16 +23,14 @@ class Final {
const ChildArray &Children() const { return children_; }
- unsigned int ChildCount() const { return from_->Arity(); }
-
- const Rule &From() const { return *from_; }
+ const Edge &From() const { return *from_; }
Score Bound() const { return bound_; }
private:
Score bound_;
- const Rule *from_;
+ const Edge *from_;
ChildArray children_;
};
diff --git a/klm/search/rule.cc b/klm/search/rule.cc
index a8b993eb..0a941527 100644
--- a/klm/search/rule.cc
+++ b/klm/search/rule.cc
@@ -9,47 +9,35 @@
namespace search {
-template <class Model> void Rule::FinishedAdding(const Context<Model> &context, Score additive, bool prepend_bos) {
+template <class Model> void Rule::Init(const Context<Model> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos) {
additive_ = additive;
Score lm_score = 0.0;
lexical_.clear();
const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound();
- for (std::vector<Word>::const_iterator word = items_.begin(); ; ++word) {
+ for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) {
lexical_.resize(lexical_.size() + 1);
lm::ngram::RuleScore<Model> scorer(context.LanguageModel(), lexical_.back());
// TODO: optimize
- if (prepend_bos && (word == items_.begin())) {
+ if (prepend_bos && (word == words.begin())) {
scorer.BeginSentence();
}
for (; ; ++word) {
- if (word == items_.end()) {
+ if (word == words.end()) {
lm_score += scorer.Finish();
bound_ = additive_ + context.GetWeights().LM() * lm_score;
- assert(lexical_.size() == arity_ + 1);
+ arity_ = lexical_.size() - 1;
return;
}
- if (!word->Terminal()) break;
- if (word->Index() == oov) additive_ += context.GetWeights().OOV();
- scorer.Terminal(word->Index());
+ if (*word == kNonTerminal) break;
+ if (*word == oov) additive_ += context.GetWeights().OOV();
+ scorer.Terminal(*word);
}
lm_score += scorer.Finish();
}
}
-template void Rule::FinishedAdding(const Context<lm::ngram::RestProbingModel> &context, Score additive, bool prepend_bos);
-template void Rule::FinishedAdding(const Context<lm::ngram::ProbingModel> &context, Score additive, bool prepend_bos);
-
-std::ostream &operator<<(std::ostream &o, const Rule &rule) {
- const Rule::ItemsRet &items = rule.Items();
- for (Rule::ItemsRet::const_iterator i = items.begin(); i != items.end(); ++i) {
- if (i->Terminal()) {
- o << i->String() << ' ';
- } else {
- o << "[] ";
- }
- }
- return o;
-}
+template void Rule::Init(const Context<lm::ngram::RestProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos);
+template void Rule::Init(const Context<lm::ngram::ProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos);
} // namespace search
diff --git a/klm/search/rule.hh b/klm/search/rule.hh
index 79192d40..920c64a7 100644
--- a/klm/search/rule.hh
+++ b/klm/search/rule.hh
@@ -2,9 +2,9 @@
#define SEARCH_RULE__
#include "lm/left.hh"
+#include "lm/word_index.hh"
#include "search/arity.hh"
#include "search/types.hh"
-#include "search/word.hh"
#include <boost/array.hpp>
@@ -19,14 +19,10 @@ class Rule {
public:
Rule() : arity_(0) {}
- void AppendTerminal(Word w) { items_.push_back(w); }
+ static const lm::WordIndex kNonTerminal = lm::kMaxWordIndex;
- void AppendNonTerminal() {
- items_.resize(items_.size() + 1);
- ++arity_;
- }
-
- template <class Model> void FinishedAdding(const Context<Model> &context, Score additive, bool prepend_bos);
+ // Use kNonTerminal for non-terminals.
+ template <class Model> void Init(const Context<Model> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos);
Score Bound() const { return bound_; }
@@ -38,23 +34,14 @@ class Rule {
return lexical_[index];
}
- // For printing.
- typedef const std::vector<Word> ItemsRet;
- ItemsRet &Items() const { return items_; }
-
private:
Score bound_, additive_;
unsigned int arity_;
- // TODO: pool?
- std::vector<Word> items_;
-
std::vector<lm::ngram::ChartState> lexical_;
};
-std::ostream &operator<<(std::ostream &o, const Rule &rule);
-
} // namespace search
#endif // SEARCH_RULE__
diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc
index 0281fc37..78948c97 100644
--- a/klm/search/vertex_generator.cc
+++ b/klm/search/vertex_generator.cc
@@ -38,7 +38,7 @@ void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Ed
// Found it already.
Final &exists = *got.first->second;
if (exists.Bound() < partial.score) {
- exists.Reset(partial.score, from.GetRule(), partial.nt[0].End(), partial.nt[1].End());
+ exists.Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End());
}
--to_pop_;
return;
@@ -91,7 +91,7 @@ Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const
assert(node.State().left.full == state.left.full);
assert(!node.End());
Final *final = context_.NewFinal();
- final->Reset(partial.score, from.GetRule(), partial.nt[0].End(), partial.nt[1].End());
+ final->Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End());
node.SetEnd(final);
return final;
}
diff --git a/klm/search/word.hh b/klm/search/word.hh
deleted file mode 100644
index e7a15be9..00000000
--- a/klm/search/word.hh
+++ /dev/null
@@ -1,47 +0,0 @@
-#ifndef SEARCH_WORD__
-#define SEARCH_WORD__
-
-#include "lm/word_index.hh"
-
-#include <boost/functional/hash.hpp>
-
-#include <string>
-#include <utility>
-
-namespace search {
-
-class Word {
- public:
- // Construct a non-terminal.
- Word() : entry_(NULL) {}
-
- explicit Word(const std::pair<const std::string, lm::WordIndex> &entry) {
- entry_ = &entry;
- }
-
- // Returns true for two non-terminals even if their labels are different (since we don't care about labels).
- bool operator==(const Word &other) const {
- return entry_ == other.entry_;
- }
-
- bool Terminal() const { return entry_ != NULL; }
-
- const std::string &String() const { return entry_->first; }
-
- lm::WordIndex Index() const { return entry_->second; }
-
- protected:
- friend size_t hash_value(const Word &word);
-
- const std::pair<const std::string, lm::WordIndex> *Entry() const { return entry_; }
-
- private:
- const std::pair<const std::string, lm::WordIndex> *entry_;
-};
-
-inline size_t hash_value(const Word &word) {
- return boost::hash_value(word.Entry());
-}
-
-} // namespace search
-#endif // SEARCH_WORD__