diff options
author | Patrick Simianer <p@simianer.de> | 2015-05-12 17:46:56 +0200 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2015-05-12 17:46:56 +0200 |
commit | 6c0fcee726662285e7c4cb3857ca28296f5c525c (patch) | |
tree | 77de036e546e79b60db0b60e7e273eb2e8b50e3f /training/dtrain | |
parent | e0bb79a01ed07cce540e5ebb757e03d801ca287e (diff) |
integrated updating grammars
Diffstat (limited to 'training/dtrain')
-rw-r--r-- | training/dtrain/dtrain_net.h | 3 | ||||
-rw-r--r-- | training/dtrain/dtrain_net_interface.cc | 40 |
2 files changed, 32 insertions, 11 deletions
diff --git a/training/dtrain/dtrain_net.h b/training/dtrain/dtrain_net.h index ecacf3ee..f6aa08b2 100644 --- a/training/dtrain/dtrain_net.h +++ b/training/dtrain/dtrain_net.h @@ -42,7 +42,8 @@ dtrain_net_init(int argc, char** argv, po::variables_map* conf) ("decoder_conf,C", po::value<string>(), "configuration file for decoder") ("k", po::value<size_t>()->default_value(100), "size of kbest list") ("N", po::value<size_t>()->default_value(4), "N for BLEU approximation") - ("margin,m", po::value<weight_t>()->default_value(0.), "margin for margin perceptron"); + ("margin,m", po::value<weight_t>()->default_value(0.), "margin for margin perceptron") + ("output,o", po::value<string>()->default_value(""), "final weights file"); po::options_description cl("Command Line Options"); cl.add_options() ("conf,c", po::value<string>(), "dtrain configuration file") diff --git a/training/dtrain/dtrain_net_interface.cc b/training/dtrain/dtrain_net_interface.cc index f484b56b..2719e946 100644 --- a/training/dtrain/dtrain_net_interface.cc +++ b/training/dtrain/dtrain_net_interface.cc @@ -20,6 +20,7 @@ main(int argc, char** argv) const size_t N = conf["N"].as<size_t>(); const weight_t margin = conf["margin"].as<weight_t>(); const string master_addr = conf["addr"].as<string>(); + const string output_fn = conf["output"].as<string>(); // setup decoder register_feature_functions(); @@ -44,10 +45,13 @@ main(int argc, char** argv) cerr << setw(25) << "margin " << margin << endl; cerr << setw(25) << "decoder conf " << "'" << conf["decoder_conf"].as<string>() << "'" << endl; + cerr << setw(25) << "output " << output_fn << endl; - // socket + // setup socket nn::socket sock(AF_SP, NN_PAIR); - sock.connect(master_addr.c_str()); + sock.bind(master_addr.c_str()); + string hello = "hello"; + sock.send(hello.c_str(), hello.size()+1, 0); size_t i = 0; while(true) @@ -61,9 +65,11 @@ main(int argc, char** argv) if (buf) { const string in(buf, buf+sz); nn::freemsg(buf); - if (in == "shutdown") { + cerr << "got input '" << in << "'" << endl; + if (in == "shutdown") { // shut down + cerr << "got shutdown signal" << endl; next = false; - } else { + } else { // translate vector<string> parts; boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| ")); if (parts[0] == "act:translate") { @@ -74,11 +80,12 @@ main(int argc, char** argv) observer->dont_score = false; vector<ScoredHyp>* samples = observer->GetSamples(); ostringstream os; + cerr << "1best features " << (*samples)[0].f << endl; PrintWordIDVec((*samples)[0].w, os); sock.send(os.str().c_str(), os.str().size()+1, 0); - cerr << "done" << endl; + cerr << "> done translating, looping" << endl; continue; - } else { + } else { // learn cerr << "learning ..." << endl; source = parts[0]; parts.erase(parts.begin()); @@ -103,17 +110,30 @@ main(int argc, char** argv) observer->SetReference(refs, rsz); decoder.Decode(source, observer); vector<ScoredHyp>* samples = observer->GetSamples(); + cerr << "samples size " << samples->size() << endl; // get pairs and update SparseVector<weight_t> updates; CollectUpdates(samples, updates, margin); - lambdas.plus_eq_v_times_s(updates, 1.0); // fixme - string s = "x"; - sock.send(s.c_str(), s.size()+1, 0); + cerr << "updates size " << updates.size() << endl; + cerr << "lambdas before " << lambdas << endl; + lambdas.plus_eq_v_times_s(updates, 1.0); // FIXME: learning rate? + cerr << "lambdas after " << lambdas << endl; i++; - cerr << "done" << endl; + cerr << "> done learning, looping" << endl; } // input loop + + if (output_fn != "") { + cerr << "writing final weights to '" << output_fn << "'" << endl; + lambdas.init_vector(decoder_weights); + Weights::WriteToFile(output_fn, decoder_weights, true); + } + + string shutdown = "off"; + sock.send(shutdown.c_str(), shutdown.size()+1, 0); + + cerr << "shutting down, goodbye" << endl; return 0; } |