diff options
Diffstat (limited to 'training/dtrain')
-rw-r--r-- | training/dtrain/Makefile.am | 5 | ||||
-rwxr-xr-x | training/dtrain/downpour.rb | 120 | ||||
-rw-r--r-- | training/dtrain/dtrain_net.cc | 121 | ||||
-rw-r--r-- | training/dtrain/dtrain_net.h | 72 | ||||
-rw-r--r-- | training/dtrain/examples/net/10.gz | bin | 0 -> 1196 bytes | |||
-rw-r--r-- | training/dtrain/examples/net/README | 6 | ||||
-rw-r--r-- | training/dtrain/examples/net/cdec.ini | 27 | ||||
-rw-r--r-- | training/dtrain/examples/net/dtrain.ini | 4 | ||||
-rw-r--r-- | training/dtrain/examples/net/work/out.0 | 11 | ||||
-rw-r--r-- | training/dtrain/examples/net/work/out.1 | 11 | ||||
-rw-r--r-- | training/dtrain/examples/net/work/out.2 | 11 | ||||
-rwxr-xr-x | training/dtrain/feed.rb | 22 | ||||
-rw-r--r-- | training/dtrain/nn.hpp | 204 |
13 files changed, 613 insertions, 1 deletions
diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am index a6c65b1e..590218be 100644 --- a/training/dtrain/Makefile.am +++ b/training/dtrain/Makefile.am @@ -1,7 +1,10 @@ -bin_PROGRAMS = dtrain +bin_PROGRAMS = dtrain dtrain_net dtrain_SOURCES = dtrain.cc dtrain.h sample.h score.h update.h dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a +dtrain_net_SOURCES = dtrain_net.cc dtrain_net.h dtrain.h sample.h score.h update.h +dtrain_net_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a /usr/lib64/libnanomsg.so + AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/dtrain/downpour.rb b/training/dtrain/downpour.rb new file mode 100755 index 00000000..d6af6707 --- /dev/null +++ b/training/dtrain/downpour.rb @@ -0,0 +1,120 @@ +#!/usr/bin/env ruby + +require 'trollop' +require 'zipf' +require 'socket' +require 'nanomsg' + +conf = Trollop::options do + opt :conf, "dtrain configuration", :type => :string, :required => true, :short => '-c' + opt :input, "input as bitext (f ||| e)", :type => :string, :required => true, :short => '-i' + opt :epochs, "number of epochs", :type => :int, :default => 10, :short => '-e' + opt :learning_rate, "learning rate", :type => :float, :default => 1.0, :short => '-l' + opt :slaves, "number of parallel learners", :type => :int, :default => 1, :short => '-p' + opt :dtrain_binary, "path to dtrain_net binary", :type => :string, :short => '-d' +end + +dtrain_conf = conf[:conf] +input = conf[:input] +epochs = conf[:epochs] +learning_rate = conf[:learning_rate] +num_slaves = conf[:slaves] +dtrain_dir = File.expand_path File.dirname(__FILE__) + +if not conf[:dtrain_binary] + dtrain_bin = "#{dtrain_dir}/dtrain_net" +else + dtrain_bin = conf[:dtrain_binary] +end + +socks = [] +port = 60666 # last port = port+slaves +slave_pids = [] +master_ip = Socket.ip_address_list[0].ip_address + +`mkdir work` + +num_slaves.times { |i| + socks << NanoMsg::PairSocket.new + addr = "tcp://#{master_ip}:#{port}" + socks.last.bind addr + STDERR.write "listening on #{addr}\n" + slave_pids << Kernel.fork { + cmd = "#{dtrain_bin} -c #{dtrain_conf} -a #{addr} &>work/out.#{i}" + `#{cmd}` + } + port += 1 +} + +threads = [] +socks.each_with_index { |n,i| + threads << Thread.new { + n.recv + STDERR.write "got hello from slave ##{i}\n" + } +} +threads.each { |thr| thr.join } # timeout? +threads.clear + +inf = ReadFile.new input +buf = [] +j = 0 +m = Mutex.new +n = Mutex.new +w = SparseVector.new +ready = num_slaves.times.map { true } +cma = 1 +epochs.times { |epoch| +STDERR.write "---\nepoch #{epoch}\n" +inf.rewind +i = 0 +while true # round-robin + d = inf.gets + break if !d + d.strip! + while !ready[j] + j += 1 + j = 0 if j==num_slaves + end + STDERR.write "sending source ##{i} to slave ##{j}\n" + socks[j].send d + n.synchronize { + ready[j] = false + } + threads << Thread.new { + me = j + moment = cma + update = SparseVector::from_kv socks[me].recv + STDERR.write "T update from slave ##{me}\n" + update *= learning_rate + update -= w + update /= moment + m.synchronize { w += update } + STDERR.write "T sending new weights to slave ##{me}\n" + socks[me].send w.to_kv + STDERR.write "T sent new weights to slave ##{me}\n" + n.synchronize { + ready[me] = true + } + } + sleep 1 + i += 1 + cma += 1 + j += 1 + j = 0 if j==num_slaves + threads.delete_if { |thr| !thr.status } +end +} + +threads.each { |thr| thr.join } + +socks.each { |n| + Thread.new { + n.send "shutdown" + } +} + +slave_pids.each { |pid| Process.wait(pid) } + +puts w.to_kv " ", "\n" + diff --git a/training/dtrain/dtrain_net.cc b/training/dtrain/dtrain_net.cc new file mode 100644 index 00000000..946b7587 --- /dev/null +++ b/training/dtrain/dtrain_net.cc @@ -0,0 +1,121 @@ +#include "dtrain_net.h" +#include "sample.h" +#include "score.h" +#include "update.h" + +#include <nanomsg/nn.h> +#include <nanomsg/pair.h> +#include "nn.hpp" + +using namespace dtrain; + +int +main(int argc, char** argv) +{ + // get configuration + po::variables_map conf; + if (!dtrain_net_init(argc, argv, &conf)) + exit(1); // something is wrong + const size_t k = conf["k"].as<size_t>(); + const size_t N = conf["N"].as<size_t>(); + const weight_t margin = conf["margin"].as<weight_t>(); + const string master_addr = conf["addr"].as<string>(); + + // setup decoder + register_feature_functions(); + SetSilent(true); + ReadFile f(conf["decoder_conf"].as<string>()); + Decoder decoder(f.stream()); + ScoredKbest* observer = new ScoredKbest(k, new PerSentenceBleuScorer(N)); + + // weights + vector<weight_t>& decoder_weights = decoder.CurrentWeightVector(); + SparseVector<weight_t> lambdas, w_average; + if (conf.count("input_weights")) { + Weights::InitFromFile(conf["input_weights"].as<string>(), &decoder_weights); + Weights::InitSparseVector(decoder_weights, &lambdas); + } + + cerr << _p4; + // output configuration + cerr << "dtrain_net" << endl << "Parameters:" << endl; + cerr << setw(25) << "k " << k << endl; + cerr << setw(25) << "N " << N << endl; + cerr << setw(25) << "margin " << margin << endl; + cerr << setw(25) << "decoder conf " << "'" + << conf["decoder_conf"].as<string>() << "'" << endl; + + // socket + nn::socket sock(AF_SP, NN_PAIR); + sock.connect(master_addr.c_str()); + sock.send("hello", 6, 0); + + size_t i = 0; + while(true) + { + char *buf = NULL; + string source; + vector<Ngrams> refs; + vector<size_t> rsz; + bool next = true; + size_t sz = sock.recv(&buf, NN_MSG, 0); + if (buf) { + const string in(buf, buf+sz); + nn::freemsg(buf); + if (in == "shutdown") { + next = false; + } else { + vector<string> parts; + boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| ")); + source = parts[0]; + parts.erase(parts.begin()); + for (auto s: parts) { + vector<WordID> r; + vector<string> toks; + boost::split(toks, s, boost::is_any_of(" ")); + for (auto tok: toks) + r.push_back(TD::Convert(tok)); + refs.emplace_back(MakeNgrams(r, N)); + rsz.push_back(r.size()); + } + } + } + if (next) { + if (i%20 == 0) + cerr << " "; + cerr << "."; + if ((i+1)%20==0) + cerr << " " << i+1 << endl; + } else { + if (i%20 != 0) + cerr << " " << i << endl; + } + cerr.flush(); + + if (!next) + break; + + // decode + lambdas.init_vector(&decoder_weights); + observer->SetReference(refs, rsz); + decoder.Decode(source, observer); + vector<ScoredHyp>* samples = observer->GetSamples(); + + // get pairs and update + SparseVector<weight_t> updates; + CollectUpdates(samples, updates, margin); + ostringstream os; + vectorAsString(updates, os); + sock.send(os.str().c_str(), os.str().size()+1, 0); + buf = NULL; + sz = sock.recv(&buf, NN_MSG, 0); + string new_weights(buf, buf+sz); + nn::freemsg(buf); + lambdas.clear(); + updateVectorFromString(new_weights, lambdas); + i++; + } // input loop + + return 0; +} + diff --git a/training/dtrain/dtrain_net.h b/training/dtrain/dtrain_net.h new file mode 100644 index 00000000..ecacf3ee --- /dev/null +++ b/training/dtrain/dtrain_net.h @@ -0,0 +1,72 @@ +#ifndef _DTRAIN_NET_H_ +#define _DTRAIN_NET_H_ + +#include "dtrain.h" + +namespace dtrain +{ + +template<typename T> +inline void +vectorAsString(SparseVector<T>& v, ostringstream& os) +{ + SparseVector<weight_t>::iterator it = v.begin(); + for (; it != v.end(); ++it) { + os << FD::Convert(it->first) << "=" << it->second; + auto peek = it; + if (++peek != v.end()) + os << " "; + } +} + +template<typename T> +inline void +updateVectorFromString(string& s, SparseVector<T>& v) +{ + string buf; + istringstream ss; + while (ss >> buf) { + size_t p = buf.find_last_of("="); + istringstream c(buf.substr(p+1,buf.size())); + weight_t val; + c >> val; + v[FD::Convert(buf.substr(0,p))] = val; + } +} + +bool +dtrain_net_init(int argc, char** argv, po::variables_map* conf) +{ + po::options_description ini("Configuration File Options"); + ini.add_options() + ("decoder_conf,C", po::value<string>(), "configuration file for decoder") + ("k", po::value<size_t>()->default_value(100), "size of kbest list") + ("N", po::value<size_t>()->default_value(4), "N for BLEU approximation") + ("margin,m", po::value<weight_t>()->default_value(0.), "margin for margin perceptron"); + po::options_description cl("Command Line Options"); + cl.add_options() + ("conf,c", po::value<string>(), "dtrain configuration file") + ("addr,a", po::value<string>(), "address of master"); + cl.add(ini); + po::store(parse_command_line(argc, argv, cl), *conf); + if (conf->count("conf")) { + ifstream f((*conf)["conf"].as<string>().c_str()); + po::store(po::parse_config_file(f, ini), *conf); + } + po::notify(*conf); + if (!conf->count("decoder_conf")) { + cerr << "Missing decoder configuration. Exiting." << endl; + return false; + } + if (!conf->count("addr")) { + cerr << "No master address given! Exiting." << endl; + return false; + } + + return true; +} + +} // namespace + +#endif + diff --git a/training/dtrain/examples/net/10.gz b/training/dtrain/examples/net/10.gz Binary files differnew file mode 100644 index 00000000..44775573 --- /dev/null +++ b/training/dtrain/examples/net/10.gz diff --git a/training/dtrain/examples/net/README b/training/dtrain/examples/net/README new file mode 100644 index 00000000..4acb721b --- /dev/null +++ b/training/dtrain/examples/net/README @@ -0,0 +1,6 @@ +run + ../../downpour.rb -c dtrain.ini -p 3 -i 10.gz -l 0.00001 -e 3 +or + zcat 10.gz | head -6 | ../../feed.rb 60667 + ../../dtrain_net -c dtrain.ini -a tcp://127.0.0.1:60667 + diff --git a/training/dtrain/examples/net/cdec.ini b/training/dtrain/examples/net/cdec.ini new file mode 100644 index 00000000..6c986d03 --- /dev/null +++ b/training/dtrain/examples/net/cdec.ini @@ -0,0 +1,27 @@ +formalism=scfg +add_pass_through_rules=true +scfg_max_span_limit=15 +intersection_strategy=cube_pruning +cubepruning_pop_limit=200 +grammar=../standard/nc-wmt11.grammar.gz +feature_function=WordPenalty +feature_function=KLanguageModel ../standard/nc-wmt11.en.srilm.gz +# all currently working feature functions for translation: +# (with those features active that were used in the ACL paper) +#feature_function=ArityPenalty +#feature_function=CMR2008ReorderingFeatures +#feature_function=Dwarf +#feature_function=InputIndicator +#feature_function=LexNullJump +#feature_function=NewJump +#feature_function=NgramFeatures +#feature_function=NonLatinCount +#feature_function=OutputIndicator +#feature_function=RuleIdentityFeatures +#feature_function=RuleSourceBigramFeatures +#feature_function=RuleTargetBigramFeatures +#feature_function=RuleShape +#feature_function=LexicalFeatures 1 1 1 +#feature_function=SourceSpanSizeFeatures +#feature_function=SourceWordPenalty +#feature_function=SpanFeatures diff --git a/training/dtrain/examples/net/dtrain.ini b/training/dtrain/examples/net/dtrain.ini new file mode 100644 index 00000000..cfcd91bc --- /dev/null +++ b/training/dtrain/examples/net/dtrain.ini @@ -0,0 +1,4 @@ +decoder_conf=./cdec.ini # config for cdec +k=100 # use 100best lists +N=4 # optimize (approx.) BLEU4 +margin=1.0 # perceptron's margin diff --git a/training/dtrain/examples/net/work/out.0 b/training/dtrain/examples/net/work/out.0 new file mode 100644 index 00000000..37b0ea44 --- /dev/null +++ b/training/dtrain/examples/net/work/out.0 @@ -0,0 +1,11 @@ +Loading the LM will be faster if you build a binary file. +Reading ../standard/nc-wmt11.en.srilm.gz +----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100 +**************************************************************************************************** +dtrain_net +Parameters: + k 100 + N 4 + margin 1 + decoder conf './cdec.ini' + ........ 8 diff --git a/training/dtrain/examples/net/work/out.1 b/training/dtrain/examples/net/work/out.1 new file mode 100644 index 00000000..187b726e --- /dev/null +++ b/training/dtrain/examples/net/work/out.1 @@ -0,0 +1,11 @@ +Loading the LM will be faster if you build a binary file. +Reading ../standard/nc-wmt11.en.srilm.gz +----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100 +**************************************************************************************************** +dtrain_net +Parameters: + k 100 + N 4 + margin 1 + decoder conf './cdec.ini' + ........... 11 diff --git a/training/dtrain/examples/net/work/out.2 b/training/dtrain/examples/net/work/out.2 new file mode 100644 index 00000000..187b726e --- /dev/null +++ b/training/dtrain/examples/net/work/out.2 @@ -0,0 +1,11 @@ +Loading the LM will be faster if you build a binary file. +Reading ../standard/nc-wmt11.en.srilm.gz +----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100 +**************************************************************************************************** +dtrain_net +Parameters: + k 100 + N 4 + margin 1 + decoder conf './cdec.ini' + ........... 11 diff --git a/training/dtrain/feed.rb b/training/dtrain/feed.rb new file mode 100755 index 00000000..fe8dd509 --- /dev/null +++ b/training/dtrain/feed.rb @@ -0,0 +1,22 @@ +#!/usr/bin/env ruby + +require 'nanomsg' + +port = ARGV[0] +sock = NanoMsg::PairSocket.new +addr = "tcp://127.0.0.1:#{port}" +sock.bind addr + +puts sock.recv +while true + line = STDIN.gets + if !line + sock.send 'shutdown' + break + end + sock.send line.strip + sleep 1 + sock.recv + sock.send "a=1 b=2" +end + diff --git a/training/dtrain/nn.hpp b/training/dtrain/nn.hpp new file mode 100644 index 00000000..50b8304c --- /dev/null +++ b/training/dtrain/nn.hpp @@ -0,0 +1,204 @@ +/* + Copyright (c) 2013 250bpm s.r.o. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), + to deal in the Software without restriction, including without limitation + the rights to use, copy, modify, merge, publish, distribute, sublicense, + and/or sell copies of the Software, and to permit persons to whom + the Software is furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included + in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + IN THE SOFTWARE. +*/ + +#ifndef NN_HPP_INCLUDED +#define NN_HPP_INCLUDED + +#include <nanomsg/nn.h> + +#include <cassert> +#include <cstring> +#include <algorithm> +#include <exception> + +#if defined __GNUC__ +#define nn_slow(x) __builtin_expect ((x), 0) +#else +#define nn_slow(x) (x) +#endif + +namespace nn +{ + + class exception : public std::exception + { + public: + + exception () : err (nn_errno ()) {} + + virtual const char *what () const throw () + { + return nn_strerror (err); + } + + int num () const + { + return err; + } + + private: + + int err; + }; + + inline const char *symbol (int i, int *value) + { + return nn_symbol (i, value); + } + + inline void *allocmsg (size_t size, int type) + { + void *msg = nn_allocmsg (size, type); + if (nn_slow (!msg)) + throw nn::exception (); + return msg; + } + + inline int freemsg (void *msg) + { + int rc = nn_freemsg (msg); + if (nn_slow (rc != 0)) + throw nn::exception (); + return rc; + } + + class socket + { + public: + + inline socket (int domain, int protocol) + { + s = nn_socket (domain, protocol); + if (nn_slow (s < 0)) + throw nn::exception (); + } + + inline ~socket () + { + int rc = nn_close (s); + assert (rc == 0); + } + + inline void setsockopt (int level, int option, const void *optval, + size_t optvallen) + { + int rc = nn_setsockopt (s, level, option, optval, optvallen); + if (nn_slow (rc != 0)) + throw nn::exception (); + } + + inline void getsockopt (int level, int option, void *optval, + size_t *optvallen) + { + int rc = nn_getsockopt (s, level, option, optval, optvallen); + if (nn_slow (rc != 0)) + throw nn::exception (); + } + + inline int bind (const char *addr) + { + int rc = nn_bind (s, addr); + if (nn_slow (rc < 0)) + throw nn::exception (); + return rc; + } + + inline int connect (const char *addr) + { + int rc = nn_connect (s, addr); + if (nn_slow (rc < 0)) + throw nn::exception (); + return rc; + } + + inline void shutdown (int how) + { + int rc = nn_shutdown (s, how); + if (nn_slow (rc != 0)) + throw nn::exception (); + } + + inline int send (const void *buf, size_t len, int flags) + { + int rc = nn_send (s, buf, len, flags); + if (nn_slow (rc < 0)) { + if (nn_slow (nn_errno () != EAGAIN)) + throw nn::exception (); + return -1; + } + return rc; + } + + inline int recv (void *buf, size_t len, int flags) + { + int rc = nn_recv (s, buf, len, flags); + if (nn_slow (rc < 0)) { + if (nn_slow (nn_errno () != EAGAIN)) + throw nn::exception (); + return -1; + } + return rc; + } + + inline int sendmsg (const struct nn_msghdr *msghdr, int flags) + { + int rc = nn_sendmsg (s, msghdr, flags); + if (nn_slow (rc < 0)) { + if (nn_slow (nn_errno () != EAGAIN)) + throw nn::exception (); + return -1; + } + return rc; + } + + inline int recvmsg (struct nn_msghdr *msghdr, int flags) + { + int rc = nn_recvmsg (s, msghdr, flags); + if (nn_slow (rc < 0)) { + if (nn_slow (nn_errno () != EAGAIN)) + throw nn::exception (); + return -1; + } + return rc; + } + + private: + + int s; + + /* Prevent making copies of the socket by accident. */ + socket (const socket&); + void operator = (const socket&); + }; + + inline void term () + { + nn_term (); + } + +} + +#undef nn_slow + +#endif + + |