summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain_net_interface.cc
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2015-05-08 21:43:45 +0200
committerPatrick Simianer <p@simianer.de>2015-05-08 21:43:45 +0200
commitd0b8fa29b83e6424e6d5848dbc42734b03896304 (patch)
treeb1f2529183e99b92b20f972d3c5e6739ad855adf /training/dtrain/dtrain_net_interface.cc
parentf678b442e8a0c2e685652d2b7006ccccce989c81 (diff)
parent64aac199c4a8821772dfaaaa9d162f4a3f5bf121 (diff)
Merge branch 'net' of github.com:pks/cdec-dtrain into net
Diffstat (limited to 'training/dtrain/dtrain_net_interface.cc')
-rw-r--r--training/dtrain/dtrain_net_interface.cc120
1 files changed, 120 insertions, 0 deletions
diff --git a/training/dtrain/dtrain_net_interface.cc b/training/dtrain/dtrain_net_interface.cc
new file mode 100644
index 00000000..f484b56b
--- /dev/null
+++ b/training/dtrain/dtrain_net_interface.cc
@@ -0,0 +1,120 @@
+#include "dtrain_net.h"
+#include "sample_net.h"
+#include "score.h"
+#include "update.h"
+
+#include <nanomsg/nn.h>
+#include <nanomsg/pair.h>
+#include "nn.hpp"
+
+using namespace dtrain;
+
+int
+main(int argc, char** argv)
+{
+ // get configuration
+ po::variables_map conf;
+ if (!dtrain_net_init(argc, argv, &conf))
+ exit(1); // something is wrong
+ const size_t k = conf["k"].as<size_t>();
+ const size_t N = conf["N"].as<size_t>();
+ const weight_t margin = conf["margin"].as<weight_t>();
+ const string master_addr = conf["addr"].as<string>();
+
+ // setup decoder
+ register_feature_functions();
+ SetSilent(true);
+ ReadFile f(conf["decoder_conf"].as<string>());
+ Decoder decoder(f.stream());
+ ScoredKbest* observer = new ScoredKbest(k, new PerSentenceBleuScorer(N));
+
+ // weights
+ vector<weight_t>& decoder_weights = decoder.CurrentWeightVector();
+ SparseVector<weight_t> lambdas, w_average;
+ if (conf.count("input_weights")) {
+ Weights::InitFromFile(conf["input_weights"].as<string>(), &decoder_weights);
+ Weights::InitSparseVector(decoder_weights, &lambdas);
+ }
+
+ cerr << _p4;
+ // output configuration
+ cerr << "dtrain_net" << endl << "Parameters:" << endl;
+ cerr << setw(25) << "k " << k << endl;
+ cerr << setw(25) << "N " << N << endl;
+ cerr << setw(25) << "margin " << margin << endl;
+ cerr << setw(25) << "decoder conf " << "'"
+ << conf["decoder_conf"].as<string>() << "'" << endl;
+
+ // socket
+ nn::socket sock(AF_SP, NN_PAIR);
+ sock.connect(master_addr.c_str());
+
+ size_t i = 0;
+ while(true)
+ {
+ char *buf = NULL;
+ string source;
+ vector<Ngrams> refs;
+ vector<size_t> rsz;
+ bool next = true;
+ size_t sz = sock.recv(&buf, NN_MSG, 0);
+ if (buf) {
+ const string in(buf, buf+sz);
+ nn::freemsg(buf);
+ if (in == "shutdown") {
+ next = false;
+ } else {
+ vector<string> parts;
+ boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| "));
+ if (parts[0] == "act:translate") {
+ cerr << "translating ..." << endl;
+ lambdas.init_vector(&decoder_weights);
+ observer->dont_score = true;
+ decoder.Decode(parts[1], observer);
+ observer->dont_score = false;
+ vector<ScoredHyp>* samples = observer->GetSamples();
+ ostringstream os;
+ PrintWordIDVec((*samples)[0].w, os);
+ sock.send(os.str().c_str(), os.str().size()+1, 0);
+ cerr << "done" << endl;
+ continue;
+ } else {
+ cerr << "learning ..." << endl;
+ source = parts[0];
+ parts.erase(parts.begin());
+ for (auto s: parts) {
+ vector<WordID> r;
+ vector<string> toks;
+ boost::split(toks, s, boost::is_any_of(" "));
+ for (auto tok: toks)
+ r.push_back(TD::Convert(tok));
+ refs.emplace_back(MakeNgrams(r, N));
+ rsz.push_back(r.size());
+ }
+ }
+ }
+ }
+
+ if (!next)
+ break;
+
+ // decode
+ lambdas.init_vector(&decoder_weights);
+ observer->SetReference(refs, rsz);
+ decoder.Decode(source, observer);
+ vector<ScoredHyp>* samples = observer->GetSamples();
+
+ // get pairs and update
+ SparseVector<weight_t> updates;
+ CollectUpdates(samples, updates, margin);
+ lambdas.plus_eq_v_times_s(updates, 1.0); // fixme
+ string s = "x";
+ sock.send(s.c_str(), s.size()+1, 0);
+ i++;
+
+ cerr << "done" << endl;
+ } // input loop
+
+ return 0;
+}
+