summaryrefslogtreecommitdiff
path: root/klm/search/edge_queue.hh
blob: 187eaed715e716c7f21fd787f31960ddff6e176b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#ifndef SEARCH_EDGE_QUEUE__
#define SEARCH_EDGE_QUEUE__

#include "search/edge.hh"
#include "search/edge_generator.hh"
#include "search/note.hh"

#include <boost/pool/pool.hpp>
#include <boost/pool/object_pool.hpp>

#include <queue>

namespace search {

template <class Model> class Context;

class EdgeQueue {
  public:
    explicit EdgeQueue(unsigned int pop_limit_hint);

    PartialEdge &InitializeEdge() {
      return *take_;
    }

    void AddEdge(unsigned char arity, Note note) {
      generate_.push(edge_pool_.construct(*take_, arity, note));
      take_ = static_cast<PartialEdge*>(partial_edge_pool_.malloc());
    }

    bool Empty() const { return generate_.empty(); }

    /* Generate hypotheses and send them to output.  Normally, output is a
     * VertexGenerator, but the decoder may want to route edges to different
     * vertices i.e. if they have different LHS non-terminal labels.  
     */
    template <class Model, class Output> void Search(Context<Model> &context, Output &output) {
      int to_pop = context.PopLimit();
      while (to_pop > 0 && !generate_.empty()) {
        EdgeGenerator *top = generate_.top();
        generate_.pop();
        PartialEdge *ret = top->Pop(context, partial_edge_pool_);
        if (ret) {
          output.NewHypothesis(*ret, top->GetNote());
          --to_pop;
          if (top->TopScore() != -kScoreInf) {
            generate_.push(top);
          }
        } else {
          generate_.push(top);
        }
      }
      output.FinishedSearch();
    }

  private:
    boost::object_pool<EdgeGenerator> edge_pool_;

    struct LessByTopScore : public std::binary_function<const EdgeGenerator *, const EdgeGenerator *, bool> {
      bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const {
        return first->TopScore() < second->TopScore();
      }
    };

    typedef std::priority_queue<EdgeGenerator*, std::vector<EdgeGenerator*>, LessByTopScore> Generate;
    Generate generate_;

    boost::pool<> partial_edge_pool_;

    PartialEdge *take_;
};

} // namespace search
#endif // SEARCH_EDGE_QUEUE__