From 64aac199c4a8821772dfaaaa9d162f4a3f5bf121 Mon Sep 17 00:00:00 2001
From: Patrick Simianer
Date: Thu, 2 Apr 2015 14:06:02 +0200
Subject: enable translation or learning: dtrain_net_interface
---
training/dtrain/Makefile.am | 5 +-
training/dtrain/dtrain_net.cc | 23 +++---
training/dtrain/dtrain_net_interface.cc | 120 ++++++++++++++++++++++++++++++++
training/dtrain/feed1.rb | 24 +++++++
training/dtrain/sample_net.h | 61 ++++++++++++++++
5 files changed, 222 insertions(+), 11 deletions(-)
create mode 100644 training/dtrain/dtrain_net_interface.cc
create mode 100755 training/dtrain/feed1.rb
create mode 100644 training/dtrain/sample_net.h
(limited to 'training/dtrain')
diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am
index 590218be..d93597c9 100644
--- a/training/dtrain/Makefile.am
+++ b/training/dtrain/Makefile.am
@@ -1,4 +1,4 @@
-bin_PROGRAMS = dtrain dtrain_net
+bin_PROGRAMS = dtrain dtrain_net dtrain_net_interface
dtrain_SOURCES = dtrain.cc dtrain.h sample.h score.h update.h
dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a
@@ -6,5 +6,8 @@ dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mte
dtrain_net_SOURCES = dtrain_net.cc dtrain_net.h dtrain.h sample.h score.h update.h
dtrain_net_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a /usr/lib64/libnanomsg.so
+dtrain_net_interface_SOURCES = dtrain_net_interface.cc dtrain_net.h dtrain.h sample_net.h score.h update.h
+dtrain_net_interface_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a /usr/lib64/libnanomsg.so
+
AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval
diff --git a/training/dtrain/dtrain_net.cc b/training/dtrain/dtrain_net.cc
index 946b7587..306da957 100644
--- a/training/dtrain/dtrain_net.cc
+++ b/training/dtrain/dtrain_net.cc
@@ -67,16 +67,19 @@ main(int argc, char** argv)
} else {
vector parts;
boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| "));
- source = parts[0];
- parts.erase(parts.begin());
- for (auto s: parts) {
- vector r;
- vector toks;
- boost::split(toks, s, boost::is_any_of(" "));
- for (auto tok: toks)
- r.push_back(TD::Convert(tok));
- refs.emplace_back(MakeNgrams(r, N));
- rsz.push_back(r.size());
+ if (parts[0] == "act:translate") {
+ } else {
+ source = parts[0];
+ parts.erase(parts.begin());
+ for (auto s: parts) {
+ vector r;
+ vector toks;
+ boost::split(toks, s, boost::is_any_of(" "));
+ for (auto tok: toks)
+ r.push_back(TD::Convert(tok));
+ refs.emplace_back(MakeNgrams(r, N));
+ rsz.push_back(r.size());
+ }
}
}
}
diff --git a/training/dtrain/dtrain_net_interface.cc b/training/dtrain/dtrain_net_interface.cc
new file mode 100644
index 00000000..f484b56b
--- /dev/null
+++ b/training/dtrain/dtrain_net_interface.cc
@@ -0,0 +1,120 @@
+#include "dtrain_net.h"
+#include "sample_net.h"
+#include "score.h"
+#include "update.h"
+
+#include
+#include
+#include "nn.hpp"
+
+using namespace dtrain;
+
+int
+main(int argc, char** argv)
+{
+ // get configuration
+ po::variables_map conf;
+ if (!dtrain_net_init(argc, argv, &conf))
+ exit(1); // something is wrong
+ const size_t k = conf["k"].as();
+ const size_t N = conf["N"].as();
+ const weight_t margin = conf["margin"].as();
+ const string master_addr = conf["addr"].as();
+
+ // setup decoder
+ register_feature_functions();
+ SetSilent(true);
+ ReadFile f(conf["decoder_conf"].as());
+ Decoder decoder(f.stream());
+ ScoredKbest* observer = new ScoredKbest(k, new PerSentenceBleuScorer(N));
+
+ // weights
+ vector& decoder_weights = decoder.CurrentWeightVector();
+ SparseVector lambdas, w_average;
+ if (conf.count("input_weights")) {
+ Weights::InitFromFile(conf["input_weights"].as(), &decoder_weights);
+ Weights::InitSparseVector(decoder_weights, &lambdas);
+ }
+
+ cerr << _p4;
+ // output configuration
+ cerr << "dtrain_net" << endl << "Parameters:" << endl;
+ cerr << setw(25) << "k " << k << endl;
+ cerr << setw(25) << "N " << N << endl;
+ cerr << setw(25) << "margin " << margin << endl;
+ cerr << setw(25) << "decoder conf " << "'"
+ << conf["decoder_conf"].as() << "'" << endl;
+
+ // socket
+ nn::socket sock(AF_SP, NN_PAIR);
+ sock.connect(master_addr.c_str());
+
+ size_t i = 0;
+ while(true)
+ {
+ char *buf = NULL;
+ string source;
+ vector refs;
+ vector rsz;
+ bool next = true;
+ size_t sz = sock.recv(&buf, NN_MSG, 0);
+ if (buf) {
+ const string in(buf, buf+sz);
+ nn::freemsg(buf);
+ if (in == "shutdown") {
+ next = false;
+ } else {
+ vector parts;
+ boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| "));
+ if (parts[0] == "act:translate") {
+ cerr << "translating ..." << endl;
+ lambdas.init_vector(&decoder_weights);
+ observer->dont_score = true;
+ decoder.Decode(parts[1], observer);
+ observer->dont_score = false;
+ vector* samples = observer->GetSamples();
+ ostringstream os;
+ PrintWordIDVec((*samples)[0].w, os);
+ sock.send(os.str().c_str(), os.str().size()+1, 0);
+ cerr << "done" << endl;
+ continue;
+ } else {
+ cerr << "learning ..." << endl;
+ source = parts[0];
+ parts.erase(parts.begin());
+ for (auto s: parts) {
+ vector r;
+ vector toks;
+ boost::split(toks, s, boost::is_any_of(" "));
+ for (auto tok: toks)
+ r.push_back(TD::Convert(tok));
+ refs.emplace_back(MakeNgrams(r, N));
+ rsz.push_back(r.size());
+ }
+ }
+ }
+ }
+
+ if (!next)
+ break;
+
+ // decode
+ lambdas.init_vector(&decoder_weights);
+ observer->SetReference(refs, rsz);
+ decoder.Decode(source, observer);
+ vector* samples = observer->GetSamples();
+
+ // get pairs and update
+ SparseVector updates;
+ CollectUpdates(samples, updates, margin);
+ lambdas.plus_eq_v_times_s(updates, 1.0); // fixme
+ string s = "x";
+ sock.send(s.c_str(), s.size()+1, 0);
+ i++;
+
+ cerr << "done" << endl;
+ } // input loop
+
+ return 0;
+}
+
diff --git a/training/dtrain/feed1.rb b/training/dtrain/feed1.rb
new file mode 100755
index 00000000..a76e4c1e
--- /dev/null
+++ b/training/dtrain/feed1.rb
@@ -0,0 +1,24 @@
+#!/usr/bin/env ruby
+
+require 'nanomsg'
+
+#port = ARGV[0]
+port = 60667
+sock = NanoMsg::PairSocket.new
+addr = "tcp://127.0.0.1:#{port}"
+#addr = "ipc:///tmp/xxx.ipc"
+sock.bind addr
+
+#puts sock.recv
+while true
+ line = STDIN.gets
+ if !line
+ sock.send 'shutdown'
+ break
+ end
+ sock.send line.strip
+ sleep 1
+ puts "got translation: #{sock.recv}\n\n"
+ #sock.send "a=1 b=2"
+end
+
diff --git a/training/dtrain/sample_net.h b/training/dtrain/sample_net.h
new file mode 100644
index 00000000..497149d9
--- /dev/null
+++ b/training/dtrain/sample_net.h
@@ -0,0 +1,61 @@
+#ifndef _DTRAIN_SAMPLE_NET_H_
+#define _DTRAIN_SAMPLE_NET_H_
+
+#include "kbest.h"
+
+#include "score.h"
+
+namespace dtrain
+{
+
+struct ScoredKbest : public DecoderObserver
+{
+ const size_t k_;
+ size_t feature_count_, effective_sz_;
+ vector samples_;
+ PerSentenceBleuScorer* scorer_;
+ vector* ref_ngs_;
+ vector* ref_ls_;
+ bool dont_score;
+
+ ScoredKbest(const size_t k, PerSentenceBleuScorer* scorer) :
+ k_(k), scorer_(scorer), dont_score(false) {}
+
+ virtual void
+ NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg)
+ {
+ samples_.clear(); effective_sz_ = feature_count_ = 0;
+ KBest::KBestDerivations, ESentenceTraversal,
+ KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_);
+ for (size_t i = 0; i < k_; ++i) {
+ const KBest::KBestDerivations, ESentenceTraversal,
+ KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d =
+ kbest.LazyKthBest(hg->nodes_.size() - 1, i);
+ if (!d) break;
+ ScoredHyp h;
+ h.w = d->yield;
+ h.f = d->feature_values;
+ h.model = log(d->score);
+ h.rank = i;
+ if (!dont_score)
+ h.gold = scorer_->Score(h.w, *ref_ngs_, *ref_ls_);
+ samples_.push_back(h);
+ effective_sz_++;
+ feature_count_ += h.f.size();
+ }
+ }
+
+ vector* GetSamples() { return &samples_; }
+ inline void SetReference(vector& ngs, vector& ls)
+ {
+ ref_ngs_ = &ngs;
+ ref_ls_ = &ls;
+ }
+ inline size_t GetFeatureCount() { return feature_count_; }
+ inline size_t GetSize() { return effective_sz_; }
+};
+
+} // namespace
+
+#endif
+
--
cgit v1.2.3