diff options
author | pks <pks@users.noreply.github.com> | 2019-05-12 20:10:37 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-05-12 20:10:37 +0200 |
commit | 4a13b41700f34c15c30b551f98dbea9cb41f67c3 (patch) | |
tree | 0218f41c350a626f5af9909d77406309fa873fdf /training/dtrain | |
parent | e9268eb3dcd867f3baf67a7bb3d2aad56196ecde (diff) | |
parent | f64746ac87fc7338629b19de9fa2da0f03fa2790 (diff) |
Merge branch 'net' into origin/net
Diffstat (limited to 'training/dtrain')
-rw-r--r-- | training/dtrain/Makefile.am | 7 | ||||
-rw-r--r-- | training/dtrain/dtrain_net.cc | 23 | ||||
-rw-r--r-- | training/dtrain/dtrain_net.h | 4 | ||||
-rw-r--r-- | training/dtrain/dtrain_net_interface.cc | 411 | ||||
-rw-r--r-- | training/dtrain/dtrain_net_interface.h | 134 | ||||
-rwxr-xr-x | training/dtrain/feed.rb | 22 | ||||
-rw-r--r-- | training/dtrain/nn.hpp | 204 | ||||
-rw-r--r-- | training/dtrain/sample.h | 3 | ||||
-rw-r--r-- | training/dtrain/sample_net_interface.h | 68 | ||||
-rw-r--r-- | training/dtrain/score.h | 2 | ||||
-rw-r--r-- | training/dtrain/score_net_interface.h | 200 |
11 files changed, 838 insertions, 240 deletions
diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am index 82aac988..74c2a4b2 100644 --- a/training/dtrain/Makefile.am +++ b/training/dtrain/Makefile.am @@ -1,4 +1,4 @@ -bin_PROGRAMS = dtrain dtrain_net +bin_PROGRAMS = dtrain dtrain_net dtrain_net_interface dtrain_SOURCES = dtrain.cc dtrain.h sample.h score.h update.h dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a @@ -6,5 +6,8 @@ dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mte dtrain_net_SOURCES = dtrain_net.cc dtrain_net.h dtrain.h sample.h score.h update.h dtrain_net_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a /usr/local/lib/libnanomsg.so -AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -I /usr/local/include +dtrain_net_interface_SOURCES = dtrain_net_interface.cc dtrain_net_interface.h dtrain.h sample_net_interface.h score_net_interface.h update.h +dtrain_net_interface_LDFLAGS = -rdynamic +dtrain_net_interface_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a /srv/postedit/lib/nanomsg-0.5-beta/lib/libnanomsg.so +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -I/usr/local/include -I/srv/postedit/lib/cppnanomsg diff --git a/training/dtrain/dtrain_net.cc b/training/dtrain/dtrain_net.cc index 946b7587..306da957 100644 --- a/training/dtrain/dtrain_net.cc +++ b/training/dtrain/dtrain_net.cc @@ -67,16 +67,19 @@ main(int argc, char** argv) } else { vector<string> parts; boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| ")); - source = parts[0]; - parts.erase(parts.begin()); - for (auto s: parts) { - vector<WordID> r; - vector<string> toks; - boost::split(toks, s, boost::is_any_of(" ")); - for (auto tok: toks) - r.push_back(TD::Convert(tok)); - refs.emplace_back(MakeNgrams(r, N)); - rsz.push_back(r.size()); + if (parts[0] == "act:translate") { + } else { + source = parts[0]; + parts.erase(parts.begin()); + for (auto s: parts) { + vector<WordID> r; + vector<string> toks; + boost::split(toks, s, boost::is_any_of(" ")); + for (auto tok: toks) + r.push_back(TD::Convert(tok)); + refs.emplace_back(MakeNgrams(r, N)); + rsz.push_back(r.size()); + } } } } diff --git a/training/dtrain/dtrain_net.h b/training/dtrain/dtrain_net.h index 24f95500..e0d33d64 100644 --- a/training/dtrain/dtrain_net.h +++ b/training/dtrain/dtrain_net.h @@ -42,7 +42,9 @@ dtrain_net_init(int argc, char** argv, po::variables_map* conf) ("decoder_conf,C", po::value<string>(), "configuration file for decoder") ("k", po::value<size_t>()->default_value(100), "size of kbest list") ("N", po::value<size_t>()->default_value(4), "N for BLEU approximation") - ("margin,m", po::value<weight_t>()->default_value(0.), "margin for margin perceptron"); + ("margin,m", po::value<weight_t>()->default_value(0.), "margin for margin perceptron") + ("output,o", po::value<string>()->default_value(""), "final weights file") + ("input_weights,w", po::value<string>(), "input weights file"); po::options_description cl("Command Line Options"); cl.add_options() ("conf,c", po::value<string>(), "dtrain configuration file") diff --git a/training/dtrain/dtrain_net_interface.cc b/training/dtrain/dtrain_net_interface.cc new file mode 100644 index 00000000..37dff496 --- /dev/null +++ b/training/dtrain/dtrain_net_interface.cc @@ -0,0 +1,411 @@ +#include "dtrain_net_interface.h" +#include "sample_net_interface.h" +#include "score_net_interface.h" +#include "update.h" + +#include <nanomsg/nn.h> +#include <nanomsg/pair.h> +#include "nn.hpp" + +#include <sys/types.h> // mkfifo +#include <sys/stat.h> +#include <stdio.h> +#include <unistd.h> +#include <stdlib.h> +#include <fcntl.h> + + +using namespace dtrain; + +int +main(int argc, char** argv) +{ + // get configuration + po::variables_map conf; + if (!dtrain_net_init(argc, argv, &conf)) + exit(1); // something is wrong + const size_t k = conf["k"].as<size_t>(); + const size_t N = conf["N"].as<size_t>(); + const weight_t margin = conf["margin"].as<weight_t>(); + const string master_addr = conf["addr"].as<string>(); + const string output_fn = conf["output"].as<string>(); + const string debug_fn = conf["debug_output"].as<string>(); + vector<string> dense_features; + boost::split(dense_features, conf["dense_features"].as<string>(), + boost::is_any_of(" ")); + const bool output_derivation = conf["output_derivation"].as<bool>(); + const bool output_rules = conf["output_rules"].as<bool>(); + + // update lm + /*if (conf["update_lm_fn"].as<string>() != "") + mkfifo(conf["update_lm_fn"].as<string>().c_str(), 0666);*/ + + // setup socket + nn::socket sock(AF_SP, NN_PAIR); + sock.bind(master_addr.c_str()); + string hello = "hello"; + sock.send(hello.c_str(), hello.size()+1, 0); + + // setup decoder + register_feature_functions(); + SetSilent(true); + ReadFile f(conf["decoder_conf"].as<string>()); + Decoder decoder(f.stream()); + ScoredKbest* observer = new ScoredKbest(k, new PerSentenceBleuScorer(N)); + + // weights + vector<weight_t>& decoder_weights = decoder.CurrentWeightVector(); + SparseVector<weight_t> lambdas, w_average, original_lambdas; + if (conf.count("input_weights")) { + Weights::InitFromFile(conf["input_weights"].as<string>(), &decoder_weights); + Weights::InitSparseVector(decoder_weights, &lambdas); + Weights::InitSparseVector(decoder_weights, &original_lambdas); + } + + // learning rates + SparseVector<weight_t> learning_rates, original_learning_rates; + weight_t learning_rate_R, original_learning_rate_R; + weight_t learning_rate_RB, original_learning_rate_RB; + weight_t learning_rate_Shape, original_learning_rate_Shape; + vector<weight_t> l; + Weights::InitFromFile(conf["learning_rates"].as<string>(), &l); + Weights::InitSparseVector(l, &learning_rates); + original_learning_rates = learning_rates; + learning_rate_R = conf["learning_rate_R"].as<weight_t>(); + original_learning_rate_R = learning_rate_R; + learning_rate_RB = conf["learning_rate_RB"].as<weight_t>(); + original_learning_rate_RB = learning_rate_RB; + learning_rate_Shape = conf["learning_rate_Shape"].as<weight_t>(); + original_learning_rate_Shape = learning_rate_Shape; + + cerr << _p4; + // output configuration + cerr << "dtrain_net_interface" << endl << "Parameters:" << endl; + cerr << setw(25) << "k " << k << endl; + cerr << setw(25) << "N " << N << endl; + cerr << setw(25) << "margin " << margin << endl; + cerr << setw(25) << "decoder conf " << "'" + << conf["decoder_conf"].as<string>() << "'" << endl; + cerr << setw(25) << "output " << "'" << output_fn << "'" << endl; + cerr << setw(25) << "debug " << "'" << debug_fn << "'" << endl; + cerr << setw(25) << "learning rates " << "'" + << conf["learning_rates"].as<string>() << "'" << endl; + cerr << setw(25) << "learning rate R " << learning_rate_R << endl; + cerr << setw(25) << "learning rate RB " << learning_rate_RB << endl; + cerr << setw(25) << "learning rate Shape " << learning_rate_Shape << endl; + + // debug + ostringstream debug_output; + + string done = "done"; + + vector<ScoredHyp>* samples; + + size_t i = 0; + while(true) + { + cerr << "[dtrain] looping" << endl; + // debug -- + debug_output.str(string()); + debug_output.clear(); + debug_output << "{" << endl; // hack us a nice JSON output + // -- debug + + bool just_translate = false; + + char *buf = NULL; + string source; + vector<Ngrams> refs; + vector<size_t> rsz; + bool next = true; + size_t sz = sock.recv(&buf, NN_MSG, 0); + if (buf) { + const string in(buf, buf+sz); + nn::freemsg(buf); + cerr << "[dtrain] got input '" << in << "'" << endl; + if (boost::starts_with(in, "set_learning_rates")) { // set learning rates + stringstream ss(in); + string _,name; weight_t w; + ss >> _; ss >> name; ss >> w; + weight_t before = 0; + ostringstream long_name; + if (name == "R") { + before = learning_rate_R; + learning_rate_R = w; + long_name << "rule id feature group"; + } else if (name == "RB") { + before = learning_rate_RB; + learning_rate_RB = w; + long_name << "rule bigram feature group"; + } else if (name == "Shape") { + before = learning_rate_Shape; + learning_rate_Shape = w; + long_name << "rule shape feature group"; + } else { + unsigned fid = FD::Convert(name); + before = learning_rates[fid]; + learning_rates[fid] = w; + long_name << "feature '" << name << "'"; + } + ostringstream o; + o << "set learning rate for " << long_name.str() << " to " << w + << " (was: " << before << ")" << endl; + string s = o.str(); + cerr << "[dtrain] " << s; + cerr << "[dtrain] done, looping again" << endl; + sock.send(s.c_str(), s.size()+1, 0); + continue; + } else if (boost::starts_with(in, "reset_learning_rates")) { + cerr << "[dtrain] resetting learning rates" << endl; + learning_rates = original_learning_rates; + learning_rate_R = original_learning_rate_R; + learning_rate_RB = original_learning_rate_RB; + learning_rate_Shape = original_learning_rate_Shape; + cerr << "[dtrain] done, looping again" << endl; + sock.send(done.c_str(), done.size()+1, 0); + continue; + } else if (boost::starts_with(in, "set_weights")) { // set learning rates + stringstream ss(in); + string _,name; weight_t w; + ss >> _; ss >> name; ss >> w; + weight_t before = 0; + ostringstream o; + unsigned fid = FD::Convert(name); + before = lambdas[fid]; + lambdas[fid] = w; + o << "set weight for feature '" << name << "'" + << "' to " << w << " (was: " << before << ")" << endl; + string s = o.str(); + cerr << "[dtrain] " << s; + cerr << "[dtrain] done, looping again" << endl; + sock.send(s.c_str(), s.size()+1, 0); + continue; + } else if (boost::starts_with(in, "reset_weights")) { // reset weights + cerr << "[dtrain] resetting weights" << endl; + lambdas = original_lambdas; + cerr << "[dtrain] done, looping again" << endl; + sock.send(done.c_str(), done.size()+1, 0); + continue; + } else if (in == "shutdown") { // shut down + cerr << "[dtrain] got shutdown signal" << endl; + next = false; + continue; + } else if (boost::starts_with(in, "get_weight")) { // get weight + stringstream ss(in); + string _,name; + ss >> _; ss >> name; + cerr << "[dtrain] getting weight for " << name << endl; + ostringstream o; + unsigned fid = FD::Convert(name); + weight_t w = lambdas[fid]; + o << w; + string s = o.str(); + sock.send(s.c_str(), s.size()+1, 0); + continue; + } else if (boost::starts_with(in, "get_rate")) { // get rate + stringstream ss(in); + string _,name; + ss >> _; ss >> name; + cerr << "[dtrain] getting rate for " << name << endl; + ostringstream o; + unsigned fid = FD::Convert(name); + weight_t r; + if (name == "R") + r = learning_rate_R; + else if (name == "RB") + r = learning_rate_RB; + else if (name == "Shape") + r = learning_rate_Shape; + else + r = learning_rates[fid]; + o << r; + string s = o.str(); + sock.send(s.c_str(), s.size()+1, 0); + continue; + } else { // translate + vector<string> parts; + boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| ")); + if (parts[0] == "act:translate" || parts[0] == "act:translate_learn") { + if (parts[0] == "act:translate") + just_translate = true; + cerr << "[dtrain] translating ..." << endl; + lambdas.init_vector(&decoder_weights); + observer->dont_score = true; + decoder.Decode(parts[1], observer); + observer->dont_score = false; + samples = observer->GetSamples(); + if (parts[0] == "act:translate") { + ostringstream os; + cerr << "[dtrain] 1best features " << (*samples)[0].f << endl; + if (output_derivation) { + os << observer->GetViterbiTreeStr() << endl; + } else { + PrintWordIDVec((*samples)[0].w, os); + } + if (output_rules) { + os << observer->GetViterbiRules() << endl; + } + sock.send(os.str().c_str(), os.str().size()+1, 0); + cerr << "[dtrain] done translating, looping again" << endl; + } + } //else { // learn + if (!just_translate) { + cerr << "[dtrain] learning ..." << endl; + source = parts[1]; + // debug -- + debug_output << "\"source\":\"" + << escapeJson(source.substr(source.find_first_of(">")+2, source.find_last_of(">")-6)) + << "\"," << endl; + debug_output << "\"target\":\"" << escapeJson(parts[2]) << "\"," << endl; + // -- debug + parts.erase(parts.begin()); + parts.erase(parts.begin()); + for (auto s: parts) { + vector<WordID> r; + vector<string> toks; + boost::split(toks, s, boost::is_any_of(" ")); + for (auto tok: toks) + r.push_back(TD::Convert(tok)); + refs.emplace_back(MakeNgrams(r, N)); + rsz.push_back(r.size()); + } + + for (size_t r = 0; r < samples->size(); r++) + (*samples)[r].gold = observer->scorer_->Score((*samples)[r].w, refs, rsz); + //} + //} + } + } + } + + if (!next) + break; + + // decode + lambdas.init_vector(&decoder_weights); + + // debug --) + ostringstream os; + PrintWordIDVec((*samples)[0].w, os); + debug_output << "\"1best\":\"" << escapeJson(os.str()); + debug_output << "\"," << endl; + debug_output << "\"kbest\":[" << endl; + size_t h = 0; + for (auto s: *samples) { + debug_output << "\"" << s.gold << " ||| " + << s.model << " ||| " << s.rank << " ||| "; + for (auto o: s.f) + debug_output << escapeJson(FD::Convert(o.first)) << "=" << o.second << " "; + debug_output << " ||| "; + ostringstream os; + PrintWordIDVec(s.w, os); + debug_output << escapeJson(os.str()); + h += 1; + debug_output << "\""; + if (h < samples->size()) { + debug_output << ","; + } + debug_output << endl; + } + + debug_output << "]," << endl; + debug_output << "\"samples_size\":" << samples->size() << "," << endl; + debug_output << "\"weights_before\":{" << endl; + sparseVectorToJson(lambdas, debug_output); + debug_output << "}," << endl; + // -- debug + // + + // get pairs + SparseVector<weight_t> update; + size_t num_up = CollectUpdates(samples, update, margin); + + // debug -- + debug_output << "\"1best_features\":{"; + sparseVectorToJson((*samples)[0].f, debug_output); + debug_output << "}," << endl; + debug_output << "\"update_raw\":{"; + sparseVectorToJson(update, debug_output); + debug_output << "}," << endl; + // -- debug + + // update + for (auto it: update) { + string fname = FD::Convert(it.first); + unsigned k = it.first; + weight_t v = it.second; + if (learning_rates.find(it.first) != learning_rates.end()) { + update[k] = learning_rates[k]*v; + } else { + if (boost::starts_with(fname, "R:")) { + update[k] = learning_rate_R*v; + } else if (boost::starts_with(fname, "RBS:") || + boost::starts_with(fname, "RBT:")) { + update[k] = learning_rate_RB*v; + } else if (boost::starts_with(fname, "Shape_")) { + update[k] = learning_rate_Shape*v; + } + } + } + if (!just_translate) { + lambdas += update; + } else { + i++; + } + + // debug -- + debug_output << "\"update\":{"; + sparseVectorToJson(update, debug_output); + debug_output << "}," << endl; + debug_output << "\"num_up\":" << num_up << "," << endl; + debug_output << "\"updated_features\":" << update.size() << "," << endl; + debug_output << "\"learning_rate_R\":" << learning_rate_R << "," << endl; + debug_output << "\"learning_rate_RB\":" << learning_rate_R << "," << endl; + debug_output << "\"learning_rate_Shape\":" << learning_rate_R << "," << endl; + debug_output << "\"learning_rates\":{" << endl; + sparseVectorToJson(learning_rates, debug_output); + debug_output << "}," << endl; + debug_output << "\"best_match\":\""; + ostringstream ps; + PrintWordIDVec((*samples)[0].w, ps); + debug_output << escapeJson(ps.str()); + debug_output << "\"," << endl; + debug_output << "\"best_match_score\":" << (*samples)[0].gold << "," << endl ; + // -- debug + + // debug -- + debug_output << "\"weights_after\":{" << endl; + sparseVectorToJson(lambdas, debug_output); + debug_output << "}" << endl; + debug_output << "}" << endl; + // -- debug + + // debug -- + WriteFile f(debug_fn); + f.get() << debug_output.str(); + f.get() << std::flush; + // -- debug + + // write current weights + if (!just_translate) { + lambdas.init_vector(decoder_weights); + ostringstream fn; + fn << output_fn << "." << i << ".gz"; + Weights::WriteToFile(fn.str(), decoder_weights, true); + } + + if (!just_translate) { + cerr << "[dtrain] done learning, looping again" << endl; + sock.send(done.c_str(), done.size()+1, 0); + } + + } // input loop + + string shutdown = "off"; + sock.send(shutdown.c_str(), shutdown.size()+1, 0); + + cerr << "[dtrain] shutting down, goodbye" << endl; + + return 0; +} + diff --git a/training/dtrain/dtrain_net_interface.h b/training/dtrain/dtrain_net_interface.h new file mode 100644 index 00000000..91c2e538 --- /dev/null +++ b/training/dtrain/dtrain_net_interface.h @@ -0,0 +1,134 @@ +#ifndef _DTRAIN_NET_INTERFACE_H_ +#define _DTRAIN_NET_INTERFACE_H_ + +#include "dtrain.h" + +namespace dtrain +{ + +/* + * source: http://stackoverflow.com/questions/7724448/\ + simple-json-string-escape-for-c/33799784#33799784 + * + */ +inline string +escapeJson(const string& s) { + ostringstream o; + for (auto c = s.cbegin(); c != s.cend(); c++) { + switch (*c) { + case '"': o << "\\\""; break; + case '\\': o << "\\\\"; break; + case '\b': o << "\\b"; break; + case '\f': o << "\\f"; break; + case '\n': o << "\\n"; break; + case '\r': o << "\\r"; break; + case '\t': o << "\\t"; break; + default: + if ('\x00' <= *c && *c <= '\x1f') { + o << "\\u" + << std::hex << std::setw(4) << std::setfill('0') << (int)*c; + } else { + o << *c; + } + } + } + return o.str(); +} + +inline void +sparseVectorToJson(SparseVector<weight_t>& w, ostringstream& os) +{ + vector<string> strs; + for (typename SparseVector<weight_t>::iterator it=w.begin(),e=w.end(); it!=e; ++it) { + ostringstream a; + a << "\"" << escapeJson(FD::Convert(it->first)) << "\":" << it->second; + strs.push_back(a.str()); + } + for (vector<string>::const_iterator it=strs.begin(); it!=strs.end(); it++) { + os << *it; + if ((it+1) != strs.end()) + os << ","; + os << endl; + } +} + +template<typename T> +inline void +vectorAsString(SparseVector<T>& v, ostringstream& os) +{ + SparseVector<weight_t>::iterator it = v.begin(); + for (; it != v.end(); ++it) { + os << FD::Convert(it->first) << "=" << it->second; + auto peek = it; + if (++peek != v.end()) + os << " "; + } +} + +template<typename T> +inline void +updateVectorFromString(string& s, SparseVector<T>& v) +{ + string buf; + istringstream ss; + while (ss >> buf) { + size_t p = buf.find_last_of("="); + istringstream c(buf.substr(p+1,buf.size())); + weight_t val; + c >> val; + v[FD::Convert(buf.substr(0,p))] = val; + } +} + +bool +dtrain_net_init(int argc, char** argv, po::variables_map* conf) +{ + po::options_description ini("Configuration File Options"); + ini.add_options() + ("decoder_conf,C", po::value<string>(), "configuration file for decoder") + ("k", po::value<size_t>()->default_value(100), "size of kbest list") + ("N", po::value<size_t>()->default_value(4), "N for BLEU approximation") + ("margin,m", po::value<weight_t>()->default_value(0.), "margin for margin perceptron") + ("output,o", po::value<string>()->default_value(""), "final weights file") + ("input_weights,w", po::value<string>(), "input weights file") + ("learning_rates,l", po::value<string>(), "pre-defined learning rates per feature") + ("learning_rate_R", po::value<weight_t>(), "learning rate for rule id features") + ("learning_rate_RB", po::value<weight_t>(), "learning rate for rule bigram features") + ("learning_rate_Shape", po::value<weight_t>(), "learning rate for shape features") + ("output_derivation,E", po::bool_switch()->default_value(false), "output derivation, not viterbi str") + ("output_rules,R", po::bool_switch()->default_value(false), "also output rules") + ("update_lm_fn", po::value<string>()->default_value(""), "TODO") + ("dense_features,D", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV Shape_S01111_T11011 Shape_S11110_T11011 Shape_S11100_T11000 Shape_S01110_T01110 Shape_S01111_T01111 Shape_S01100_T11000 Shape_S10000_T10000 Shape_S11100_T11100 Shape_S11110_T11110 Shape_S11110_T11010 Shape_S01100_T11100 Shape_S01000_T01000 Shape_S01010_T01010 Shape_S01111_T01011 Shape_S01100_T01100 Shape_S01110_T11010 Shape_S11000_T11000 Shape_S11000_T01100 IsSupportedOnline NewRule KnownRule OOVFix"), + "dense features") + ("debug_output,d", po::value<string>()->default_value(""), "file for debug output"); + po::options_description cl("Command Line Options"); + cl.add_options() + ("conf,c", po::value<string>(), "dtrain configuration file") + ("addr,a", po::value<string>(), "address of master"); + cl.add(ini); + po::store(parse_command_line(argc, argv, cl), *conf); + if (conf->count("conf")) { + ifstream f((*conf)["conf"].as<string>().c_str()); + po::store(po::parse_config_file(f, ini), *conf); + } + po::notify(*conf); + if (!conf->count("decoder_conf")) { + cerr << "Missing decoder configuration. Exiting." << endl; + return false; + } + if (!conf->count("learning_rates")) { + cerr << "Missing learning rates. Exiting." << endl; + return false; + } + if (!conf->count("addr")) { + cerr << "No master address given! Exiting." << endl; + return false; + } + + return true; +} + +} // namespace + +#endif + diff --git a/training/dtrain/feed.rb b/training/dtrain/feed.rb deleted file mode 100755 index fe8dd509..00000000 --- a/training/dtrain/feed.rb +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env ruby - -require 'nanomsg' - -port = ARGV[0] -sock = NanoMsg::PairSocket.new -addr = "tcp://127.0.0.1:#{port}" -sock.bind addr - -puts sock.recv -while true - line = STDIN.gets - if !line - sock.send 'shutdown' - break - end - sock.send line.strip - sleep 1 - sock.recv - sock.send "a=1 b=2" -end - diff --git a/training/dtrain/nn.hpp b/training/dtrain/nn.hpp deleted file mode 100644 index 50b8304c..00000000 --- a/training/dtrain/nn.hpp +++ /dev/null @@ -1,204 +0,0 @@ -/* - Copyright (c) 2013 250bpm s.r.o. - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), - to deal in the Software without restriction, including without limitation - the rights to use, copy, modify, merge, publish, distribute, sublicense, - and/or sell copies of the Software, and to permit persons to whom - the Software is furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included - in all copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL - THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS - IN THE SOFTWARE. -*/ - -#ifndef NN_HPP_INCLUDED -#define NN_HPP_INCLUDED - -#include <nanomsg/nn.h> - -#include <cassert> -#include <cstring> -#include <algorithm> -#include <exception> - -#if defined __GNUC__ -#define nn_slow(x) __builtin_expect ((x), 0) -#else -#define nn_slow(x) (x) -#endif - -namespace nn -{ - - class exception : public std::exception - { - public: - - exception () : err (nn_errno ()) {} - - virtual const char *what () const throw () - { - return nn_strerror (err); - } - - int num () const - { - return err; - } - - private: - - int err; - }; - - inline const char *symbol (int i, int *value) - { - return nn_symbol (i, value); - } - - inline void *allocmsg (size_t size, int type) - { - void *msg = nn_allocmsg (size, type); - if (nn_slow (!msg)) - throw nn::exception (); - return msg; - } - - inline int freemsg (void *msg) - { - int rc = nn_freemsg (msg); - if (nn_slow (rc != 0)) - throw nn::exception (); - return rc; - } - - class socket - { - public: - - inline socket (int domain, int protocol) - { - s = nn_socket (domain, protocol); - if (nn_slow (s < 0)) - throw nn::exception (); - } - - inline ~socket () - { - int rc = nn_close (s); - assert (rc == 0); - } - - inline void setsockopt (int level, int option, const void *optval, - size_t optvallen) - { - int rc = nn_setsockopt (s, level, option, optval, optvallen); - if (nn_slow (rc != 0)) - throw nn::exception (); - } - - inline void getsockopt (int level, int option, void *optval, - size_t *optvallen) - { - int rc = nn_getsockopt (s, level, option, optval, optvallen); - if (nn_slow (rc != 0)) - throw nn::exception (); - } - - inline int bind (const char *addr) - { - int rc = nn_bind (s, addr); - if (nn_slow (rc < 0)) - throw nn::exception (); - return rc; - } - - inline int connect (const char *addr) - { - int rc = nn_connect (s, addr); - if (nn_slow (rc < 0)) - throw nn::exception (); - return rc; - } - - inline void shutdown (int how) - { - int rc = nn_shutdown (s, how); - if (nn_slow (rc != 0)) - throw nn::exception (); - } - - inline int send (const void *buf, size_t len, int flags) - { - int rc = nn_send (s, buf, len, flags); - if (nn_slow (rc < 0)) { - if (nn_slow (nn_errno () != EAGAIN)) - throw nn::exception (); - return -1; - } - return rc; - } - - inline int recv (void *buf, size_t len, int flags) - { - int rc = nn_recv (s, buf, len, flags); - if (nn_slow (rc < 0)) { - if (nn_slow (nn_errno () != EAGAIN)) - throw nn::exception (); - return -1; - } - return rc; - } - - inline int sendmsg (const struct nn_msghdr *msghdr, int flags) - { - int rc = nn_sendmsg (s, msghdr, flags); - if (nn_slow (rc < 0)) { - if (nn_slow (nn_errno () != EAGAIN)) - throw nn::exception (); - return -1; - } - return rc; - } - - inline int recvmsg (struct nn_msghdr *msghdr, int flags) - { - int rc = nn_recvmsg (s, msghdr, flags); - if (nn_slow (rc < 0)) { - if (nn_slow (nn_errno () != EAGAIN)) - throw nn::exception (); - return -1; - } - return rc; - } - - private: - - int s; - - /* Prevent making copies of the socket by accident. */ - socket (const socket&); - void operator = (const socket&); - }; - - inline void term () - { - nn_term (); - } - -} - -#undef nn_slow - -#endif - - diff --git a/training/dtrain/sample.h b/training/dtrain/sample.h index 03cc82c3..e24b65cf 100644 --- a/training/dtrain/sample.h +++ b/training/dtrain/sample.h @@ -16,6 +16,7 @@ struct ScoredKbest : public DecoderObserver PerSentenceBleuScorer* scorer_; vector<Ngrams>* ref_ngs_; vector<size_t>* ref_ls_; + string viterbi_tree_str; ScoredKbest(const size_t k, PerSentenceBleuScorer* scorer) : k_(k), scorer_(scorer) {} @@ -40,6 +41,7 @@ struct ScoredKbest : public DecoderObserver samples_.push_back(h); effective_sz_++; feature_count_ += h.f.size(); + viterbi_tree_str = hg->show_viterbi_tree(false); } } @@ -51,6 +53,7 @@ struct ScoredKbest : public DecoderObserver } inline size_t GetFeatureCount() { return feature_count_; } inline size_t GetSize() { return effective_sz_; } + inline string GetViterbiTreeString() { return viterbi_tree_str; } }; } // namespace diff --git a/training/dtrain/sample_net_interface.h b/training/dtrain/sample_net_interface.h new file mode 100644 index 00000000..6d00e5d5 --- /dev/null +++ b/training/dtrain/sample_net_interface.h @@ -0,0 +1,68 @@ +#ifndef _DTRAIN_SAMPLE_NET_H_ +#define _DTRAIN_SAMPLE_NET_H_ + +#include "kbest.h" + +#include "score_net_interface.h" + +namespace dtrain +{ + +struct ScoredKbest : public DecoderObserver +{ + const size_t k_; + size_t feature_count_, effective_sz_; + vector<ScoredHyp> samples_; + PerSentenceBleuScorer* scorer_; + vector<Ngrams>* ref_ngs_; + vector<size_t>* ref_ls_; + bool dont_score; + string viterbiTreeStr_, viterbiRules_; + + ScoredKbest(const size_t k, PerSentenceBleuScorer* scorer) : + k_(k), scorer_(scorer), dont_score(false) {} + + virtual void + NotifyTranslationForest(const SentenceMetadata& /*smeta*/, Hypergraph* hg) + { + samples_.clear(); effective_sz_ = feature_count_ = 0; + KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, + KBest::FilterUnique, prob_t, EdgeProb> kbest(*hg, k_); + for (size_t i = 0; i < k_; ++i) { + const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, + KBest::FilterUnique, prob_t, EdgeProb>::Derivation* d = + kbest.LazyKthBest(hg->nodes_.size() - 1, i); + if (!d) break; + ScoredHyp h; + h.w = d->yield; + h.f = d->feature_values; + h.model = log(d->score); + h.rank = i; + if (!dont_score) + h.gold = scorer_->Score(h.w, *ref_ngs_, *ref_ls_); + samples_.push_back(h); + effective_sz_++; + feature_count_ += h.f.size(); + viterbiTreeStr_ = hg->show_viterbi_tree(false); + ostringstream ss; + ViterbiRules(*hg, &ss); + viterbiRules_ = ss.str(); + } + } + + vector<ScoredHyp>* GetSamples() { return &samples_; } + inline void SetReference(vector<Ngrams>& ngs, vector<size_t>& ls) + { + ref_ngs_ = &ngs; + ref_ls_ = &ls; + } + inline size_t GetFeatureCount() { return feature_count_; } + inline size_t GetSize() { return effective_sz_; } + inline string GetViterbiTreeStr() { return viterbiTreeStr_; } + inline string GetViterbiRules() { return viterbiRules_; } +}; + +} // namespace + +#endif + diff --git a/training/dtrain/score.h b/training/dtrain/score.h index 06dbc5a4..e6e60acb 100644 --- a/training/dtrain/score.h +++ b/training/dtrain/score.h @@ -153,7 +153,7 @@ struct PerSentenceBleuScorer size_t best = numeric_limits<size_t>::max(); for (auto l: ref_ls) { size_t d = abs(hl-l); - if (d < best) { + if (d < best) { best_idx = i; best = d; } diff --git a/training/dtrain/score_net_interface.h b/training/dtrain/score_net_interface.h new file mode 100644 index 00000000..58357cf6 --- /dev/null +++ b/training/dtrain/score_net_interface.h @@ -0,0 +1,200 @@ +#ifndef _DTRAIN_SCORE_NET_INTERFACE_H_ +#define _DTRAIN_SCORE_NET_INTERFACE_H_ + +#include "dtrain.h" + +namespace dtrain +{ + +struct NgramCounts +{ + size_t N_; + map<size_t, weight_t> clipped_; + map<size_t, weight_t> sum_; + + NgramCounts(const size_t N) : N_(N) { Zero(); } + + inline void + operator+=(const NgramCounts& rhs) + { + if (rhs.N_ > N_) Resize(rhs.N_); + for (size_t i = 0; i < N_; i++) { + this->clipped_[i] += rhs.clipped_.find(i)->second; + this->sum_[i] += rhs.sum_.find(i)->second; + } + } + + inline const NgramCounts + operator+(const NgramCounts &other) const + { + NgramCounts result = *this; + result += other; + + return result; + } + + inline void + Add(const size_t count, const size_t ref_count, const size_t i) + { + assert(i < N_); + if (count > ref_count) { + clipped_[i] += ref_count; + } else { + clipped_[i] += count; + } + sum_[i] += count; + } + + inline void + Zero() + { + for (size_t i = 0; i < N_; i++) { + clipped_[i] = 0.; + sum_[i] = 0.; + } + } + + inline void + Resize(size_t N) + { + if (N == N_) return; + else if (N > N_) { + for (size_t i = N_; i < N; i++) { + clipped_[i] = 0.; + sum_[i] = 0.; + } + } else { // N < N_ + for (size_t i = N_-1; i > N-1; i--) { + clipped_.erase(i); + sum_.erase(i); + } + } + N_ = N; + } +}; + +typedef map<vector<WordID>, size_t> Ngrams; + +inline Ngrams +MakeNgrams(const vector<WordID>& s, const size_t N) +{ + Ngrams ngrams; + vector<WordID> ng; + for (size_t i = 0; i < s.size(); i++) { + ng.clear(); + for (size_t j = i; j < min(i+N, s.size()); j++) { + ng.push_back(s[j]); + ngrams[ng]++; + } + } + + return ngrams; +} + +inline NgramCounts +MakeNgramCounts(const vector<WordID>& hyp, + const vector<Ngrams>& ref, + const size_t N) +{ + Ngrams hyp_ngrams = MakeNgrams(hyp, N); + NgramCounts counts(N); + Ngrams::iterator it, ti; + for (it = hyp_ngrams.begin(); it != hyp_ngrams.end(); it++) { + size_t max_ref_count = 0; + for (auto r: ref) { + ti = r.find(it->first); + if (ti != r.end()) + max_ref_count = max(max_ref_count, ti->second); + } + counts.Add(it->second, min(it->second, max_ref_count), it->first.size()-1); + } + + return counts; +} + +/* + * per-sentence BLEU + * as in "Optimizing for Sentence-Level BLEU+1 + * Yields Short Translations" + * (Nakov et al. '12) + * + * [simply add 1 to reference length for calculation of BP] + * + */ +struct PerSentenceBleuScorer +{ + const size_t N_; + vector<weight_t> w_; + + PerSentenceBleuScorer(size_t n) : N_(n) + { + for (size_t i = 1; i <= N_; i++) + w_.push_back(1.0/N_); + } + + inline weight_t + BrevityPenalty(const size_t hl, const size_t rl) + { + if (hl > rl) + return 1; + + return exp(1 - (weight_t)rl/hl); + } + + inline size_t + BestMatchLength(const size_t hl, + const vector<size_t>& ref_ls) + { + size_t m; + if (ref_ls.size() == 1) { + m = ref_ls.front(); + } else { + size_t i = 0, best_idx = 0; + size_t best = numeric_limits<size_t>::max(); + for (auto l: ref_ls) { + size_t d = abs(hl-l); + if (d < best) { + best_idx = i; + best = d; + } + i += 1; + } + m = ref_ls[best_idx]; + } + + return m; + } + + weight_t + Score(const vector<WordID>& hyp, + const vector<Ngrams>& ref_ngs, + const vector<size_t>& ref_ls) + { + size_t hl = hyp.size(), rl = 0; + if (hl == 0) return 0.; + rl = BestMatchLength(hl, ref_ls); + if (rl == 0) return 0.; + NgramCounts counts = MakeNgramCounts(hyp, ref_ngs, N_); + size_t M = N_; + vector<weight_t> v = w_; + if (rl < N_) { + M = rl; + for (size_t i = 0; i < M; i++) v[i] = 1/((weight_t)M); + } + weight_t sum = 0, add = 0; + for (size_t i = 0; i < M; i++) { + if (i == 0 && (counts.sum_[i] == 0 || counts.clipped_[i] == 0)) return 0.; + if (i > 0) add = 1; + sum += v[i] * log(((weight_t)counts.clipped_[i] + add) + / ((counts.sum_[i] + add))); + } + + //return BrevityPenalty(hl, rl+1) * exp(sum); + return BrevityPenalty(hl, rl) * exp(sum); + } +}; + +} // namespace + +#endif + |