summaryrefslogtreecommitdiff
path: root/training/dtrain/downpour.rb
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/downpour.rb')
-rwxr-xr-xtraining/dtrain/downpour.rb120
1 files changed, 120 insertions, 0 deletions
diff --git a/training/dtrain/downpour.rb b/training/dtrain/downpour.rb
new file mode 100755
index 00000000..d6af6707
--- /dev/null
+++ b/training/dtrain/downpour.rb
@@ -0,0 +1,120 @@
+#!/usr/bin/env ruby
+
+require 'trollop'
+require 'zipf'
+require 'socket'
+require 'nanomsg'
+
+conf = Trollop::options do
+ opt :conf, "dtrain configuration", :type => :string, :required => true, :short => '-c'
+ opt :input, "input as bitext (f ||| e)", :type => :string, :required => true, :short => '-i'
+ opt :epochs, "number of epochs", :type => :int, :default => 10, :short => '-e'
+ opt :learning_rate, "learning rate", :type => :float, :default => 1.0, :short => '-l'
+ opt :slaves, "number of parallel learners", :type => :int, :default => 1, :short => '-p'
+ opt :dtrain_binary, "path to dtrain_net binary", :type => :string, :short => '-d'
+end
+
+dtrain_conf = conf[:conf]
+input = conf[:input]
+epochs = conf[:epochs]
+learning_rate = conf[:learning_rate]
+num_slaves = conf[:slaves]
+dtrain_dir = File.expand_path File.dirname(__FILE__)
+
+if not conf[:dtrain_binary]
+ dtrain_bin = "#{dtrain_dir}/dtrain_net"
+else
+ dtrain_bin = conf[:dtrain_binary]
+end
+
+socks = []
+port = 60666 # last port = port+slaves
+slave_pids = []
+master_ip = Socket.ip_address_list[0].ip_address
+
+`mkdir work`
+
+num_slaves.times { |i|
+ socks << NanoMsg::PairSocket.new
+ addr = "tcp://#{master_ip}:#{port}"
+ socks.last.bind addr
+ STDERR.write "listening on #{addr}\n"
+ slave_pids << Kernel.fork {
+ cmd = "#{dtrain_bin} -c #{dtrain_conf} -a #{addr} &>work/out.#{i}"
+ `#{cmd}`
+ }
+ port += 1
+}
+
+threads = []
+socks.each_with_index { |n,i|
+ threads << Thread.new {
+ n.recv
+ STDERR.write "got hello from slave ##{i}\n"
+ }
+}
+threads.each { |thr| thr.join } # timeout?
+threads.clear
+
+inf = ReadFile.new input
+buf = []
+j = 0
+m = Mutex.new
+n = Mutex.new
+w = SparseVector.new
+ready = num_slaves.times.map { true }
+cma = 1
+epochs.times { |epoch|
+STDERR.write "---\nepoch #{epoch}\n"
+inf.rewind
+i = 0
+while true # round-robin
+ d = inf.gets
+ break if !d
+ d.strip!
+ while !ready[j]
+ j += 1
+ j = 0 if j==num_slaves
+ end
+ STDERR.write "sending source ##{i} to slave ##{j}\n"
+ socks[j].send d
+ n.synchronize {
+ ready[j] = false
+ }
+ threads << Thread.new {
+ me = j
+ moment = cma
+ update = SparseVector::from_kv socks[me].recv
+ STDERR.write "T update from slave ##{me}\n"
+ update *= learning_rate
+ update -= w
+ update /= moment
+ m.synchronize { w += update }
+ STDERR.write "T sending new weights to slave ##{me}\n"
+ socks[me].send w.to_kv
+ STDERR.write "T sent new weights to slave ##{me}\n"
+ n.synchronize {
+ ready[me] = true
+ }
+ }
+ sleep 1
+ i += 1
+ cma += 1
+ j += 1
+ j = 0 if j==num_slaves
+ threads.delete_if { |thr| !thr.status }
+end
+}
+
+threads.each { |thr| thr.join }
+
+socks.each { |n|
+ Thread.new {
+ n.send "shutdown"
+ }
+}
+
+slave_pids.each { |pid| Process.wait(pid) }
+
+puts w.to_kv " ", "\n"
+