summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain_net.cc
blob: 306da957d39ecfb554fac46d77106b68ce567803 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include "dtrain_net.h"
#include "sample.h"
#include "score.h"
#include "update.h"

#include <nanomsg/nn.h>
#include <nanomsg/pair.h>
#include "nn.hpp"

using namespace dtrain;

int
main(int argc, char** argv)
{
  // get configuration
  po::variables_map conf;
  if (!dtrain_net_init(argc, argv, &conf))
    exit(1); // something is wrong
  const size_t k              = conf["k"].as<size_t>();
  const size_t N              = conf["N"].as<size_t>();
  const weight_t margin       = conf["margin"].as<weight_t>();
  const string master_addr    = conf["addr"].as<string>();

  // setup decoder
  register_feature_functions();
  SetSilent(true);
  ReadFile f(conf["decoder_conf"].as<string>());
  Decoder decoder(f.stream());
  ScoredKbest* observer = new ScoredKbest(k, new PerSentenceBleuScorer(N));

  // weights
  vector<weight_t>& decoder_weights = decoder.CurrentWeightVector();
  SparseVector<weight_t> lambdas, w_average;
  if (conf.count("input_weights")) {
    Weights::InitFromFile(conf["input_weights"].as<string>(), &decoder_weights);
    Weights::InitSparseVector(decoder_weights, &lambdas);
  }

  cerr << _p4;
  // output configuration
  cerr << "dtrain_net" << endl << "Parameters:" << endl;
  cerr << setw(25) << "k " << k << endl;
  cerr << setw(25) << "N " << N << endl;
  cerr << setw(25) << "margin " << margin << endl;
  cerr << setw(25) << "decoder conf " << "'"
       << conf["decoder_conf"].as<string>() << "'" << endl;

  // socket
  nn::socket sock(AF_SP, NN_PAIR);
  sock.connect(master_addr.c_str());
  sock.send("hello", 6, 0);

  size_t i = 0;
  while(true)
  {
    char *buf = NULL;
    string source;
    vector<Ngrams> refs;
    vector<size_t> rsz;
    bool next = true;
    size_t sz = sock.recv(&buf, NN_MSG, 0);
    if (buf) {
      const string in(buf, buf+sz);
      nn::freemsg(buf);
      if (in == "shutdown") {
        next = false;
      } else {
        vector<string> parts;
        boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| "));
        if (parts[0] == "act:translate") {
        } else {
          source = parts[0];
          parts.erase(parts.begin());
          for (auto s: parts) {
            vector<WordID> r;
            vector<string> toks;
            boost::split(toks, s, boost::is_any_of(" "));
            for (auto tok: toks)
              r.push_back(TD::Convert(tok));
            refs.emplace_back(MakeNgrams(r, N));
            rsz.push_back(r.size());
          }
        }
      }
    }
    if (next) {
      if (i%20 == 0)
        cerr << " ";
      cerr << ".";
      if ((i+1)%20==0)
        cerr << " " << i+1 << endl;
    } else {
      if (i%20 != 0)
        cerr << " " << i << endl;
    }
    cerr.flush();

    if (!next)
      break;

    // decode
    lambdas.init_vector(&decoder_weights);
    observer->SetReference(refs, rsz);
    decoder.Decode(source, observer);
    vector<ScoredHyp>* samples = observer->GetSamples();

    // get pairs and update
    SparseVector<weight_t> updates;
    CollectUpdates(samples, updates, margin);
    ostringstream os;
    vectorAsString(updates, os);
    sock.send(os.str().c_str(), os.str().size()+1, 0);
    buf = NULL;
    sz = sock.recv(&buf, NN_MSG, 0);
    string new_weights(buf, buf+sz);
    nn::freemsg(buf);
    lambdas.clear();
    updateVectorFromString(new_weights, lambdas);
    i++;
  } // input loop

  return 0;
}