summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain_net_interface.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain_net_interface.cc')
-rw-r--r--training/dtrain/dtrain_net_interface.cc94
1 files changed, 52 insertions, 42 deletions
diff --git a/training/dtrain/dtrain_net_interface.cc b/training/dtrain/dtrain_net_interface.cc
index e9612def..3b19ecbf 100644
--- a/training/dtrain/dtrain_net_interface.cc
+++ b/training/dtrain/dtrain_net_interface.cc
@@ -19,10 +19,14 @@ main(int argc, char** argv)
const size_t k = conf["k"].as<size_t>();
const size_t N = conf["N"].as<size_t>();
weight_t eta = conf["learning_rate"].as<weight_t>();
+ weight_t eta_sparse = conf["learning_rate_sparse"].as<weight_t>();
const weight_t margin = conf["margin"].as<weight_t>();
const string master_addr = conf["addr"].as<string>();
const string output_fn = conf["output"].as<string>();
const string debug_fn = conf["debug_output"].as<string>();
+ vector<string> dense_features;
+ boost::split(dense_features, conf["dense_features"].as<string>(),
+ boost::is_any_of(" "));
// setup decoder
register_feature_functions();
@@ -33,10 +37,11 @@ main(int argc, char** argv)
// weights
vector<weight_t>& decoder_weights = decoder.CurrentWeightVector();
- SparseVector<weight_t> lambdas, w_average;
+ SparseVector<weight_t> lambdas, w_average, original_lambdas;
if (conf.count("input_weights")) {
Weights::InitFromFile(conf["input_weights"].as<string>(), &decoder_weights);
Weights::InitSparseVector(decoder_weights, &lambdas);
+ Weights::InitSparseVector(decoder_weights, &original_lambdas);
}
cerr << _p4;
@@ -44,6 +49,8 @@ main(int argc, char** argv)
cerr << "dtrain_net_interface" << endl << "Parameters:" << endl;
cerr << setw(25) << "k " << k << endl;
cerr << setw(25) << "N " << N << endl;
+ cerr << setw(25) << "eta " << eta << endl;
+ cerr << setw(25) << "eta (sparse) " << eta_sparse << endl;
cerr << setw(25) << "margin " << margin << endl;
cerr << setw(25) << "decoder conf " << "'"
<< conf["decoder_conf"].as<string>() << "'" << endl;
@@ -58,13 +65,15 @@ main(int argc, char** argv)
// debug
ostringstream debug_output;
+ string done = "done";
+
size_t i = 0;
while(true)
{
// debug --
debug_output.str(string());
debug_output.clear();
- debug_output << "{" << endl;
+ debug_output << "{" << endl; // hack us a nice JSON output
// -- debug
char *buf = NULL;
@@ -77,7 +86,31 @@ main(int argc, char** argv)
const string in(buf, buf+sz);
nn::freemsg(buf);
cerr << "[dtrain] got input '" << in << "'" << endl;
- if (in == "shutdown") { // shut down
+ if (boost::starts_with(in, "set_learning_rate")) { // set learning rate
+ stringstream ss(in);
+ string x; weight_t w;
+ ss >> x; ss >> w;
+ cerr << "[dtrain] setting (dense) learning rate to " << w << " (was: " << eta << ")" << endl;
+ eta = w;
+ cerr << "[dtrain] done, looping again" << endl;
+ sock.send(done.c_str(), done.size()+1, 0);
+ continue;
+ } else if (boost::starts_with(in, "set_sparse_learning_rate")) { // set sparse learning rate
+ stringstream ss(in);
+ string x; weight_t w;
+ ss >> x; ss >> w;
+ cerr << "[dtrain] setting sparse learning rate to " << w << " (was: " << eta_sparse << ")" << endl;
+ eta_sparse = w;
+ cerr << "[dtrain] done, looping again" << endl;
+ sock.send(done.c_str(), done.size()+1, 0);
+ continue;
+ } else if (boost::starts_with(in, "reset_weights")) { // reset weights
+ cerr << "[dtrain] resetting weights" << endl;
+ lambdas = original_lambdas;
+ cerr << "[dtrain] done, looping again" << endl;
+ sock.send(done.c_str(), done.size()+1, 0);
+ continue;
+ } else if (in == "shutdown") { // shut down
cerr << "[dtrain] got shutdown signal" << endl;
next = false;
} else { // translate
@@ -134,16 +167,8 @@ main(int argc, char** argv)
size_t h = 0;
for (auto s: *samples) {
debug_output << "\"" << s.gold << " ||| " << s.model << " ||| " << s.rank << " ||| ";
- debug_output << "EgivenFCoherent=" << s.f[FD::Convert("EgivenFCoherent")] << " ";
- debug_output << "SampleCountF=" << s.f[FD::Convert("CountEF")] << " ";
- debug_output << "MaxLexFgivenE=" << s.f[FD::Convert("MaxLexFgivenE")] << " ";
- debug_output << "MaxLexEgivenF=" << s.f[FD::Convert("MaxLexEgivenF")] << " ";
- debug_output << "IsSingletonF=" << s.f[FD::Convert("IsSingletonF")] << " ";
- debug_output << "IsSingletonFE=" << s.f[FD::Convert("IsSingletonFE")] << " ";
- debug_output << "Glue=:" << s.f[FD::Convert("Glue")] << " ";
- debug_output << "WordPenalty=" << s.f[FD::Convert("WordPenalty")] << " ";
- debug_output << "PassThrough=" << s.f[FD::Convert("PassThrough")] << " ";
- debug_output << "LanguageModel=" << s.f[FD::Convert("LanguageModel_OOV")];
+ for (auto o: s.f)
+ debug_output << FD::Convert(o.first) << "=" << o.second << " ";
debug_output << " ||| ";
PrintWordIDVec(s.w, debug_output);
h += 1;
@@ -156,67 +181,52 @@ main(int argc, char** argv)
debug_output << "]," << endl;
debug_output << "\"samples_size\":" << samples->size() << "," << endl;
debug_output << "\"weights_before\":{" << endl;
- debug_output << "\"EgivenFCoherent\":" << lambdas[FD::Convert("EgivenFCoherent")] << "," << endl;
- debug_output << "\"SampleCountF\":" << lambdas[FD::Convert("CountEF")] << "," << endl;
- debug_output << "\"MaxLexFgivenE\":" << lambdas[FD::Convert("MaxLexFgivenE")] << "," << endl;
- debug_output << "\"MaxLexEgivenF\":" << lambdas[FD::Convert("MaxLexEgivenF")] << "," << endl;
- debug_output << "\"IsSingletonF\":" << lambdas[FD::Convert("IsSingletonF")] << "," << endl;
- debug_output << "\"IsSingletonFE\":" << lambdas[FD::Convert("IsSingletonFE")] << "," << endl;
- debug_output << "\"Glue\":" << lambdas[FD::Convert("Glue")] << "," << endl;
- debug_output << "\"WordPenalty\":" << lambdas[FD::Convert("WordPenalty")] << "," << endl;
- debug_output << "\"PassThrough\":" << lambdas[FD::Convert("PassThrough")] << "," << endl;
- debug_output << "\"LanguageModel\":" << lambdas[FD::Convert("LanguageModel_OOV")] << endl;
+ weightsToJson(lambdas, debug_output);
debug_output << "}," << endl;
// -- debug
// get pairs and update
SparseVector<weight_t> updates;
size_t num_up = CollectUpdates(samples, updates, margin);
-
+ updates *= eta_sparse; // apply learning rate for sparse features
+ for (auto feat: dense_features) { // apply learning rate for dense features
+ updates[FD::Convert(feat)] /= eta_sparse;
+ updates[FD::Convert(feat)] *= eta;
+ }
// debug --
debug_output << "\"num_up\":" << num_up << "," << endl;
debug_output << "\"updated_features\":" << updates.size() << "," << endl;
debug_output << "\"learning_rate\":" << eta << "," << endl;
+ debug_output << "\"learning_rate_sparse\":" << eta_sparse << "," << endl;
debug_output << "\"best_match\":\"";
PrintWordIDVec((*samples)[0].w, debug_output);
debug_output << "\"," << endl;
debug_output << "\"best_match_score\":" << (*samples)[0].gold << "," << endl ;
// -- debug
-
- lambdas.plus_eq_v_times_s(updates, eta);
+ lambdas.plus_eq_v_times_s(updates, 1.0);
i++;
// debug --
debug_output << "\"weights_after\":{" << endl;
- debug_output << "\"EgivenFCoherent\":" << lambdas[FD::Convert("EgivenFCoherent")] << "," << endl;
- debug_output << "\"SampleCountF\":" << lambdas[FD::Convert("CountEF")] << "," << endl;
- debug_output << "\"MaxLexFgivenE\":" << lambdas[FD::Convert("MaxLexFgivenE")] << "," << endl;
- debug_output << "\"MaxLexEgivenF\":" << lambdas[FD::Convert("MaxLexEgivenF")] << "," << endl;
- debug_output << "\"IsSingletonF\":" << lambdas[FD::Convert("IsSingletonF")] << "," << endl;
- debug_output << "\"IsSingletonFE\":" << lambdas[FD::Convert("IsSingletonFE")] << "," << endl;
- debug_output << "\"Glue\":" << lambdas[FD::Convert("Glue")] << "," << endl;
- debug_output << "\"WordPenalty\":" << lambdas[FD::Convert("WordPenalty")] << "," << endl;
- debug_output << "\"PassThrough\":" << lambdas[FD::Convert("PassThrough")] << "," << endl;
- debug_output << "\"LanguageModel\":" << lambdas[FD::Convert("LanguageModel_OOV")] << endl;
+ weightsToJson(lambdas, debug_output);
debug_output << "}" << endl;
debug_output << "}" << endl;
// -- debug
cerr << "[dtrain] done learning, looping again" << endl;
- string done = "done";
sock.send(done.c_str(), done.size()+1, 0);
// debug --
WriteFile f(debug_fn);
*f << debug_output.str();
// -- debug
- } // input loop
- if (output_fn != "") {
- cerr << "[dtrain] writing final weights to '" << output_fn << "'" << endl;
+ // write current weights
lambdas.init_vector(decoder_weights);
- Weights::WriteToFile(output_fn, decoder_weights, true);
- }
+ ostringstream fn;
+ fn << output_fn << "." << i << ".gz";
+ Weights::WriteToFile(fn.str(), decoder_weights, true);
+ } // input loop
string shutdown = "off";
sock.send(shutdown.c_str(), shutdown.size()+1, 0);