diff options
Diffstat (limited to 'klm')
| -rw-r--r-- | klm/lm/word_index.hh | 3 | ||||
| -rw-r--r-- | klm/search/context.hh | 1 | ||||
| -rw-r--r-- | klm/search/final.hh | 12 | ||||
| -rw-r--r-- | klm/search/rule.cc | 32 | ||||
| -rw-r--r-- | klm/search/rule.hh | 21 | ||||
| -rw-r--r-- | klm/search/vertex_generator.cc | 4 | ||||
| -rw-r--r-- | klm/search/word.hh | 47 | 
7 files changed, 25 insertions, 95 deletions
| diff --git a/klm/lm/word_index.hh b/klm/lm/word_index.hh index 67841c30..e09557a7 100644 --- a/klm/lm/word_index.hh +++ b/klm/lm/word_index.hh @@ -2,8 +2,11 @@  #ifndef LM_WORD_INDEX__  #define LM_WORD_INDEX__ +#include <limits.h> +  namespace lm {  typedef unsigned int WordIndex; +const WordIndex kMaxWordIndex = UINT_MAX;  } // namespace lm  typedef lm::WordIndex LMWordIndex; diff --git a/klm/search/context.hh b/klm/search/context.hh index ae248549..27940053 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -6,7 +6,6 @@  #include "search/final.hh"  #include "search/types.hh"  #include "search/vertex.hh" -#include "search/word.hh"  #include "util/exception.hh"  #include <boost/pool/object_pool.hpp> diff --git a/klm/search/final.hh b/klm/search/final.hh index 24e6f0a5..823b8c1a 100644 --- a/klm/search/final.hh +++ b/klm/search/final.hh @@ -1,18 +1,20 @@  #ifndef SEARCH_FINAL__  #define SEARCH_FINAL__ -#include "search/rule.hh" +#include "search/arity.hh"  #include "search/types.hh"  #include <boost/array.hpp>  namespace search { +class Edge; +  class Final {    public:      typedef boost::array<const Final*, search::kMaxArity> ChildArray; -    void Reset(Score bound, const Rule &from, const Final &left, const Final &right) { +    void Reset(Score bound, const Edge &from, const Final &left, const Final &right) {        bound_ = bound;        from_ = &from;        children_[0] = &left; @@ -21,16 +23,14 @@ class Final {      const ChildArray &Children() const { return children_; } -    unsigned int ChildCount() const { return from_->Arity(); } - -    const Rule &From() const { return *from_; } +    const Edge &From() const { return *from_; }      Score Bound() const { return bound_; }    private:      Score bound_; -    const Rule *from_; +    const Edge *from_;      ChildArray children_;  }; diff --git a/klm/search/rule.cc b/klm/search/rule.cc index a8b993eb..0a941527 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -9,47 +9,35 @@  namespace search { -template <class Model> void Rule::FinishedAdding(const Context<Model> &context, Score additive, bool prepend_bos) { +template <class Model> void Rule::Init(const Context<Model> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos) {    additive_ = additive;    Score lm_score = 0.0;    lexical_.clear();    const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound(); -  for (std::vector<Word>::const_iterator word = items_.begin(); ; ++word) { +  for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) {      lexical_.resize(lexical_.size() + 1);      lm::ngram::RuleScore<Model> scorer(context.LanguageModel(), lexical_.back());      // TODO: optimize -    if (prepend_bos && (word == items_.begin())) { +    if (prepend_bos && (word == words.begin())) {        scorer.BeginSentence();      }      for (; ; ++word) { -      if (word == items_.end()) { +      if (word == words.end()) {          lm_score += scorer.Finish();          bound_ = additive_ + context.GetWeights().LM() * lm_score; -        assert(lexical_.size() == arity_ + 1); +        arity_ = lexical_.size() - 1;           return;        } -      if (!word->Terminal()) break; -      if (word->Index() == oov) additive_ += context.GetWeights().OOV(); -      scorer.Terminal(word->Index()); +      if (*word == kNonTerminal) break; +      if (*word == oov) additive_ += context.GetWeights().OOV(); +      scorer.Terminal(*word);      }      lm_score += scorer.Finish();    }  } -template void Rule::FinishedAdding(const Context<lm::ngram::RestProbingModel> &context, Score additive, bool prepend_bos); -template void Rule::FinishedAdding(const Context<lm::ngram::ProbingModel> &context, Score additive, bool prepend_bos); - -std::ostream &operator<<(std::ostream &o, const Rule &rule) { -  const Rule::ItemsRet &items = rule.Items(); -  for (Rule::ItemsRet::const_iterator i = items.begin(); i != items.end(); ++i) { -    if (i->Terminal()) { -      o << i->String() << ' '; -    } else { -      o << "[] "; -    } -  } -  return o; -} +template void Rule::Init(const Context<lm::ngram::RestProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos); +template void Rule::Init(const Context<lm::ngram::ProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos);  } // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh index 79192d40..920c64a7 100644 --- a/klm/search/rule.hh +++ b/klm/search/rule.hh @@ -2,9 +2,9 @@  #define SEARCH_RULE__  #include "lm/left.hh" +#include "lm/word_index.hh"  #include "search/arity.hh"  #include "search/types.hh" -#include "search/word.hh"  #include <boost/array.hpp> @@ -19,14 +19,10 @@ class Rule {    public:      Rule() : arity_(0) {} -    void AppendTerminal(Word w) { items_.push_back(w); } +    static const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; -    void AppendNonTerminal() { -      items_.resize(items_.size() + 1); -      ++arity_; -    } - -    template <class Model> void FinishedAdding(const Context<Model> &context, Score additive, bool prepend_bos); +    // Use kNonTerminal for non-terminals.   +    template <class Model> void Init(const Context<Model> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos);      Score Bound() const { return bound_; } @@ -38,23 +34,14 @@ class Rule {        return lexical_[index];      } -    // For printing.   -    typedef const std::vector<Word> ItemsRet; -    ItemsRet &Items() const { return items_; } -    private:      Score bound_, additive_;      unsigned int arity_; -    // TODO: pool? -    std::vector<Word> items_; -      std::vector<lm::ngram::ChartState> lexical_;  }; -std::ostream &operator<<(std::ostream &o, const Rule &rule); -  } // namespace search  #endif // SEARCH_RULE__ diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index 0281fc37..78948c97 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -38,7 +38,7 @@ void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Ed      // Found it already.        Final &exists = *got.first->second;      if (exists.Bound() < partial.score) { -      exists.Reset(partial.score, from.GetRule(), partial.nt[0].End(), partial.nt[1].End()); +      exists.Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End());      }      --to_pop_;      return; @@ -91,7 +91,7 @@ Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const    assert(node.State().left.full == state.left.full);    assert(!node.End());    Final *final = context_.NewFinal(); -  final->Reset(partial.score, from.GetRule(), partial.nt[0].End(), partial.nt[1].End()); +  final->Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End());    node.SetEnd(final);    return final;  } diff --git a/klm/search/word.hh b/klm/search/word.hh deleted file mode 100644 index e7a15be9..00000000 --- a/klm/search/word.hh +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef SEARCH_WORD__ -#define SEARCH_WORD__ - -#include "lm/word_index.hh" - -#include <boost/functional/hash.hpp> - -#include <string> -#include <utility> - -namespace search { - -class Word { -  public: -    // Construct a non-terminal. -    Word() : entry_(NULL) {} - -    explicit Word(const std::pair<const std::string, lm::WordIndex> &entry) { -      entry_ = &entry; -    } - -    // Returns true for two non-terminals even if their labels are different (since we don't care about labels). -    bool operator==(const Word &other) const { -      return entry_ == other.entry_; -    } - -    bool Terminal() const { return entry_ != NULL; } - -    const std::string &String() const { return entry_->first; } - -    lm::WordIndex Index() const { return entry_->second; } - -  protected: -    friend size_t hash_value(const Word &word); - -    const std::pair<const std::string, lm::WordIndex> *Entry() const { return entry_; } - -  private: -    const std::pair<const std::string, lm::WordIndex> *entry_; -}; - -inline size_t hash_value(const Word &word) { -  return boost::hash_value(word.Entry()); -} - -} // namespace search -#endif // SEARCH_WORD__ | 
