summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2016-02-17 17:31:36 +0100
committerPatrick Simianer <p@simianer.de>2016-02-17 17:31:36 +0100
commit3adfda72d6c2e63dbc62a1111e17f59326970cb7 (patch)
tree366cf9aa0fa89c5d94b7f2ccdb3c49967b74dee1
parentdf2e8a4297304706691e6edfb2b56b9fbc7e692d (diff)
add get rates/weights
-rw-r--r--training/dtrain/dtrain_net_interface.cc33
1 files changed, 33 insertions, 0 deletions
diff --git a/training/dtrain/dtrain_net_interface.cc b/training/dtrain/dtrain_net_interface.cc
index 5c2df022..f2f200ef 100644
--- a/training/dtrain/dtrain_net_interface.cc
+++ b/training/dtrain/dtrain_net_interface.cc
@@ -172,6 +172,39 @@ main(int argc, char** argv)
} else if (in == "shutdown") { // shut down
cerr << "[dtrain] got shutdown signal" << endl;
next = false;
+ continue;
+ } else if (boost::starts_with(in, "get_weight")) { // get weight
+ stringstream ss(in);
+ string _,name;
+ ss >> _; ss >> name;
+ cerr << "[dtrain] getting weight for " << name << endl;
+ ostringstream o;
+ unsigned fid = FD::Convert(name);
+ weight_t w = lambdas[fid];
+ o << w;
+ string s = o.str();
+ sock.send(s.c_str(), s.size()+1, 0);
+ continue;
+ } else if (boost::starts_with(in, "get_rate")) { // get rate
+ stringstream ss(in);
+ string _,name;
+ ss >> _; ss >> name;
+ cerr << "[dtrain] getting rate for " << name << endl;
+ ostringstream o;
+ unsigned fid = FD::Convert(name);
+ weight_t r;
+ if (name == "R")
+ r = learning_rate_R;
+ else if (name == "RB")
+ r = learning_rate_RB;
+ else if (name == "Shape")
+ r = learning_rate_Shape;
+ else
+ r = learning_rates[fid];
+ o << r;
+ string s = o.str();
+ sock.send(s.c_str(), s.size()+1, 0);
+ continue;
} else { // translate
vector<string> parts;
boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| "));