diff options
Diffstat (limited to 'klm')
| -rw-r--r-- | klm/lm/fragment.cc | 37 | ||||
| -rw-r--r-- | klm/search/Jamfile | 2 | ||||
| -rw-r--r-- | klm/search/edge.hh | 31 | ||||
| -rw-r--r-- | klm/search/edge_generator.cc | 53 | ||||
| -rw-r--r-- | klm/search/edge_generator.hh | 24 | ||||
| -rw-r--r-- | klm/search/edge_queue.cc | 25 | ||||
| -rw-r--r-- | klm/search/edge_queue.hh | 73 | ||||
| -rw-r--r-- | klm/search/final.hh | 11 | ||||
| -rw-r--r-- | klm/search/note.hh | 12 | ||||
| -rw-r--r-- | klm/search/rule.cc | 32 | ||||
| -rw-r--r-- | klm/search/rule.hh | 31 | ||||
| -rw-r--r-- | klm/search/vertex.hh | 45 | ||||
| -rw-r--r-- | klm/search/vertex_generator.cc | 32 | ||||
| -rw-r--r-- | klm/search/vertex_generator.hh | 35 | 
14 files changed, 250 insertions, 193 deletions
| diff --git a/klm/lm/fragment.cc b/klm/lm/fragment.cc new file mode 100644 index 00000000..0267cd4e --- /dev/null +++ b/klm/lm/fragment.cc @@ -0,0 +1,37 @@ +#include "lm/binary_format.hh" +#include "lm/model.hh" +#include "lm/left.hh" +#include "util/tokenize_piece.hh" + +template <class Model> void Query(const char *name) { +  Model model(name); +  std::string line; +  lm::ngram::ChartState ignored; +  while (getline(std::cin, line)) { +    lm::ngram::RuleScore<Model> scorer(model, ignored); +    for (util::TokenIter<util::SingleCharacter, true> i(line, ' '); i; ++i) { +      scorer.Terminal(model.GetVocabulary().Index(*i)); +    } +    std::cout << scorer.Finish() << '\n'; +  } +} + +int main(int argc, char *argv[]) { +  if (argc != 2) { +    std::cerr << "Expected model file name." << std::endl; +    return 1; +  } +  const char *name = argv[1]; +  lm::ngram::ModelType model_type = lm::ngram::PROBING; +  lm::ngram::RecognizeBinary(name, model_type); +  switch (model_type) { +    case lm::ngram::PROBING: +      Query<lm::ngram::ProbingModel>(name); +      break; +    case lm::ngram::REST_PROBING: +      Query<lm::ngram::RestProbingModel>(name); +      break; +    default: +      std::cerr << "Model type not supported yet." << std::endl; +  } +} diff --git a/klm/search/Jamfile b/klm/search/Jamfile index ac47c249..e8b14363 100644 --- a/klm/search/Jamfile +++ b/klm/search/Jamfile @@ -1,4 +1,4 @@ -lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil : : : <include>.. ; +lib search : weights.cc vertex.cc vertex_generator.cc edge_queue.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : <include>.. ;  import testing ; diff --git a/klm/search/edge.hh b/klm/search/edge.hh index 4d2a5cbf..77ab0ade 100644 --- a/klm/search/edge.hh +++ b/klm/search/edge.hh @@ -11,33 +11,6 @@  namespace search { -class Edge { -  public: -    Edge() { -      end_to_ = to_; -    } - -    Rule &InitRule() { return rule_; } - -    void Add(Vertex &vertex) { -      assert(end_to_ - to_ < kMaxArity); -      *(end_to_++) = &vertex; -    } - -    const Vertex &GetVertex(std::size_t index) const { -      return *to_[index]; -    } - -    const Rule &GetRule() const { return rule_; } - -  private: -    // Rule and pointers to rule arguments.   -    Rule rule_; - -    Vertex *to_[kMaxArity]; -    Vertex **end_to_; -}; -  struct PartialEdge {    Score score;    // Terminals @@ -45,6 +18,10 @@ struct PartialEdge {    // Non-terminals    PartialVertex nt[kMaxArity]; +  const lm::ngram::ChartState &CompletedState() const { +    return between[0]; +  } +    bool operator<(const PartialEdge &other) const {      return score < other.score;    } diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index d135899a..56239dfb 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -10,28 +10,15 @@  namespace search { -bool EdgeGenerator::Init(Edge &edge, VertexGenerator &parent) { -  from_ = &edge; -  for (unsigned int i = 0; i < GetRule().Arity(); ++i) { -    if (edge.GetVertex(i).RootPartial().Empty()) return false; -  } -  PartialEdge &root = *parent.MallocPartialEdge(); -  root.score = GetRule().Bound(); -  for (unsigned int i = 0; i < GetRule().Arity(); ++i) { +EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) { +/*  for (unsigned char i = 0; i < edge.Arity(); ++i) {      root.nt[i] = edge.GetVertex(i).RootPartial(); -    root.score += root.nt[i].Bound();    } -  for (unsigned int i = GetRule().Arity(); i < 2; ++i) { +  for (unsigned char i = edge.Arity(); i < 2; ++i) {      root.nt[i] = kBlankPartialVertex; -  } -  for (unsigned int i = 0; i < GetRule().Arity() + 1; ++i) { -    root.between[i] = GetRule().Lexical(i); -  } -  // wtf no clear method? -  generate_ = Generate(); +  }*/    generate_.push(&root); -  top_ = root.score; -  return true; +  top_score_ = root.score;  }  namespace { @@ -78,13 +65,13 @@ template <class Model> float FastScore(const Context<Model> &context, unsigned c  } // namespace -template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGenerator &parent) { +template <class Model> PartialEdge *EdgeGenerator::Pop(Context<Model> &context, boost::pool<> &partial_edge_pool) {    assert(!generate_.empty());    PartialEdge &top = *generate_.top();    generate_.pop();    unsigned int victim = 0;    unsigned char lowest_length = 255; -  for (unsigned int i = 0; i != GetRule().Arity(); ++i) { +  for (unsigned char i = 0; i != arity_; ++i) {      if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) {        lowest_length = top.nt[i].Length();        victim = i; @@ -92,21 +79,21 @@ template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGe    }    if (lowest_length == 255) {      // All states report complete.   -    top.between[0].right = top.between[GetRule().Arity()].right; -    parent.NewHypothesis(top.between[0], *from_, top); -    top_ = generate_.empty() ? -kScoreInf : generate_.top()->score; -    return !generate_.empty(); +    top.between[0].right = top.between[arity_].right; +    // Now top.between[0] is the full edge state.   +    top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score; +    return ⊤    }    unsigned int stay = !victim; -  PartialEdge &continuation = *parent.MallocPartialEdge(); +  PartialEdge &continuation = *static_cast<PartialEdge*>(partial_edge_pool.malloc());    float old_bound = top.nt[victim].Bound();    // The alternate's score will change because alternate.nt[victim] changes.      bool split = top.nt[victim].Split(continuation.nt[victim]);    // top is now the alternate.      continuation.nt[stay] = top.nt[stay]; -  continuation.score = FastScore(context, victim, GetRule().Arity(), top, continuation); +  continuation.score = FastScore(context, victim, arity_, top, continuation);    // TODO: dedupe?      generate_.push(&continuation); @@ -116,14 +103,18 @@ template <class Model> bool EdgeGenerator::Pop(Context<Model> &context, VertexGe      // TODO: dedupe?        generate_.push(&top);    } else { -    parent.FreePartialEdge(&top); +    partial_edge_pool.free(&top);    } -  top_ = generate_.top()->score; -  return true; +  top_score_ = generate_.top()->score; +  return NULL;  } -template bool EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, VertexGenerator &parent); -template bool EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, VertexGenerator &parent); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::RestProbingModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ProbingModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::TrieModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantTrieModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::ArrayTrieModel> &context, boost::pool<> &partial_edge_pool); +template PartialEdge *EdgeGenerator::Pop(Context<lm::ngram::QuantArrayTrieModel> &context, boost::pool<> &partial_edge_pool);  } // namespace search diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh index e306dc61..875ccc5e 100644 --- a/klm/search/edge_generator.hh +++ b/klm/search/edge_generator.hh @@ -2,7 +2,9 @@  #define SEARCH_EDGE_GENERATOR__  #include "search/edge.hh" +#include "search/note.hh" +#include <boost/pool/pool.hpp>  #include <boost/unordered_map.hpp>  #include <functional> @@ -28,26 +30,28 @@ struct PartialEdgePointerLess : std::binary_function<const PartialEdge *, const  class EdgeGenerator {    public: -    // True if it has a hypothesis.   -    bool Init(Edge &edge, VertexGenerator &parent); +    EdgeGenerator(PartialEdge &root, unsigned char arity, Note note); -    Score Top() const { -      return top_; +    Score TopScore() const { +      return top_score_;      } -    template <class Model> bool Pop(Context<Model> &context, VertexGenerator &parent); +    Note GetNote() const { +      return note_; +    } + +    // Pop.  If there's a complete hypothesis, return it.  Otherwise return NULL.   +    template <class Model> PartialEdge *Pop(Context<Model> &context, boost::pool<> &partial_edge_pool);    private: -    const Rule &GetRule() const { -      return from_->GetRule(); -    } +    Score top_score_; -    Score top_; +    unsigned char arity_;      typedef std::priority_queue<PartialEdge*, std::vector<PartialEdge*>, PartialEdgePointerLess> Generate;      Generate generate_; -    Edge *from_; +    Note note_;  };  } // namespace search diff --git a/klm/search/edge_queue.cc b/klm/search/edge_queue.cc new file mode 100644 index 00000000..e3ae6ebf --- /dev/null +++ b/klm/search/edge_queue.cc @@ -0,0 +1,25 @@ +#include "search/edge_queue.hh" + +#include "lm/left.hh" +#include "search/context.hh" + +#include <stdint.h> + +namespace search { + +EdgeQueue::EdgeQueue(unsigned int pop_limit_hint) : partial_edge_pool_(sizeof(PartialEdge), pop_limit_hint * 2) { +  take_ = static_cast<PartialEdge*>(partial_edge_pool_.malloc()); +} + +/*void EdgeQueue::AddEdge(PartialEdge &root, unsigned char arity, Note note) { +  // Ignore empty edges.   +  for (unsigned char i = 0; i < edge.Arity(); ++i) { +    PartialVertex root(edge.GetVertex(i).RootPartial()); +    if (root.Empty()) return; +    total_score += root.Bound(); +  } +  PartialEdge &allocated = *static_cast<PartialEdge*>(partial_edge_pool_.malloc()); +  allocated.score = total_score; +}*/ + +} // namespace search diff --git a/klm/search/edge_queue.hh b/klm/search/edge_queue.hh new file mode 100644 index 00000000..187eaed7 --- /dev/null +++ b/klm/search/edge_queue.hh @@ -0,0 +1,73 @@ +#ifndef SEARCH_EDGE_QUEUE__ +#define SEARCH_EDGE_QUEUE__ + +#include "search/edge.hh" +#include "search/edge_generator.hh" +#include "search/note.hh" + +#include <boost/pool/pool.hpp> +#include <boost/pool/object_pool.hpp> + +#include <queue> + +namespace search { + +template <class Model> class Context; + +class EdgeQueue { +  public: +    explicit EdgeQueue(unsigned int pop_limit_hint); + +    PartialEdge &InitializeEdge() { +      return *take_; +    } + +    void AddEdge(unsigned char arity, Note note) { +      generate_.push(edge_pool_.construct(*take_, arity, note)); +      take_ = static_cast<PartialEdge*>(partial_edge_pool_.malloc()); +    } + +    bool Empty() const { return generate_.empty(); } + +    /* Generate hypotheses and send them to output.  Normally, output is a +     * VertexGenerator, but the decoder may want to route edges to different +     * vertices i.e. if they have different LHS non-terminal labels.   +     */ +    template <class Model, class Output> void Search(Context<Model> &context, Output &output) { +      int to_pop = context.PopLimit(); +      while (to_pop > 0 && !generate_.empty()) { +        EdgeGenerator *top = generate_.top(); +        generate_.pop(); +        PartialEdge *ret = top->Pop(context, partial_edge_pool_); +        if (ret) { +          output.NewHypothesis(*ret, top->GetNote()); +          --to_pop; +          if (top->TopScore() != -kScoreInf) { +            generate_.push(top); +          } +        } else { +          generate_.push(top); +        } +      } +      output.FinishedSearch(); +    } + +  private: +    boost::object_pool<EdgeGenerator> edge_pool_; + +    struct LessByTopScore : public std::binary_function<const EdgeGenerator *, const EdgeGenerator *, bool> { +      bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { +        return first->TopScore() < second->TopScore(); +      } +    }; + +    typedef std::priority_queue<EdgeGenerator*, std::vector<EdgeGenerator*>, LessByTopScore> Generate; +    Generate generate_; + +    boost::pool<> partial_edge_pool_; + +    PartialEdge *take_; +}; + +} // namespace search +#endif // SEARCH_EDGE_QUEUE__ diff --git a/klm/search/final.hh b/klm/search/final.hh index 823b8c1a..1b3092ac 100644 --- a/klm/search/final.hh +++ b/klm/search/final.hh @@ -2,35 +2,34 @@  #define SEARCH_FINAL__  #include "search/arity.hh" +#include "search/note.hh"  #include "search/types.hh"  #include <boost/array.hpp>  namespace search { -class Edge; -  class Final {    public:      typedef boost::array<const Final*, search::kMaxArity> ChildArray; -    void Reset(Score bound, const Edge &from, const Final &left, const Final &right) { +    void Reset(Score bound, Note note, const Final &left, const Final &right) {        bound_ = bound; -      from_ = &from; +      note_ = note;        children_[0] = &left;        children_[1] = &right;      }      const ChildArray &Children() const { return children_; } -    const Edge &From() const { return *from_; } +    Note GetNote() const { return note_; }      Score Bound() const { return bound_; }    private:      Score bound_; -    const Edge *from_; +    Note note_;      ChildArray children_;  }; diff --git a/klm/search/note.hh b/klm/search/note.hh new file mode 100644 index 00000000..50bed06e --- /dev/null +++ b/klm/search/note.hh @@ -0,0 +1,12 @@ +#ifndef SEARCH_NOTE__ +#define SEARCH_NOTE__ + +namespace search { + +union Note { +  const void *vp; +}; + +} // namespace search + +#endif // SEARCH_NOTE__ diff --git a/klm/search/rule.cc b/klm/search/rule.cc index 0a941527..5b00207e 100644 --- a/klm/search/rule.cc +++ b/klm/search/rule.cc @@ -9,35 +9,35 @@  namespace search { -template <class Model> void Rule::Init(const Context<Model> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos) { -  additive_ = additive; -  Score lm_score = 0.0; -  lexical_.clear(); -  const lm::WordIndex oov = context.LanguageModel().GetVocabulary().NotFound(); - +template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing) { +  unsigned int oov_count = 0; +  float prob = 0.0; +  const Model &model = context.LanguageModel(); +  const lm::WordIndex oov = model.GetVocabulary().NotFound();    for (std::vector<lm::WordIndex>::const_iterator word = words.begin(); ; ++word) { -    lexical_.resize(lexical_.size() + 1); -    lm::ngram::RuleScore<Model> scorer(context.LanguageModel(), lexical_.back()); +    lm::ngram::RuleScore<Model> scorer(model, *(writing++));      // TODO: optimize      if (prepend_bos && (word == words.begin())) {        scorer.BeginSentence();      }      for (; ; ++word) {        if (word == words.end()) { -        lm_score += scorer.Finish(); -        bound_ = additive_ + context.GetWeights().LM() * lm_score; -        arity_ = lexical_.size() - 1;  -        return; +        prob += scorer.Finish(); +        return static_cast<float>(oov_count) * context.GetWeights().OOV() + prob * context.GetWeights().LM();        }        if (*word == kNonTerminal) break; -      if (*word == oov) additive_ += context.GetWeights().OOV(); +      if (*word == oov) ++oov_count;        scorer.Terminal(*word);      } -    lm_score += scorer.Finish(); +    prob += scorer.Finish();    }  } -template void Rule::Init(const Context<lm::ngram::RestProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos); -template void Rule::Init(const Context<lm::ngram::ProbingModel> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos); +template float ScoreRule(const Context<lm::ngram::RestProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::ProbingModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::TrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::QuantTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::ArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing); +template float ScoreRule(const Context<lm::ngram::QuantArrayTrieModel> &model, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *writing);  } // namespace search diff --git a/klm/search/rule.hh b/klm/search/rule.hh index 920c64a7..0ce2794d 100644 --- a/klm/search/rule.hh +++ b/klm/search/rule.hh @@ -3,44 +3,17 @@  #include "lm/left.hh"  #include "lm/word_index.hh" -#include "search/arity.hh"  #include "search/types.hh" -#include <boost/array.hpp> - -#include <iosfwd>  #include <vector>  namespace search {  template <class Model> class Context; -class Rule { -  public: -    Rule() : arity_(0) {} - -    static const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; - -    // Use kNonTerminal for non-terminals.   -    template <class Model> void Init(const Context<Model> &context, Score additive, const std::vector<lm::WordIndex> &words, bool prepend_bos); - -    Score Bound() const { return bound_; } - -    Score Additive() const { return additive_; } - -    unsigned int Arity() const { return arity_; } - -    const lm::ngram::ChartState &Lexical(unsigned int index) const { -      return lexical_[index]; -    } - -  private: -    Score bound_, additive_; - -    unsigned int arity_; +const lm::WordIndex kNonTerminal = lm::kMaxWordIndex; -    std::vector<lm::ngram::ChartState> lexical_; -}; +template <class Model> float ScoreRule(const Context<Model> &context, const std::vector<lm::WordIndex> &words, bool prepend_bos, lm::ngram::ChartState *state_out);  } // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index 7ef29efc..e1a9ad11 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -16,8 +16,6 @@ namespace search {  class ContextBase; -class Edge; -  class VertexNode {    public:      VertexNode() : end_(NULL) {} @@ -103,6 +101,10 @@ class PartialVertex {      unsigned char Length() const { return back_->Length(); } +    bool HasAlternative() const { +      return index_ + 1 < back_->Size(); +    } +      // Split into continuation and alternative, rendering this the alternative.      bool Split(PartialVertex &continuation) {        assert(!Complete()); @@ -128,35 +130,26 @@ extern PartialVertex kBlankPartialVertex;  class Vertex {    public: -    Vertex()  -#ifdef DEBUG -      : finished_adding_(false) -#endif -    {} - -    void Add(Edge &edge) { -#ifdef DEBUG -      assert(!finished_adding_); -#endif -      edges_.push_back(&edge); -    } - -    void FinishedAdding() { -#ifdef DEBUG -      assert(!finished_adding_); -      finished_adding_ = true; -#endif -    } +    Vertex() {}      PartialVertex RootPartial() const { return PartialVertex(root_); } +    const Final *BestChild() const { +      PartialVertex top(RootPartial()); +      if (top.Empty()) { +        return NULL; +      } else { +        PartialVertex continuation; +        while (!top.Complete()) { +          top.Split(continuation); +          top = continuation; +        } +        return &top.End(); +      } +    } +    private:      friend class VertexGenerator; -    std::vector<Edge*> edges_; - -#ifdef DEBUG -    bool finished_adding_; -#endif      VertexNode root_;  }; diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index 78948c97..d94e6e06 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -2,45 +2,30 @@  #include "lm/left.hh"  #include "search/context.hh" +#include "search/edge.hh"  #include <stdint.h>  namespace search { -template <class Model> VertexGenerator::VertexGenerator(Context<Model> &context, Vertex &gen) : context_(context), edges_(gen.edges_.size()), partial_edge_pool_(sizeof(PartialEdge), context.PopLimit() * 2) { -  for (std::size_t i = 0; i < gen.edges_.size(); ++i) { -    if (edges_[i].Init(*gen.edges_[i], *this)) -      generate_.push(&edges_[i]); -  } +VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) {    gen.root_.InitRoot();    root_.under = &gen.root_; -  to_pop_ = context.PopLimit(); -  while (to_pop_ > 0 && !generate_.empty()) { -    EdgeGenerator *top = generate_.top(); -    generate_.pop(); -    if (top->Pop(context, *this)) { -      generate_.push(top); -    } -  } -  gen.root_.SortAndSet(context, NULL);  } -template VertexGenerator::VertexGenerator(Context<lm::ngram::ProbingModel> &context, Vertex &gen); -template VertexGenerator::VertexGenerator(Context<lm::ngram::RestProbingModel> &context, Vertex &gen); -  namespace {  const uint64_t kCompleteAdd = static_cast<uint64_t>(-1);  } // namespace -void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) { +void VertexGenerator::NewHypothesis(const PartialEdge &partial, Note note) { +  const lm::ngram::ChartState &state = partial.CompletedState();    std::pair<Existing::iterator, bool> got(existing_.insert(std::pair<uint64_t, Final*>(hash_value(state), NULL)));    if (!got.second) {      // Found it already.        Final &exists = *got.first->second;      if (exists.Bound() < partial.score) { -      exists.Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End()); +      exists.Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End());      } -    --to_pop_;      return;    }    unsigned char left = 0, right = 0; @@ -67,8 +52,7 @@ void VertexGenerator::NewHypothesis(const lm::ngram::ChartState &state, const Ed    }    node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); -  got.first->second = CompleteTransition(*node, state, from, partial); -  --to_pop_; +  got.first->second = CompleteTransition(*node, state, note, partial);  }  VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { @@ -86,12 +70,12 @@ VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node    return next;  } -Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial) { +Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial) {    VertexNode &node = *starter.under;    assert(node.State().left.full == state.left.full);    assert(!node.End());    Final *final = context_.NewFinal(); -  final->Reset(partial.score, from, partial.nt[0].End(), partial.nt[1].End()); +  final->Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End());    node.SetEnd(final);    return final;  } diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 8cdf1420..6b98da3e 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -1,10 +1,9 @@  #ifndef SEARCH_VERTEX_GENERATOR__  #define SEARCH_VERTEX_GENERATOR__ -#include "search/edge.hh" -#include "search/edge_generator.hh" +#include "search/note.hh" +#include "search/vertex.hh" -#include <boost/pool/pool.hpp>  #include <boost/unordered_map.hpp>  #include <queue> @@ -17,18 +16,21 @@ class ChartState;  namespace search { -template <class Model> class Context;  class ContextBase;  class Final; +struct PartialEdge;  class VertexGenerator {    public: -    template <class Model> VertexGenerator(Context<Model> &context, Vertex &gen); +    VertexGenerator(ContextBase &context, Vertex &gen); -    PartialEdge *MallocPartialEdge() { return static_cast<PartialEdge*>(partial_edge_pool_.malloc()); } -    void FreePartialEdge(PartialEdge *value) { partial_edge_pool_.free(value); } +    void NewHypothesis(const PartialEdge &partial, Note note); -    void NewHypothesis(const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial); +    void FinishedSearch() { +      root_.under->SortAndSet(context_, NULL); +    } + +    const Vertex &Generating() const { return gen_; }    private:      // Parallel structure to VertexNode.   @@ -41,29 +43,16 @@ class VertexGenerator {      Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full); -    Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, const Edge &from, const PartialEdge &partial); +    Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial);      ContextBase &context_; -    std::vector<EdgeGenerator> edges_; - -    struct LessByTop : public std::binary_function<const EdgeGenerator *, const EdgeGenerator *, bool> { -      bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { -        return first->Top() < second->Top(); -      } -    }; - -    typedef std::priority_queue<EdgeGenerator*, std::vector<EdgeGenerator*>, LessByTop> Generate; -    Generate generate_; +    Vertex &gen_;      Trie root_;      typedef boost::unordered_map<uint64_t, Final*> Existing;      Existing existing_; - -    int to_pop_; - -    boost::pool<> partial_edge_pool_;  };  } // namespace search | 
