summaryrefslogtreecommitdiff
path: root/klm/alone/threading.hh
blob: 0ab0f73917fc31198aa0afb2863e00bf702118d3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#ifndef ALONE_THREADING__
#define ALONE_THREADING__

#ifdef WITH_THREADS
#include "util/pcqueue.hh"
#include "util/pool.hh"
#endif

#include <iosfwd>
#include <queue>
#include <string>

namespace util {
class FilePiece;
} // namespace util

namespace search {
class Config;
template <class Model> class Context;
} // namespace search

namespace alone {

template <class Model> void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out);

class Graph;

#ifdef WITH_THREADS
struct SentenceID {
  unsigned int sentence_id;
  bool operator==(const SentenceID &other) const {
    return sentence_id == other.sentence_id;
  }
};

struct Input : public SentenceID {
  util::FilePiece *file;
  static Input Poison() {
    Input ret;
    ret.sentence_id = static_cast<unsigned int>(-1);
    ret.file = NULL;
    return ret;
  }
};

struct Output : public SentenceID {
  std::string *str;
  static Output Poison() {
    Output ret;
    ret.sentence_id = static_cast<unsigned int>(-1);
    ret.str = NULL;
    return ret;
  }
};

template <class Model> class DecodeHandler {
  public:
    typedef Input Request;

    DecodeHandler(const search::Config &config, const Model &model, util::PCQueue<Output> &out) : config_(config), model_(model), out_(out) {}

    void operator()(Input message);

  private:
    void Produce(unsigned int sentence_id, const std::string &str);

    const search::Config &config_;

    const Model &model_;
    
    util::PCQueue<Output> &out_;
};

class PrintHandler {
  public:
    typedef Output Request;

    explicit PrintHandler(std::ostream &o) : out_(o), done_(0) {}

    void operator()(Output message);

  private:
    std::ostream &out_;
    std::deque<std::string*> waiting_;
    unsigned int done_;
};

template <class Model> class Controller {
  public:
    // This config must remain valid.   
    explicit Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to);

    // Takes ownership of in.    
    void Add(util::FilePiece *in) {
      Input input;
      input.sentence_id = sentence_id_++;
      input.file = in;
      decoder_.Produce(input);
    }

  private:
    unsigned int sentence_id_;

    util::Pool<PrintHandler> printer_;

    util::Pool<DecodeHandler<Model> > decoder_;
};
#endif

// Same API as controller.  
template <class Model> class InThread {
  public:
    InThread(const search::Config &config, const Model &model, std::ostream &to) : config_(config), model_(model), to_(to) {}

    // Takes ownership of in.  
    void Add(util::FilePiece *in) {
      Decode(config_, model_, in, to_);
    }

  private:
    const search::Config &config_;

    const Model &model_;

    std::ostream &to_; 
};

} // namespace alone
#endif // ALONE_THREADING__