diff options
Diffstat (limited to 'klm')
| -rw-r--r-- | klm/search/Jamfile | 5 | ||||
| -rw-r--r-- | klm/search/arity.hh | 8 | ||||
| -rw-r--r-- | klm/search/config.hh | 25 | ||||
| -rw-r--r-- | klm/search/context.hh | 66 | ||||
| -rw-r--r-- | klm/search/edge.hh | 54 | ||||
| -rw-r--r-- | klm/search/edge_generator.cc | 129 | ||||
| -rw-r--r-- | klm/search/edge_generator.hh | 54 | ||||
| -rw-r--r-- | klm/search/final.hh | 40 | ||||
| -rw-r--r-- | klm/search/rule.cc | 55 | ||||
| -rw-r--r-- | klm/search/rule.hh | 60 | ||||
| -rw-r--r-- | klm/search/source.hh | 48 | ||||
| -rw-r--r-- | klm/search/types.hh | 18 | ||||
| -rw-r--r-- | klm/search/vertex.cc | 48 | ||||
| -rw-r--r-- | klm/search/vertex.hh | 165 | ||||
| -rw-r--r-- | klm/search/vertex_generator.cc | 99 | ||||
| -rw-r--r-- | klm/search/vertex_generator.hh | 70 | ||||
| -rw-r--r-- | klm/search/weights.cc | 69 | ||||
| -rw-r--r-- | klm/search/weights.hh | 49 | ||||
| -rw-r--r-- | klm/search/weights_test.cc | 38 | ||||
| -rw-r--r-- | klm/search/word.hh | 47 | 
20 files changed, 1147 insertions, 0 deletions
| diff --git a/klm/search/Jamfile b/klm/search/Jamfile new file mode 100644 index 00000000..ac47c249 --- /dev/null +++ b/klm/search/Jamfile @@ -0,0 +1,5 @@ +lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil : : : <include>.. ; + +import testing ; + +unit-test weights_test : weights_test.cc search /top//boost_unit_test_framework ; diff --git a/klm/search/arity.hh b/klm/search/arity.hh new file mode 100644 index 00000000..09c2c671 --- /dev/null +++ b/klm/search/arity.hh @@ -0,0 +1,8 @@ +#ifndef SEARCH_ARITY__ +#define SEARCH_ARITY__ +namespace search { + +const unsigned int kMaxArity = 2; + +} // namespace search +#endif // SEARCH_ARITY__ diff --git a/klm/search/config.hh b/klm/search/config.hh new file mode 100644 index 00000000..e21e4b7c --- /dev/null +++ b/klm/search/config.hh @@ -0,0 +1,25 @@ +#ifndef SEARCH_CONFIG__ +#define SEARCH_CONFIG__ + +#include "search/weights.hh" +#include "util/string_piece.hh" + +namespace search { + +class Config { +  public: +    Config(StringPiece weight_str, unsigned int pop_limit) : +      weights_(weight_str), pop_limit_(pop_limit) {} + +    const Weights &GetWeights() const { return weights_; } + +    unsigned int PopLimit() const { return pop_limit_; } + +  private: +    search::Weights weights_; +    unsigned int pop_limit_; +}; + +} // namespace search + +#endif // SEARCH_CONFIG__ diff --git a/klm/search/context.hh b/klm/search/context.hh new file mode 100644 index 00000000..ae248549 --- /dev/null +++ b/klm/search/context.hh @@ -0,0 +1,66 @@ +#ifndef SEARCH_CONTEXT__ +#define SEARCH_CONTEXT__ + +#include "lm/model.hh" +#include "search/config.hh" +#include "search/final.hh" +#include "search/types.hh" +#include "search/vertex.hh" +#include "search/word.hh" +#include "util/exception.hh" + +#include <boost/pool/object_pool.hpp> +#include <boost/ptr_container/ptr_vector.hpp> + +#include <vector> + +namespace search { + +class Weights; + +class ContextBase { +  public: +    explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} + +    Final *NewFinal() { +     Final *ret = final_pool_.construct(); +     assert(ret); +     return ret; +    } + +    VertexNode *NewVertexNode() { +      VertexNode *ret = vertex_node_pool_.construct(); +      assert(ret); +      return ret; +    } + +    void DeleteVertexNode(VertexNode *node) { +      vertex_node_pool_.destroy(node); +    } + +    unsigned int PopLimit() const { return pop_limit_; } + +    const Weights &GetWeights() const { return weights_; } + +  private: +    boost::object_pool<Final> final_pool_; +    boost::object_pool<VertexNode> vertex_node_pool_; + +    unsigned int pop_limit_; + +    const Weights &weights_; +}; + +template <class Model> class Context : public ContextBase { +  public: +    Context(const Config &config, const Model &model) : ContextBase(config), model_(model) {} + +    const Model &LanguageModel() const { return model_; } + +  private: +    const Model &model_; +}; + +} // namespace search + +#endif // SEARCH_CONTEXT__ diff --git a/klm/search/edge.hh b/klm/search/edge.hh new file mode 100644 index 00000000..4d2a5cbf --- /dev/null +++ b/klm/search/edge.hh @@ -0,0 +1,54 @@ +#ifndef SEARCH_EDGE__ +#define SEARCH_EDGE__ + +#include "lm/state.hh" +#include "search/arity.hh" +#include "search/rule.hh" +#include "search/types.hh" +#include "search/vertex.hh" + +#include <queue> + +namespace search { + +class Edge { +  public: +    Edge() { +      end_to_ = to_; +    } + +    Rule &InitRule() { return rule_; } + +    void Add(Vertex &vertex) { +      assert(end_to_ - to_ < kMaxArity); +      *(end_to_++) = &vertex; +    } + +    const Vertex &GetVertex(std::size_t index) const { +      return *to_[index]; +    } + +    const Rule &GetRule() const { return rule_; } + +  private: +    // Rule and pointers to rule arguments.   +    Rule rule_; + +    Vertex *to_[kMaxArity]; +    Vertex **end_to_; +}; + +struct PartialEdge { +  Score score; +  // Terminals +  lm::ngram::ChartState between[kMaxArity + 1]; +  // Non-terminals +  PartialVertex nt[kMaxArity]; + +  bool operator<(const PartialEdge &other) const { +    return score < other.score; +  } +}; + +} // namespace search +#endif // SEARCH_EDGE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc new file mode 100644 index 00000000..d135899a --- /dev/null +++ b/klm/search/edge_generator.cc @@ -0,0 +1,129 @@ +#include "search/edge_generator.hh" + +#include "lm/left.hh" +#include "lm/partial.hh" +#include "search/context.hh" +#include "search/vertex.hh" +#include "search/vertex_generator.hh" + +#include <numeric> + +namespace search { + +bool EdgeGenerator::Init(Edge &edge, VertexGenerator &parent) { +  from_ = &edge; +  for (unsigned int i = 0; i < GetRule().Arity(); ++i) { +    if (edge.GetVertex(i).RootPartial().Empty()) return false; +  } +  PartialEdge &root = *parent.MallocPartialEdge(); +  root.score = GetRule().Bound(); +  for (unsigned int i = 0; i < GetRule().Arity(); ++i) { +    root.nt[i] = edge.GetVertex(i).RootPartial(); +    root.score += root.nt[i].Bound(); +  } +  for (unsigned int i = GetRule().Arity(); i < 2; ++i) { +    root.nt[i] = kBlankPartialVertex; +  } +  for (unsigned int i = 0; i < GetRule().Arity() + 1; ++i) { +    root.between[i] = GetRule().Lexical(i); +  } +  // wtf no clear method? +  generate_ = Generate(); +  generate_.push(&root); +  top_ = root.score; +  return true; +} + +namespace { + +template <class Model> float FastScore(const Context<Model> &context, unsigned char victim, unsigned char arity, const PartialEdge &previous, PartialEdge &update) { +  memcpy(update.between, previous.between, sizeof(lm::ngram::ChartState) * (arity + 1)); + +  float ret = 0.0; +  lm::ngram::ChartState *before, *after; +  if (victim == 0) { +    before = &update.between[0]; +    after = &update.between[(arity == 2 && previous.nt[1].Complete()) ? 2 : 1]; +  } else { +    assert(victim == 1); +    assert(arity == 2); +    before = &update.between[previous.nt[0].Complete() ? 0 : 1]; +    after = &update.between[2]; +  } +  const lm::ngram::ChartState &previous_reveal = previous.nt[victim].State(); +  const PartialVertex &update_nt = update.nt[victim]; +  const lm::ngram::ChartState &update_reveal = update_nt.State(); +  float just_after = 0.0; +  if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) { +    just_after += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); +  } +  if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous.nt[victim].RightFull())) { +    ret += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); +  } +  if (update_nt.Complete()) { +    if (update_reveal.left.full) { +      before->left.full = true; +    } else { +      assert(update_reveal.left.length == update_reveal.right.length); +      ret += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); +    } +    if (victim == 0) { +      update.between[0].right = after->right; +    } else { +      update.between[2].left = before->left; +    } +  } +  return previous.score + (ret + just_after) * context.GetWeights().LM(); +} + +} // namespace + +template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGenerator &parent) { +  assert(!generate_.empty()); +  PartialEdge &top = *generate_.top(); +  generate_.pop(); +  unsigned int victim = 0; +  unsigned char lowest_length = 255; +  for (unsigned int i = 0; i != GetRule().Arity(); ++i) { +    if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) { +      lowest_length = top.nt[i].Length(); +      victim = i; +    } +  } +  if (lowest_length == 255) { +    // All states report complete.   +    top.between[0].right = top.between[GetRule().Arity()].right; +    parent.NewHypothesis(top.between[0], *from_, top); +    top_ = generate_.empty() ? -kScoreInf : generate_.top()->score; +    return !generate_.empty(); +  } + +  unsigned int stay = !victim; +  PartialEdge &continuation = *parent.MallocPartialEdge(); +  float old_bound = top.nt[victim].Bound(); +  // The alternate's score will change because alternate.nt[victim] changes.   +  bool split = top.nt[victim].Split(continuation.nt[victim]); +  // top is now the alternate.   + +  continuation.nt[stay] = top.nt[stay]; +  continuation.score = FastScore(context, victim, GetRule().Arity(), top, continuation); +  // TODO: dedupe?   +  generate_.push(&continuation); + +  if (split) { +    // We have an alternate.   +    top.score += top.nt[victim].Bound() - old_bound; +    // TODO: dedupe?   +    generate_.push(&top); +  } else { +    parent.FreePartialEdge(&top); +  } + +  top_ = generate_.top()->score; +  return true; +} + +template bool EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, VertexGenerator &parent); +template bool EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, VertexGenerator &parent); + +} // namespace search diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh new file mode 100644 index 00000000..e306dc61 --- /dev/null +++ b/klm/search/edge_generator.hh @@ -0,0 +1,54 @@ +#ifndef SEARCH_EDGE_GENERATOR__ +#define SEARCH_EDGE_GENERATOR__ + +#include "search/edge.hh" + +#include <boost/unordered_map.hpp> + +#include <functional> +#include <queue> + +namespace lm { +namespace ngram { +class ChartState; +} // namespace ngram +} // namespace lm + +namespace search { + +template <class Model> class Context; + +class VertexGenerator; + +struct PartialEdgePointerLess : std::binary_function<const PartialEdge *, const PartialEdge *, bool> { +  bool operator()(const PartialEdge *first, const PartialEdge *second) const { +    return *first < *second; +  } +}; + +class EdgeGenerator { +  public: +    // True if it has a hypothesis.   +    bool Init(Edge &edge, VertexGenerator &parent); + +    Score Top() const { +      return top_; +    } + +    template <class Model> bool Pop(Context<Model> &context, VertexGenerator &parent); + +  private: +    const Rule &GetRule() const { +      return from_->GetRule(); +    } + +    Score top_; + +    typedef std::priority_queue<PartialEdge*, std::vector<PartialEdge*>, PartialEdgePointerLess> Generate; +    Generate generate_; + +    Edge *from_; +}; + +} // namespace search +#endif // SEARCH_EDGE_GENERATOR__ diff --git a/klm/search/final.hh b/klm/search/final.hh new file mode 100644 index 00000000..24e6f0a5 --- /dev/null +++ b/klm/search/final.hh @@ -0,0 +1,40 @@ +#ifndef SEARCH_FINAL__ +#define SEARCH_FINAL__ + +#include "search/rule.hh" +#include "search/types.hh" + +#include <boost/array.hpp> + +namespace search { + +class Final { +  public: +    typedef boost::array<const Final*, search::kMaxArity> ChildArray; + +    void Reset(Score bound, const Rule &from, const Final &left, const Final &right) { +      bound_ = bound; +      from_ = &from; +      children_[0] = &left; +      children_[1] = &right; +    } + +    const ChildArray &Children() const { return children_; } + +    unsigned int ChildCount() const { return from_->Arity(); } + +    const Rule &From() const { return *from_; } + +    Score Bound() const { return bound_; } + +  private: +    Score bound_; + +    const Rule *from_; + +    ChildArray children_; +}; + +} // namespace search + +#endif // SEARCH_FINAL__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc new file mode 100644 index 00000000..a8b993eb --- /dev/null +++ b/klm/search/rule.cc @@ -0,0 +1,55 @@ +#include "search/rule.hh" + +#include "search/context.hh" +#include "search/final.hh" + +#include <ostream> + +#include <math.h> + +namespace search { + +template <class Model> void Rule::FinishedAdding(const Context<Model> &context, Score additive, bool prepend_bos) { +  additive_ = additive; +  Score lm_score = 0.0; +  lexical_.clear(); +  const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound(); + +  for (std::vector<Word>::const_iterator word = items_.begin(); ; ++word) { +    lexical_.resize(lexical_.size() + 1); +    lm::ngram::RuleScore<Model> scorer(context.LanguageModel(), lexical_.back()); +    // TODO: optimize +    if (prepend_bos && (word == items_.begin())) { +      scorer.BeginSentence(); +    } +    for (; ; ++word) { +      if (word == items_.end()) { +        lm_score += scorer.Finish(); +        bound_ = additive_ + context.GetWeights().LM() * lm_score; +        assert(lexical_.size() == arity_ + 1); +        return; +      } +      if (!word->Terminal()) break; +      if (word->Index() == oov) additive_ += context.GetWeights().OOV(); +      scorer.Terminal(word->Index()); +    } +    lm_score += scorer.Finish(); +  } +} + +template void Rule::FinishedAdding(const Context<lm::ngram::RestProbingModel> &context, Score additive, bool prepend_bos); +template void Rule::FinishedAdding(const Context<lm::ngram::ProbingModel> &context, Score additive, bool prepend_bos); + +std::ostream &operator<<(std::ostream &o, const Rule &rule) { +  const Rule::ItemsRet &items = rule.Items(); +  for (Rule::ItemsRet::const_iterator i = items.begin(); i != items.end(); ++i) { +    if (i->Terminal()) { +      o << i->String() << ' '; +    } else { +      o << "[] "; +    } +  } +  return o; +} + +} // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh new file mode 100644 index 00000000..79192d40 --- /dev/null +++ b/klm/search/rule.hh @@ -0,0 +1,60 @@ +#ifndef SEARCH_RULE__ +#define SEARCH_RULE__ + +#include "lm/left.hh" +#include "search/arity.hh" +#include "search/types.hh" +#include "search/word.hh" + +#include <boost/array.hpp> + +#include <iosfwd> +#include <vector> + +namespace search { + +template <class Model> class Context; + +class Rule { +  public: +    Rule() : arity_(0) {} + +    void AppendTerminal(Word w) { items_.push_back(w); } + +    void AppendNonTerminal() { +      items_.resize(items_.size() + 1); +      ++arity_; +    } + +    template <class Model> void FinishedAdding(const Context<Model> &context, Score additive, bool prepend_bos); + +    Score Bound() const { return bound_; } + +    Score Additive() const { return additive_; } + +    unsigned int Arity() const { return arity_; } + +    const lm::ngram::ChartState &Lexical(unsigned int index) const { +      return lexical_[index]; +    } + +    // For printing.   +    typedef const std::vector<Word> ItemsRet; +    ItemsRet &Items() const { return items_; } + +  private: +    Score bound_, additive_; + +    unsigned int arity_; + +    // TODO: pool? +    std::vector<Word> items_; + +    std::vector<lm::ngram::ChartState> lexical_; +}; + +std::ostream &operator<<(std::ostream &o, const Rule &rule); + +} // namespace search + +#endif // SEARCH_RULE__ diff --git a/klm/search/source.hh b/klm/search/source.hh new file mode 100644 index 00000000..11839f7b --- /dev/null +++ b/klm/search/source.hh @@ -0,0 +1,48 @@ +#ifndef SEARCH_SOURCE__ +#define SEARCH_SOURCE__ + +#include "search/types.hh" + +#include <assert.h> +#include <vector> + +namespace search { + +template <class Final> class Source { +  public: +    Source() : bound_(kScoreInf) {} + +    Index Size() const { +      return final_.size(); +    } + +    Score Bound() const { +      return bound_; +    } + +    const Final &operator[](Index index) const { +      return *final_[index]; +    } + +    Score ScoreOrBound(Index index) const { +      return Size() > index ? final_[index]->Total() : Bound(); +    } + +  protected: +    void AddFinal(const Final &store) { +      final_.push_back(&store); +    } + +    void SetBound(Score to) { +      assert(to <= bound_ + 0.001); +      bound_ = to; +    } + +  private: +    std::vector<const Final *> final_; + +    Score bound_; +}; + +} // namespace search +#endif // SEARCH_SOURCE__ diff --git a/klm/search/types.hh b/klm/search/types.hh new file mode 100644 index 00000000..9726379f --- /dev/null +++ b/klm/search/types.hh @@ -0,0 +1,18 @@ +#ifndef SEARCH_TYPES__ +#define SEARCH_TYPES__ + +#include <cmath> + +namespace search { + +typedef float Score; +const Score kScoreInf = INFINITY; + +// This could have been an enum but gcc wants 4 bytes.   +typedef bool ExtendDirection; +const ExtendDirection kExtendLeft = 0; +const ExtendDirection kExtendRight = 1; + +} // namespace search + +#endif // SEARCH_TYPES__ diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc new file mode 100644 index 00000000..cc53c0dd --- /dev/null +++ b/klm/search/vertex.cc @@ -0,0 +1,48 @@ +#include "search/vertex.hh" + +#include "search/context.hh" + +#include <algorithm> +#include <functional> + +#include <assert.h> + +namespace search { + +namespace { + +struct GreaterByBound : public std::binary_function<const VertexNode *, const VertexNode *, bool> { +  bool operator()(const VertexNode *first, const VertexNode *second) const { +    return first->Bound() > second->Bound(); +  } +}; + +} // namespace + +void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) { +  if (Complete()) { +    assert(end_); +    assert(extend_.empty()); +    bound_ = end_->Bound(); +    return; +  } +  if (extend_.size() == 1 && parent_ptr) { +    *parent_ptr = extend_[0]; +    extend_[0]->SortAndSet(context, parent_ptr); +    context.DeleteVertexNode(this); +    return; +  } +  for (std::vector<VertexNode*>::iterator i = extend_.begin(); i != extend_.end(); ++i) { +    (*i)->SortAndSet(context, &*i); +  } +  std::sort(extend_.begin(), extend_.end(), GreaterByBound()); +  bound_ = extend_.front()->Bound(); +} + +namespace { +VertexNode kBlankVertexNode; +} // namespace + +PartialVertex kBlankPartialVertex(kBlankVertexNode); + +} // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh new file mode 100644 index 00000000..7ef29efc --- /dev/null +++ b/klm/search/vertex.hh @@ -0,0 +1,165 @@ +#ifndef SEARCH_VERTEX__ +#define SEARCH_VERTEX__ + +#include "lm/left.hh" +#include "search/final.hh" +#include "search/types.hh" + +#include <boost/unordered_set.hpp> + +#include <queue> +#include <vector> + +#include <stdint.h> + +namespace search { + +class ContextBase; + +class Edge; + +class VertexNode { +  public: +    VertexNode() : end_(NULL) {} + +    void InitRoot() { +      extend_.clear(); +      state_.left.full = false; +      state_.left.length = 0; +      state_.right.length = 0; +      right_full_ = false; +      bound_ = -kScoreInf; +      end_ = NULL; +    } + +    lm::ngram::ChartState &MutableState() { return state_; } +    bool &MutableRightFull() { return right_full_; } + +    void AddExtend(VertexNode *next) { +      extend_.push_back(next); +    } + +    void SetEnd(Final *end) { end_ = end; } +     +    Final &MutableEnd() { return *end_; } + +    void SortAndSet(ContextBase &context, VertexNode **parent_pointer); + +    // Should only happen to a root node when the entire vertex is empty.    +    bool Empty() const { +      return !end_ && extend_.empty(); +    } + +    bool Complete() const { +      return end_; +    } + +    const lm::ngram::ChartState &State() const { return state_; } +    bool RightFull() const { return right_full_; } + +    Score Bound() const { +      return bound_; +    } + +    unsigned char Length() const { +      return state_.left.length + state_.right.length; +    } + +    // May be NULL. +    const Final *End() const { return end_; } + +    const VertexNode &operator[](size_t index) const { +      return *extend_[index]; +    } + +    size_t Size() const { +      return extend_.size(); +    } + +  private: +    std::vector<VertexNode*> extend_; + +    lm::ngram::ChartState state_; +    bool right_full_; + +    Score bound_; +    Final *end_; +}; + +class PartialVertex { +  public: +    PartialVertex() {} + +    explicit PartialVertex(const VertexNode &back) : back_(&back), index_(0) {} + +    bool Empty() const { return back_->Empty(); } + +    bool Complete() const { return back_->Complete(); } + +    const lm::ngram::ChartState &State() const { return back_->State(); } +    bool RightFull() const { return back_->RightFull(); } + +    Score Bound() const { return Complete() ? back_->End()->Bound() : (*back_)[index_].Bound(); } + +    unsigned char Length() const { return back_->Length(); } + +    // Split into continuation and alternative, rendering this the alternative. +    bool Split(PartialVertex &continuation) { +      assert(!Complete()); +      continuation.back_ = &((*back_)[index_]); +      continuation.index_ = 0; +      if (index_ + 1 < back_->Size()) { +        ++index_; +        return true; +      } +      return false; +    } + +    const Final &End() const { +      return *back_->End(); +    } + +  private: +    const VertexNode *back_; +    unsigned int index_; +}; + +extern PartialVertex kBlankPartialVertex; + +class Vertex { +  public: +    Vertex()  +#ifdef DEBUG +      : finished_adding_(false) +#endif +    {} + +    void Add(Edge &edge) { +#ifdef DEBUG +      assert(!finished_adding_); +#endif +      edges_.push_back(&edge); +    } + +    void FinishedAdding() { +#ifdef DEBUG +      assert(!finished_adding_); +      finished_adding_ = true; +#endif +    } + +    PartialVertex RootPartial() const { return PartialVertex(root_); } + +  private: +    friend class VertexGenerator; +    std::vector<Edge*> edges_; + +#ifdef DEBUG +    bool finished_adding_; +#endif + +    VertexNode root_; +}; + +} // namespace search +#endif // SEARCH_VERTEX__ diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc new file mode 100644 index 00000000..0281fc37 --- /dev/null +++ b/klm/search/vertex_generator.cc @@ -0,0 +1,99 @@ +#include "search/vertex_generator.hh" + +#include "lm/left.hh" +#include "search/context.hh" + +#include <stdint.h> + +namespace search { + +template <class Model> VertexGenerator::VertexGenerator(Context<Model> &context, Vertex &gen) : context_(context), edges_(gen.edges_.size()), partial_edge_pool_(sizeof(PartialEdge), context.PopLimit() * 2) { +  for (std::size_t i = 0; i < gen.edges_.size(); ++i) { +    if (edges_[i].Init(*gen.edges_[i], *this)) +      generate_.push(&edges_[i]); +  } +  gen.root_.InitRoot(); +  root_.under = &gen.root_; +  to_pop_ = context.PopLimit(); +  while (to_pop_ > 0 && !generate_.empty()) { +    EdgeGenerator *top = generate_.top(); +    generate_.pop(); +    if (top->Pop(context, *this)) { +      generate_.push(top); +    } +  } +  gen.root_.SortAndSet(context, NULL); +} + +template VertexGenerator::VertexGenerator(Context<lm::ngram::ProbingModel> &context, Vertex &gen); +template VertexGenerator::VertexGenerator(Context<lm::ngram::RestProbingModel> &context, Vertex &gen); + +namespace { +const uint64_t kCompleteAdd = static_cast<uint64_t>(-1); +} // namespace + +void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) { +  std::pair<Existing::iterator, bool> got(existing_.insert(std::pair<uint64_t, Final*>(hash_value(state), NULL))); +  if (!got.second) { +    // Found it already.   +    Final &exists = *got.first->second; +    if (exists.Bound() < partial.score) { +      exists.Reset(partial.score, from.GetRule(), partial.nt[0].End(), partial.nt[1].End()); +    } +    --to_pop_; +    return; +  } +  unsigned char left = 0, right = 0; +  Trie *node = &root_; +  while (true) { +    if (left == state.left.length) { +      node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, true, right, false); +      for (; right < state.right.length; ++right) { +        node = &FindOrInsert(*node, state.right.words[right], state, left, true, right + 1, false); +      } +      break; +    } +    node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, false); +    left++; +    if (right == state.right.length) { +      node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, false, right, true); +      for (; left < state.left.length; ++left) { +        node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, true); +      } +      break; +    } +    node = &FindOrInsert(*node, state.right.words[right], state, left, false, right + 1, false); +    right++; +  } + +  node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); +  got.first->second = CompleteTransition(*node, state, from, partial); +  --to_pop_; +} + +VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { +  VertexGenerator::Trie &next = node.extend[added]; +  if (!next.under) { +    next.under = context_.NewVertexNode(); +    lm::ngram::ChartState &writing = next.under->MutableState(); +    writing = state; +    writing.left.full &= left_full && state.left.full; +    next.under->MutableRightFull() = right_full && state.left.full; +    writing.left.length = left; +    writing.right.length = right; +    node.under->AddExtend(next.under); +  } +  return next; +} + +Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) { +  VertexNode &node = *starter.under; +  assert(node.State().left.full == state.left.full); +  assert(!node.End()); +  Final *final = context_.NewFinal(); +  final->Reset(partial.score, from.GetRule(), partial.nt[0].End(), partial.nt[1].End()); +  node.SetEnd(final); +  return final; +} + +} // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh new file mode 100644 index 00000000..8cdf1420 --- /dev/null +++ b/klm/search/vertex_generator.hh @@ -0,0 +1,70 @@ +#ifndef SEARCH_VERTEX_GENERATOR__ +#define SEARCH_VERTEX_GENERATOR__ + +#include "search/edge.hh" +#include "search/edge_generator.hh" + +#include <boost/pool/pool.hpp> +#include <boost/unordered_map.hpp> + +#include <queue> + +namespace lm { +namespace ngram { +class ChartState; +} // namespace ngram +} // namespace lm + +namespace search { + +template <class Model> class Context; +class ContextBase; +class Final; + +class VertexGenerator { +  public: +    template <class Model> VertexGenerator(Context<Model> &context, Vertex &gen); + +    PartialEdge *MallocPartialEdge() { return static_cast<PartialEdge*>(partial_edge_pool_.malloc()); } +    void FreePartialEdge(PartialEdge *value) { partial_edge_pool_.free(value); } + +    void NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial); + +  private: +    // Parallel structure to VertexNode.   +    struct Trie { +      Trie() : under(NULL) {} + +      VertexNode *under; +      boost::unordered_map<uint64_t, Trie> extend; +    }; + +    Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full); + +    Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial); + +    ContextBase &context_; + +    std::vector<EdgeGenerator> edges_; + +    struct LessByTop : public std::binary_function<const EdgeGenerator *, const EdgeGenerator *, bool> { +      bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { +        return first->Top() < second->Top(); +      } +    }; + +    typedef std::priority_queue<EdgeGenerator*, std::vector<EdgeGenerator*>, LessByTop> Generate; +    Generate generate_; + +    Trie root_; + +    typedef boost::unordered_map<uint64_t, Final*> Existing; +    Existing existing_; + +    int to_pop_; + +    boost::pool<> partial_edge_pool_; +}; + +} // namespace search +#endif // SEARCH_VERTEX_GENERATOR__ diff --git a/klm/search/weights.cc b/klm/search/weights.cc new file mode 100644 index 00000000..82ff3f12 --- /dev/null +++ b/klm/search/weights.cc @@ -0,0 +1,69 @@ +#include "search/weights.hh" +#include "util/tokenize_piece.hh" + +#include <cstdlib> + +namespace search { + +namespace { +struct Insert { +  void operator()(boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) const { +    std::string copy(name.data(), name.size()); +    map[copy] = score; +  } +}; + +struct DotProduct { +  search::Score total; +  DotProduct() : total(0.0) {} + +  void operator()(const boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) { +    boost::unordered_map<std::string, search::Score>::const_iterator i(FindStringPiece(map, name)); +    if (i != map.end())  +      total += score * i->second; +  } +}; + +template <class Map, class Op> void Parse(StringPiece text, Map &map, Op &op) { +  for (util::TokenIter<util::SingleCharacter, true> spaces(text, ' '); spaces; ++spaces) { +    util::TokenIter<util::SingleCharacter> equals(*spaces, '='); +    UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces); +    StringPiece name(*equals); +    UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces); +    char *end; +    // Assumes proper termination.   +    double value = std::strtod(equals->data(), &end); +    UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals); +    UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces); +    op(map, name, value); +  } +} + +} // namespace + +Weights::Weights(StringPiece text) { +  Insert op; +  Parse<Map, Insert>(text, map_, op); +  lm_ = Steal("LanguageModel"); +  oov_ = Steal("OOV"); +  word_penalty_ = Steal("WordPenalty"); +} + +search::Score Weights::DotNoLM(StringPiece text) const { +  DotProduct dot; +  Parse<const Map, DotProduct>(text, map_, dot); +  return dot.total; +} + +float Weights::Steal(const std::string &str) { +  Map::iterator i(map_.find(str)); +  if (i == map_.end()) { +    return 0.0; +  } else { +    float ret = i->second; +    map_.erase(i); +    return ret; +  } +} + +} // namespace search diff --git a/klm/search/weights.hh b/klm/search/weights.hh new file mode 100644 index 00000000..4a4388c7 --- /dev/null +++ b/klm/search/weights.hh @@ -0,0 +1,49 @@ +// For now, the individual features are not kept.   +#ifndef SEARCH_WEIGHTS__ +#define SEARCH_WEIGHTS__ + +#include "search/types.hh" +#include "util/exception.hh" +#include "util/string_piece.hh" + +#include <boost/unordered_map.hpp> + +#include <string> + +namespace search { + +class WeightParseException : public util::Exception { +  public: +    WeightParseException() {} +    ~WeightParseException() throw() {} +}; + +class Weights { +  public: +    // Parses weights, sets lm_weight_, removes it from map_. +    explicit Weights(StringPiece text); + +    search::Score DotNoLM(StringPiece text) const; + +    search::Score LM() const { return lm_; } + +    search::Score OOV() const { return oov_; } + +    search::Score WordPenalty() const { return word_penalty_; } + +    // Mostly for testing.   +    const boost::unordered_map<std::string, search::Score> &GetMap() const { return map_; } + +  private: +    float Steal(const std::string &str); + +    typedef boost::unordered_map<std::string, search::Score> Map; + +    Map map_; + +    search::Score lm_, oov_, word_penalty_; +}; + +} // namespace search + +#endif // SEARCH_WEIGHTS__ diff --git a/klm/search/weights_test.cc b/klm/search/weights_test.cc new file mode 100644 index 00000000..4811ff06 --- /dev/null +++ b/klm/search/weights_test.cc @@ -0,0 +1,38 @@ +#include "search/weights.hh" + +#define BOOST_TEST_MODULE WeightTest +#include <boost/test/unit_test.hpp> +#include <boost/test/floating_point_comparison.hpp> + +namespace search { +namespace { + +#define CHECK_WEIGHT(value, string) \ +  i = parsed.find(string); \ +  BOOST_REQUIRE(i != parsed.end()); \ +  BOOST_CHECK_CLOSE((value), i->second, 0.001); + +BOOST_AUTO_TEST_CASE(parse) { +  // These are not real feature weights.   +  Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); +  const boost::unordered_map<std::string, search::Score> &parsed = w.GetMap(); +  boost::unordered_map<std::string, search::Score>::const_iterator i; +  CHECK_WEIGHT(0.0, "rarity"); +  CHECK_WEIGHT(0.0, "phrase-SGT"); +  CHECK_WEIGHT(9.45117, "phrase-TGS"); +  CHECK_WEIGHT(2.33833, "lexical-SGT"); +  BOOST_CHECK(parsed.end() == parsed.find("lm")); +  BOOST_CHECK_CLOSE(3.0, w.LM(), 0.001); +  CHECK_WEIGHT(-28.3317, "lexical-TGS"); +  CHECK_WEIGHT(5.0, "glue?"); +} + +BOOST_AUTO_TEST_CASE(dot) { +  Weights w("rarity=0 phrase-SGT=0 phrase-TGS=9.45117 lhsGrhs=0 lexical-SGT=2.33833 lexical-TGS=-28.3317 abstract?=0 LanguageModel=3 lexical?=1 glue?=5"); +  BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0"), 0.001); +  BOOST_CHECK_CLOSE(9.45117 * 3.0, w.DotNoLM("phrase-TGS=3.0 LanguageModel=10"), 0.001); +  BOOST_CHECK_CLOSE(9.45117 * 3.0 + 28.3317 * 17.4, w.DotNoLM("rarity=5 phrase-TGS=3.0 LanguageModel=10 lexical-TGS=-17.4"), 0.001); +} + +} // namespace +} // namespace search diff --git a/klm/search/word.hh b/klm/search/word.hh new file mode 100644 index 00000000..e7a15be9 --- /dev/null +++ b/klm/search/word.hh @@ -0,0 +1,47 @@ +#ifndef SEARCH_WORD__ +#define SEARCH_WORD__ + +#include "lm/word_index.hh" + +#include <boost/functional/hash.hpp> + +#include <string> +#include <utility> + +namespace search { + +class Word { +  public: +    // Construct a non-terminal. +    Word() : entry_(NULL) {} + +    explicit Word(const std::pair<const std::string, lm::WordIndex> &entry) { +      entry_ = &entry; +    } + +    // Returns true for two non-terminals even if their labels are different (since we don't care about labels). +    bool operator==(const Word &other) const { +      return entry_ == other.entry_; +    } + +    bool Terminal() const { return entry_ != NULL; } + +    const std::string &String() const { return entry_->first; } + +    lm::WordIndex Index() const { return entry_->second; } + +  protected: +    friend size_t hash_value(const Word &word); + +    const std::pair<const std::string, lm::WordIndex> *Entry() const { return entry_; } + +  private: +    const std::pair<const std::string, lm::WordIndex> *entry_; +}; + +inline size_t hash_value(const Word &word) { +  return boost::hash_value(word.Entry()); +} + +} // namespace search +#endif // SEARCH_WORD__ | 
