diff options
| author | Paul Baltescu <pauldb89@gmail.com> | 2013-06-19 15:06:34 +0100 | 
|---|---|---|
| committer | Paul Baltescu <pauldb89@gmail.com> | 2013-06-19 15:06:34 +0100 | 
| commit | 459775095b46b4625ce26ea5a34001ec74ab3aa8 (patch) | |
| tree | 844d1a650a302114ae619d37b8778ab66207a834 /klm/search | |
| parent | 02099a01350a41a99ec400e9b29df08a01d88979 (diff) | |
| parent | 0dc7755f7fb1ef15db5a60c70866aa61b6367898 (diff) | |
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'klm/search')
| -rw-r--r-- | klm/search/Makefile.am | 17 | ||||
| -rw-r--r-- | klm/search/context.hh | 12 | ||||
| -rw-r--r-- | klm/search/edge_generator.cc | 12 | ||||
| -rw-r--r-- | klm/search/vertex.cc | 204 | ||||
| -rw-r--r-- | klm/search/vertex.hh | 121 | ||||
| -rw-r--r-- | klm/search/vertex_generator.hh | 36 | 
6 files changed, 271 insertions, 131 deletions
| diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am index 03554276..b8c8a050 100644 --- a/klm/search/Makefile.am +++ b/klm/search/Makefile.am @@ -1,23 +1,10 @@  noinst_LIBRARIES = libksearch.a  libksearch_a_SOURCES = \ -  applied.hh \ -  config.hh \ -  context.hh \ -  dedupe.hh \ -  edge.hh \ -  edge_generator.hh \ -  header.hh \ -  nbest.hh \ -  rule.hh \ -  types.hh \ -  vertex.hh \ -  vertex_generator.hh \    edge_generator.cc \  	nbest.cc \    rule.cc \ -  vertex.cc \ -  vertex_generator.cc +  vertex.cc -AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/klm +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/search/context.hh b/klm/search/context.hh index 08f21bbf..c3c8e53b 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -12,16 +12,6 @@ class ContextBase {    public:      explicit ContextBase(const Config &config) : config_(config) {} -    VertexNode *NewVertexNode() { -      VertexNode *ret = vertex_node_pool_.construct(); -      assert(ret); -      return ret; -    } - -    void DeleteVertexNode(VertexNode *node) { -      vertex_node_pool_.destroy(node); -    } -      unsigned int PopLimit() const { return config_.PopLimit(); }      Score LMWeight() const { return config_.LMWeight(); } @@ -29,8 +19,6 @@ class ContextBase {      const Config &GetConfig() const { return config_; }    private: -    boost::object_pool<VertexNode> vertex_node_pool_; -      Config config_;  }; diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index eacf5de5..dd9d61e4 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -54,20 +54,20 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) {    Arity victim = 0;    Arity victim_completed;    Arity incomplete; +  unsigned char lowest_niceness = 255;    // Select victim or return if complete.       {      Arity completed = 0; -    unsigned char lowest_length = 255;      for (Arity i = 0; i != arity; ++i) {        if (top_nt[i].Complete()) {          ++completed; -      } else if (top_nt[i].Length() < lowest_length) { -        lowest_length = top_nt[i].Length(); +      } else if (top_nt[i].Niceness() < lowest_niceness) { +        lowest_niceness = top_nt[i].Niceness();          victim = i;          victim_completed = completed;        }      } -    if (lowest_length == 255) { +    if (lowest_niceness == 255) {        return top;      }      incomplete = arity - completed; @@ -92,10 +92,14 @@ template <class Model> PartialEdge EdgeGenerator::Pop(Context<Model> &context) {      generate_.push(alternate);    } +#ifndef NDEBUG   +  Score before = top.GetScore(); +#endif    // top is now the continuation.    FastScore(context, victim, victim - victim_completed, incomplete, old_value, top);    // TODO: dedupe?      generate_.push(top); +  assert(lowest_niceness != 254 || top.GetScore() == before);    // Invalid indicates no new hypothesis generated.      return PartialEdge(); diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc index 45842982..bf40810e 100644 --- a/klm/search/vertex.cc +++ b/klm/search/vertex.cc @@ -2,6 +2,8 @@  #include "search/context.hh" +#include <boost/unordered_map.hpp> +  #include <algorithm>  #include <functional> @@ -11,45 +13,193 @@ namespace search {  namespace { -struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> { -  bool operator()(const VertexNode *first, const VertexNode *second) const { -    return first->Bound() > second->Bound(); +const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); + +class DivideLeft { +  public: +    explicit DivideLeft(unsigned char index) +      : index_(index) {} + +    uint64_t operator()(const lm::ngram::ChartState &state) const { +      return (index_ < state.left.length) ?  +        state.left.pointers[index_] : +        (kCompleteAdd - state.left.full); +    } + +  private: +    unsigned char index_; +}; + +class DivideRight { +  public: +    explicit DivideRight(unsigned char index) +      : index_(index) {} + +    uint64_t operator()(const lm::ngram::ChartState &state) const { +      return (index_ < state.right.length) ? +        static_cast<uint64_t>(state.right.words[index_]) : +        (kCompleteAdd - state.left.full); +    } + +  private: +    unsigned char index_; +}; + +template <class Divider> void Split(const Divider ÷r, const std::vector<HypoState> &hypos, std::vector<VertexNode> &extend) { +  // Map from divider to index in extend. +  typedef boost::unordered_map<uint64_t, std::size_t> Lookup; +  Lookup lookup; +  for (std::vector<HypoState>::const_iterator i = hypos.begin(); i != hypos.end(); ++i) { +    uint64_t key = divider(i->state); +    std::pair<Lookup::iterator, bool> res(lookup.insert(std::make_pair(key, extend.size()))); +    if (res.second) { +      extend.resize(extend.size() + 1); +      extend.back().AppendHypothesis(*i); +    } else { +      extend[res.first->second].AppendHypothesis(*i); +    }    } +  //assert((extend.size() != 1) || (hypos.size() == 1)); +} + +lm::WordIndex Identify(const lm::ngram::Right &right, unsigned char index) { +  return right.words[index]; +} + +uint64_t Identify(const lm::ngram::Left &left, unsigned char index) { +  return left.pointers[index]; +} + +template <class Side> class DetermineSame { +  public: +    DetermineSame(const Side &side, unsigned char guaranteed)  +      : side_(side), guaranteed_(guaranteed), shared_(side.length), complete_(true) {} + +    void Consider(const Side &other) { +      if (shared_ != other.length) { +        complete_ = false; +        if (shared_ > other.length) +          shared_ = other.length; +      } +      for (unsigned char i = guaranteed_; i < shared_; ++i) { +        if (Identify(side_, i) != Identify(other, i)) { +          shared_ = i; +          complete_ = false; +          return; +        } +      } +    } + +    unsigned char Shared() const { return shared_; } + +    bool Complete() const { return complete_; } + +  private: +    const Side &side_; +    unsigned char guaranteed_, shared_; +    bool complete_;  }; +// Custom enum to save memory: valid values of policy_. +// Alternate and there is still alternation to do. +const unsigned char kPolicyAlternate = 0; +// Branch based on left state only, because right ran out or this is a left tree. +const unsigned char kPolicyOneLeft = 1; +// Branch based on right state only. +const unsigned char kPolicyOneRight = 2; +// Reveal everything in the next branch.  Used to terminate the left/right policies. +//    static const unsigned char kPolicyEverything = 3; + +} // namespace + +namespace { +struct GreaterByScore : public std::binary_function<const HypoState &, const HypoState &, bool> { +  bool operator()(const HypoState &first, const HypoState &second) const { +    return first.score > second.score; +  } +};  } // namespace -void VertexNode::RecursiveSortAndSet(ContextBase &context, VertexNode *&parent_ptr) { -  if (Complete()) { -    assert(end_); -    assert(extend_.empty()); -    return; +void VertexNode::FinishRoot() { +  std::sort(hypos_.begin(), hypos_.end(), GreaterByScore()); +  extend_.clear(); +  // HACK: extend to one hypo so that root can be blank. +  state_.left.full = false; +  state_.left.length = 0; +  state_.right.length = 0; +  right_full_ = false; +  niceness_ = 0; +  policy_ = kPolicyAlternate; +  if (hypos_.size() == 1) { +    extend_.resize(1); +    extend_.front().AppendHypothesis(hypos_.front()); +    extend_.front().FinishedAppending(0, 0); +  } +  if (hypos_.empty()) { +    bound_ = -INFINITY; +  } else { +    bound_ = hypos_.front().score;    } -  if (extend_.size() == 1) { -    parent_ptr = extend_[0]; -    extend_[0]->RecursiveSortAndSet(context, parent_ptr); -    context.DeleteVertexNode(this); -    return; +} + +void VertexNode::FinishedAppending(const unsigned char common_left, const unsigned char common_right) { +  assert(!hypos_.empty()); +  assert(extend_.empty()); +  bound_ = hypos_.front().score; +  state_ = hypos_.front().state; +  bool all_full = state_.left.full; +  bool all_non_full = !state_.left.full; +  DetermineSame<lm::ngram::Left> left(state_.left, common_left); +  DetermineSame<lm::ngram::Right> right(state_.right, common_right); +  for (std::vector<HypoState>::const_iterator i = hypos_.begin() + 1; i != hypos_.end(); ++i) { +    all_full &= i->state.left.full; +    all_non_full &= !i->state.left.full; +    left.Consider(i->state.left); +    right.Consider(i->state.right);    } -  for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { -    (*i)->RecursiveSortAndSet(context, *i); +  state_.left.full = all_full && left.Complete(); +  right_full_ = all_full && right.Complete(); +  state_.left.length = left.Shared(); +  state_.right.length = right.Shared(); + +  if (!all_full && !all_non_full) { +    policy_ = kPolicyAlternate; +  } else if (left.Complete()) { +    policy_ = kPolicyOneRight; +  } else if (right.Complete()) { +    policy_ = kPolicyOneLeft; +  } else { +    policy_ = kPolicyAlternate;    } -  std::sort(extend_.begin(), extend_.end(), GreaterByBound()); -  bound_ = extend_.front()->Bound(); +  niceness_ = state_.left.length + state_.right.length;  } -void VertexNode::SortAndSet(ContextBase &context) { -  // This is the root.  The root might be empty.   -  if (extend_.empty()) { -    bound_ = -INFINITY; -    return; +void VertexNode::BuildExtend() { +  // Already built. +  if (!extend_.empty()) return; +  // Nothing to build since this is a leaf. +  if (hypos_.size() <= 1) return; +  bool left_branch = true; +  switch (policy_) { +    case kPolicyAlternate: +      left_branch = (state_.left.length <= state_.right.length); +      break; +    case kPolicyOneLeft: +      left_branch = true; +      break; +    case kPolicyOneRight: +      left_branch = false; +      break; +  } +  if (left_branch) { +    Split(DivideLeft(state_.left.length), hypos_, extend_); +  } else { +    Split(DivideRight(state_.right.length), hypos_, extend_);    } -  // The root cannot be replaced.  There's always one transition.   -  for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { -    (*i)->RecursiveSortAndSet(context, *i); +  for (std::vector<VertexNode>::iterator i = extend_.begin(); i != extend_.end(); ++i) { +    // TODO: provide more here for branching? +    i->FinishedAppending(state_.left.length, state_.right.length);    } -  std::sort(extend_.begin(), extend_.end(), GreaterByBound()); -  bound_ = extend_.front()->Bound();  }  } // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index ca9a4fcd..81c3cfed 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -16,59 +16,74 @@ namespace search {  class ContextBase; +struct HypoState { +  History history; +  lm::ngram::ChartState state; +  Score score; +}; +  class VertexNode {    public: -    VertexNode() : end_() {} - -    void InitRoot() { -      extend_.clear(); -      state_.left.full = false; -      state_.left.length = 0; -      state_.right.length = 0; -      right_full_ = false; -      end_ = History(); +    VertexNode() {} + +    void InitRoot() { hypos_.clear(); } + +    /* The steps of building a VertexNode: +     * 1. Default construct. +     * 2. AppendHypothesis at least once, possibly multiple times. +     * 3. FinishAppending with the number of words on left and right guaranteed +     * to be common. +     * 4. If !Complete(), call BuildExtend to construct the extensions +     */ +    // Must default construct, call AppendHypothesis 1 or more times then do FinishedAppending. +    void AppendHypothesis(const NBestComplete &best) { +      assert(hypos_.empty() || !(hypos_.front().state == *best.state)); +      HypoState hypo; +      hypo.history = best.history; +      hypo.state = *best.state; +      hypo.score = best.score; +      hypos_.push_back(hypo); +    } +    void AppendHypothesis(const HypoState &hypo) { +      hypos_.push_back(hypo);      } -    lm::ngram::ChartState &MutableState() { return state_; } -    bool &MutableRightFull() { return right_full_; } +    // Sort hypotheses for the root. +    void FinishRoot(); -    void AddExtend(VertexNode *next) { -      extend_.push_back(next); -    } +    void FinishedAppending(const unsigned char common_left, const unsigned char common_right); -    void SetEnd(History end, Score score) { -      assert(!end_); -      end_ = end; -      bound_ = score; -    } -     -    void SortAndSet(ContextBase &context); +    void BuildExtend();      // Should only happen to a root node when the entire vertex is empty.         bool Empty() const { -      return !end_ && extend_.empty(); +      return hypos_.empty() && extend_.empty();      }      bool Complete() const { -      return end_; +      // HACK: prevent root from being complete.  TODO: allow root to be complete. +      return hypos_.size() == 1 && extend_.empty();      }      const lm::ngram::ChartState &State() const { return state_; }      bool RightFull() const { return right_full_; } +    // Priority relative to other non-terminals.  0 is highest. +    unsigned char Niceness() const { return niceness_; } +      Score Bound() const {        return bound_;      } -    unsigned char Length() const { -      return state_.left.length + state_.right.length; -    } -      // Will be invalid unless this is a leaf.    -    History End() const { return end_; } +    History End() const { +      assert(hypos_.size() == 1); +      return hypos_.front().history; +    } -    const VertexNode &operator[](size_t index) const { -      return *extend_[index]; +    VertexNode &operator[](size_t index) { +      assert(!extend_.empty()); +      return extend_[index];      }      size_t Size() const { @@ -76,22 +91,26 @@ class VertexNode {      }    private: -    void RecursiveSortAndSet(ContextBase &context, VertexNode *&parent); +    // Hypotheses to be split. +    std::vector<HypoState> hypos_; -    std::vector<VertexNode*> extend_; +    std::vector<VertexNode> extend_;      lm::ngram::ChartState state_;      bool right_full_; +    unsigned char niceness_; + +    unsigned char policy_; +      Score bound_; -    History end_;  };  class PartialVertex {    public:      PartialVertex() {} -    explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {} +    explicit PartialVertex(VertexNode &back) : back_(&back), index_(0) {}      bool Empty() const { return back_->Empty(); } @@ -100,17 +119,14 @@ class PartialVertex {      const lm::ngram::ChartState &State() const { return back_->State(); }      bool RightFull() const { return back_->RightFull(); } -    Score Bound() const { return Complete() ? back_->Bound() : (*back_)[index_].Bound(); } - -    unsigned char Length() const { return back_->Length(); } +    Score Bound() const { return index_ ? (*back_)[index_].Bound() : back_->Bound(); } -    bool HasAlternative() const { -      return index_ + 1 < back_->Size(); -    } +    unsigned char Niceness() const { return back_->Niceness(); }      // Split into continuation and alternative, rendering this the continuation.      bool Split(PartialVertex &alternative) {        assert(!Complete()); +      back_->BuildExtend();        bool ret;        if (index_ + 1 < back_->Size()) {          alternative.index_ = index_ + 1; @@ -129,7 +145,7 @@ class PartialVertex {      }    private: -    const VertexNode *back_; +    VertexNode *back_;      unsigned int index_;  }; @@ -139,10 +155,21 @@ class Vertex {    public:      Vertex() {} -    PartialVertex RootPartial() const { return PartialVertex(root_); } +    //PartialVertex RootFirst() const { return PartialVertex(right_); } +    PartialVertex RootAlternate() { return PartialVertex(root_); } +    //PartialVertex RootLast() const { return PartialVertex(left_); } + +    bool Empty() const { +      return root_.Empty(); +    } + +    Score Bound() const { +      return root_.Bound(); +    } -    History BestChild() const { -      PartialVertex top(RootPartial()); +    History BestChild() { +      // left_ and right_ are not set at the root. +      PartialVertex top(RootAlternate());        if (top.Empty()) {          return History();        } else { @@ -158,6 +185,12 @@ class Vertex {      template <class Output> friend class VertexGenerator;      template <class Output> friend class RootVertexGenerator;      VertexNode root_; + +    // These will not be set for the root vertex. +    // Branches only on left state. +    //VertexNode left_; +    // Branches only on right state. +    //VertexNode right_;  };  } // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 646b8189..91000012 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -4,10 +4,8 @@  #include "search/edge.hh"  #include "search/types.hh"  #include "search/vertex.hh" -#include "util/exception.hh"  #include <boost/unordered_map.hpp> -#include <boost/version.hpp>  namespace lm {  namespace ngram { @@ -19,45 +17,25 @@ namespace search {  class ContextBase; -#if BOOST_VERSION > 104200 -// Parallel structure to VertexNode.   -struct Trie { -  Trie() : under(NULL) {} - -  VertexNode *under; -  boost::unordered_map<uint64_t, Trie> extend; -}; - -void AddHypothesis(ContextBase &context, Trie &root, const NBestComplete &end); - -#endif // BOOST_VERSION -  // Output makes the single-best or n-best list.     template <class Output> class VertexGenerator {    public: -    VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) { -      gen.root_.InitRoot(); -    } +    VertexGenerator(ContextBase &context, Vertex &gen, Output &nbest) : context_(context), gen_(gen), nbest_(nbest) {}      void NewHypothesis(PartialEdge partial) {        nbest_.Add(existing_[hash_value(partial.CompletedState())], partial);      }      void FinishedSearch() { -#if BOOST_VERSION > 104200 -      Trie root; -      root.under = &gen_.root_; +      gen_.root_.InitRoot();        for (typename Existing::iterator i(existing_.begin()); i != existing_.end(); ++i) { -        AddHypothesis(context_, root, nbest_.Complete(i->second)); +        gen_.root_.AppendHypothesis(nbest_.Complete(i->second));        }        existing_.clear(); -      root.under->SortAndSet(context_); -#else -      UTIL_THROW(util::Exception, "Upgrade Boost to >= 1.42.0 to use incremental search."); -#endif +      gen_.root_.FinishRoot();      } -    const Vertex &Generating() const { return gen_; } +    Vertex &Generating() { return gen_; }    private:      ContextBase &context_; @@ -84,8 +62,8 @@ template <class Output> class RootVertexGenerator {      void FinishedSearch() {        gen_.root_.InitRoot(); -      NBestComplete completed(out_.Complete(combine_)); -      gen_.root_.SetEnd(completed.history, completed.score); +      gen_.root_.AppendHypothesis(out_.Complete(combine_)); +      gen_.root_.FinishRoot();      }    private: | 
