diff options
-rw-r--r-- | training/dtrain/dtrain_net_interface.cc | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/training/dtrain/dtrain_net_interface.cc b/training/dtrain/dtrain_net_interface.cc index 5c2df022..f2f200ef 100644 --- a/training/dtrain/dtrain_net_interface.cc +++ b/training/dtrain/dtrain_net_interface.cc @@ -172,6 +172,39 @@ main(int argc, char** argv) } else if (in == "shutdown") { // shut down cerr << "[dtrain] got shutdown signal" << endl; next = false; + continue; + } else if (boost::starts_with(in, "get_weight")) { // get weight + stringstream ss(in); + string _,name; + ss >> _; ss >> name; + cerr << "[dtrain] getting weight for " << name << endl; + ostringstream o; + unsigned fid = FD::Convert(name); + weight_t w = lambdas[fid]; + o << w; + string s = o.str(); + sock.send(s.c_str(), s.size()+1, 0); + continue; + } else if (boost::starts_with(in, "get_rate")) { // get rate + stringstream ss(in); + string _,name; + ss >> _; ss >> name; + cerr << "[dtrain] getting rate for " << name << endl; + ostringstream o; + unsigned fid = FD::Convert(name); + weight_t r; + if (name == "R") + r = learning_rate_R; + else if (name == "RB") + r = learning_rate_RB; + else if (name == "Shape") + r = learning_rate_Shape; + else + r = learning_rates[fid]; + o << r; + string s = o.str(); + sock.send(s.c_str(), s.size()+1, 0); + continue; } else { // translate vector<string> parts; boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| ")); |