From 64aac199c4a8821772dfaaaa9d162f4a3f5bf121 Mon Sep 17 00:00:00 2001
From: Patrick Simianer
Date: Thu, 2 Apr 2015 14:06:02 +0200
Subject: enable translation or learning: dtrain_net_interface
---
training/dtrain/dtrain_net_interface.cc | 120 ++++++++++++++++++++++++++++++++
1 file changed, 120 insertions(+)
create mode 100644 training/dtrain/dtrain_net_interface.cc
(limited to 'training/dtrain/dtrain_net_interface.cc')
diff --git a/training/dtrain/dtrain_net_interface.cc b/training/dtrain/dtrain_net_interface.cc
new file mode 100644
index 00000000..f484b56b
--- /dev/null
+++ b/training/dtrain/dtrain_net_interface.cc
@@ -0,0 +1,120 @@
+#include "dtrain_net.h"
+#include "sample_net.h"
+#include "score.h"
+#include "update.h"
+
+#include
+#include
+#include "nn.hpp"
+
+using namespace dtrain;
+
+int
+main(int argc, char** argv)
+{
+ // get configuration
+ po::variables_map conf;
+ if (!dtrain_net_init(argc, argv, &conf))
+ exit(1); // something is wrong
+ const size_t k = conf["k"].as();
+ const size_t N = conf["N"].as();
+ const weight_t margin = conf["margin"].as();
+ const string master_addr = conf["addr"].as();
+
+ // setup decoder
+ register_feature_functions();
+ SetSilent(true);
+ ReadFile f(conf["decoder_conf"].as());
+ Decoder decoder(f.stream());
+ ScoredKbest* observer = new ScoredKbest(k, new PerSentenceBleuScorer(N));
+
+ // weights
+ vector& decoder_weights = decoder.CurrentWeightVector();
+ SparseVector lambdas, w_average;
+ if (conf.count("input_weights")) {
+ Weights::InitFromFile(conf["input_weights"].as(), &decoder_weights);
+ Weights::InitSparseVector(decoder_weights, &lambdas);
+ }
+
+ cerr << _p4;
+ // output configuration
+ cerr << "dtrain_net" << endl << "Parameters:" << endl;
+ cerr << setw(25) << "k " << k << endl;
+ cerr << setw(25) << "N " << N << endl;
+ cerr << setw(25) << "margin " << margin << endl;
+ cerr << setw(25) << "decoder conf " << "'"
+ << conf["decoder_conf"].as() << "'" << endl;
+
+ // socket
+ nn::socket sock(AF_SP, NN_PAIR);
+ sock.connect(master_addr.c_str());
+
+ size_t i = 0;
+ while(true)
+ {
+ char *buf = NULL;
+ string source;
+ vector refs;
+ vector rsz;
+ bool next = true;
+ size_t sz = sock.recv(&buf, NN_MSG, 0);
+ if (buf) {
+ const string in(buf, buf+sz);
+ nn::freemsg(buf);
+ if (in == "shutdown") {
+ next = false;
+ } else {
+ vector parts;
+ boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| "));
+ if (parts[0] == "act:translate") {
+ cerr << "translating ..." << endl;
+ lambdas.init_vector(&decoder_weights);
+ observer->dont_score = true;
+ decoder.Decode(parts[1], observer);
+ observer->dont_score = false;
+ vector* samples = observer->GetSamples();
+ ostringstream os;
+ PrintWordIDVec((*samples)[0].w, os);
+ sock.send(os.str().c_str(), os.str().size()+1, 0);
+ cerr << "done" << endl;
+ continue;
+ } else {
+ cerr << "learning ..." << endl;
+ source = parts[0];
+ parts.erase(parts.begin());
+ for (auto s: parts) {
+ vector r;
+ vector toks;
+ boost::split(toks, s, boost::is_any_of(" "));
+ for (auto tok: toks)
+ r.push_back(TD::Convert(tok));
+ refs.emplace_back(MakeNgrams(r, N));
+ rsz.push_back(r.size());
+ }
+ }
+ }
+ }
+
+ if (!next)
+ break;
+
+ // decode
+ lambdas.init_vector(&decoder_weights);
+ observer->SetReference(refs, rsz);
+ decoder.Decode(source, observer);
+ vector* samples = observer->GetSamples();
+
+ // get pairs and update
+ SparseVector updates;
+ CollectUpdates(samples, updates, margin);
+ lambdas.plus_eq_v_times_s(updates, 1.0); // fixme
+ string s = "x";
+ sock.send(s.c_str(), s.size()+1, 0);
+ i++;
+
+ cerr << "done" << endl;
+ } // input loop
+
+ return 0;
+}
+
--
cgit v1.2.3