summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain_net.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain_net.cc')
-rw-r--r--training/dtrain/dtrain_net.cc121
1 files changed, 121 insertions, 0 deletions
diff --git a/training/dtrain/dtrain_net.cc b/training/dtrain/dtrain_net.cc
new file mode 100644
index 00000000..946b7587
--- /dev/null
+++ b/training/dtrain/dtrain_net.cc
@@ -0,0 +1,121 @@
+#include "dtrain_net.h"
+#include "sample.h"
+#include "score.h"
+#include "update.h"
+
+#include <nanomsg/nn.h>
+#include <nanomsg/pair.h>
+#include "nn.hpp"
+
+using namespace dtrain;
+
+int
+main(int argc, char** argv)
+{
+ // get configuration
+ po::variables_map conf;
+ if (!dtrain_net_init(argc, argv, &conf))
+ exit(1); // something is wrong
+ const size_t k = conf["k"].as<size_t>();
+ const size_t N = conf["N"].as<size_t>();
+ const weight_t margin = conf["margin"].as<weight_t>();
+ const string master_addr = conf["addr"].as<string>();
+
+ // setup decoder
+ register_feature_functions();
+ SetSilent(true);
+ ReadFile f(conf["decoder_conf"].as<string>());
+ Decoder decoder(f.stream());
+ ScoredKbest* observer = new ScoredKbest(k, new PerSentenceBleuScorer(N));
+
+ // weights
+ vector<weight_t>& decoder_weights = decoder.CurrentWeightVector();
+ SparseVector<weight_t> lambdas, w_average;
+ if (conf.count("input_weights")) {
+ Weights::InitFromFile(conf["input_weights"].as<string>(), &decoder_weights);
+ Weights::InitSparseVector(decoder_weights, &lambdas);
+ }
+
+ cerr << _p4;
+ // output configuration
+ cerr << "dtrain_net" << endl << "Parameters:" << endl;
+ cerr << setw(25) << "k " << k << endl;
+ cerr << setw(25) << "N " << N << endl;
+ cerr << setw(25) << "margin " << margin << endl;
+ cerr << setw(25) << "decoder conf " << "'"
+ << conf["decoder_conf"].as<string>() << "'" << endl;
+
+ // socket
+ nn::socket sock(AF_SP, NN_PAIR);
+ sock.connect(master_addr.c_str());
+ sock.send("hello", 6, 0);
+
+ size_t i = 0;
+ while(true)
+ {
+ char *buf = NULL;
+ string source;
+ vector<Ngrams> refs;
+ vector<size_t> rsz;
+ bool next = true;
+ size_t sz = sock.recv(&buf, NN_MSG, 0);
+ if (buf) {
+ const string in(buf, buf+sz);
+ nn::freemsg(buf);
+ if (in == "shutdown") {
+ next = false;
+ } else {
+ vector<string> parts;
+ boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| "));
+ source = parts[0];
+ parts.erase(parts.begin());
+ for (auto s: parts) {
+ vector<WordID> r;
+ vector<string> toks;
+ boost::split(toks, s, boost::is_any_of(" "));
+ for (auto tok: toks)
+ r.push_back(TD::Convert(tok));
+ refs.emplace_back(MakeNgrams(r, N));
+ rsz.push_back(r.size());
+ }
+ }
+ }
+ if (next) {
+ if (i%20 == 0)
+ cerr << " ";
+ cerr << ".";
+ if ((i+1)%20==0)
+ cerr << " " << i+1 << endl;
+ } else {
+ if (i%20 != 0)
+ cerr << " " << i << endl;
+ }
+ cerr.flush();
+
+ if (!next)
+ break;
+
+ // decode
+ lambdas.init_vector(&decoder_weights);
+ observer->SetReference(refs, rsz);
+ decoder.Decode(source, observer);
+ vector<ScoredHyp>* samples = observer->GetSamples();
+
+ // get pairs and update
+ SparseVector<weight_t> updates;
+ CollectUpdates(samples, updates, margin);
+ ostringstream os;
+ vectorAsString(updates, os);
+ sock.send(os.str().c_str(), os.str().size()+1, 0);
+ buf = NULL;
+ sz = sock.recv(&buf, NN_MSG, 0);
+ string new_weights(buf, buf+sz);
+ nn::freemsg(buf);
+ lambdas.clear();
+ updateVectorFromString(new_weights, lambdas);
+ i++;
+ } // input loop
+
+ return 0;
+}
+