summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/dtrain/dtrain_net.h3
-rw-r--r--training/dtrain/dtrain_net_interface.cc40
2 files changed, 32 insertions, 11 deletions
diff --git a/training/dtrain/dtrain_net.h b/training/dtrain/dtrain_net.h
index ecacf3ee..f6aa08b2 100644
--- a/training/dtrain/dtrain_net.h
+++ b/training/dtrain/dtrain_net.h
@@ -42,7 +42,8 @@ dtrain_net_init(int argc, char** argv, po::variables_map* conf)
("decoder_conf,C", po::value<string>(), "configuration file for decoder")
("k", po::value<size_t>()->default_value(100), "size of kbest list")
("N", po::value<size_t>()->default_value(4), "N for BLEU approximation")
- ("margin,m", po::value<weight_t>()->default_value(0.), "margin for margin perceptron");
+ ("margin,m", po::value<weight_t>()->default_value(0.), "margin for margin perceptron")
+ ("output,o", po::value<string>()->default_value(""), "final weights file");
po::options_description cl("Command Line Options");
cl.add_options()
("conf,c", po::value<string>(), "dtrain configuration file")
diff --git a/training/dtrain/dtrain_net_interface.cc b/training/dtrain/dtrain_net_interface.cc
index f484b56b..2719e946 100644
--- a/training/dtrain/dtrain_net_interface.cc
+++ b/training/dtrain/dtrain_net_interface.cc
@@ -20,6 +20,7 @@ main(int argc, char** argv)
const size_t N = conf["N"].as<size_t>();
const weight_t margin = conf["margin"].as<weight_t>();
const string master_addr = conf["addr"].as<string>();
+ const string output_fn = conf["output"].as<string>();
// setup decoder
register_feature_functions();
@@ -44,10 +45,13 @@ main(int argc, char** argv)
cerr << setw(25) << "margin " << margin << endl;
cerr << setw(25) << "decoder conf " << "'"
<< conf["decoder_conf"].as<string>() << "'" << endl;
+ cerr << setw(25) << "output " << output_fn << endl;
- // socket
+ // setup socket
nn::socket sock(AF_SP, NN_PAIR);
- sock.connect(master_addr.c_str());
+ sock.bind(master_addr.c_str());
+ string hello = "hello";
+ sock.send(hello.c_str(), hello.size()+1, 0);
size_t i = 0;
while(true)
@@ -61,9 +65,11 @@ main(int argc, char** argv)
if (buf) {
const string in(buf, buf+sz);
nn::freemsg(buf);
- if (in == "shutdown") {
+ cerr << "got input '" << in << "'" << endl;
+ if (in == "shutdown") { // shut down
+ cerr << "got shutdown signal" << endl;
next = false;
- } else {
+ } else { // translate
vector<string> parts;
boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| "));
if (parts[0] == "act:translate") {
@@ -74,11 +80,12 @@ main(int argc, char** argv)
observer->dont_score = false;
vector<ScoredHyp>* samples = observer->GetSamples();
ostringstream os;
+ cerr << "1best features " << (*samples)[0].f << endl;
PrintWordIDVec((*samples)[0].w, os);
sock.send(os.str().c_str(), os.str().size()+1, 0);
- cerr << "done" << endl;
+ cerr << "> done translating, looping" << endl;
continue;
- } else {
+ } else { // learn
cerr << "learning ..." << endl;
source = parts[0];
parts.erase(parts.begin());
@@ -103,17 +110,30 @@ main(int argc, char** argv)
observer->SetReference(refs, rsz);
decoder.Decode(source, observer);
vector<ScoredHyp>* samples = observer->GetSamples();
+ cerr << "samples size " << samples->size() << endl;
// get pairs and update
SparseVector<weight_t> updates;
CollectUpdates(samples, updates, margin);
- lambdas.plus_eq_v_times_s(updates, 1.0); // fixme
- string s = "x";
- sock.send(s.c_str(), s.size()+1, 0);
+ cerr << "updates size " << updates.size() << endl;
+ cerr << "lambdas before " << lambdas << endl;
+ lambdas.plus_eq_v_times_s(updates, 1.0); // FIXME: learning rate?
+ cerr << "lambdas after " << lambdas << endl;
i++;
- cerr << "done" << endl;
+ cerr << "> done learning, looping" << endl;
} // input loop
+
+ if (output_fn != "") {
+ cerr << "writing final weights to '" << output_fn << "'" << endl;
+ lambdas.init_vector(decoder_weights);
+ Weights::WriteToFile(output_fn, decoder_weights, true);
+ }
+
+ string shutdown = "off";
+ sock.send(shutdown.c_str(), shutdown.size()+1, 0);
+
+ cerr << "shutting down, goodbye" << endl;
return 0;
}