summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain_net_interface.h
diff options
context:
space:
mode:
authorpks <pks@users.noreply.github.com>2019-05-12 20:10:37 +0200
committerGitHub <noreply@github.com>2019-05-12 20:10:37 +0200
commit4a13b41700f34c15c30b551f98dbea9cb41f67c3 (patch)
tree0218f41c350a626f5af9909d77406309fa873fdf /training/dtrain/dtrain_net_interface.h
parente9268eb3dcd867f3baf67a7bb3d2aad56196ecde (diff)
parentf64746ac87fc7338629b19de9fa2da0f03fa2790 (diff)
Merge branch 'net' into origin/net
Diffstat (limited to 'training/dtrain/dtrain_net_interface.h')
-rw-r--r--training/dtrain/dtrain_net_interface.h134
1 files changed, 134 insertions, 0 deletions
diff --git a/training/dtrain/dtrain_net_interface.h b/training/dtrain/dtrain_net_interface.h
new file mode 100644
index 00000000..91c2e538
--- /dev/null
+++ b/training/dtrain/dtrain_net_interface.h
@@ -0,0 +1,134 @@
+#ifndef _DTRAIN_NET_INTERFACE_H_
+#define _DTRAIN_NET_INTERFACE_H_
+
+#include "dtrain.h"
+
+namespace dtrain
+{
+
+/*
+ * source: http://stackoverflow.com/questions/7724448/\
+ simple-json-string-escape-for-c/33799784#33799784
+ *
+ */
+inline string
+escapeJson(const string& s) {
+ ostringstream o;
+ for (auto c = s.cbegin(); c != s.cend(); c++) {
+ switch (*c) {
+ case '"': o << "\\\""; break;
+ case '\\': o << "\\\\"; break;
+ case '\b': o << "\\b"; break;
+ case '\f': o << "\\f"; break;
+ case '\n': o << "\\n"; break;
+ case '\r': o << "\\r"; break;
+ case '\t': o << "\\t"; break;
+ default:
+ if ('\x00' <= *c && *c <= '\x1f') {
+ o << "\\u"
+ << std::hex << std::setw(4) << std::setfill('0') << (int)*c;
+ } else {
+ o << *c;
+ }
+ }
+ }
+ return o.str();
+}
+
+inline void
+sparseVectorToJson(SparseVector<weight_t>& w, ostringstream& os)
+{
+ vector<string> strs;
+ for (typename SparseVector<weight_t>::iterator it=w.begin(),e=w.end(); it!=e; ++it) {
+ ostringstream a;
+ a << "\"" << escapeJson(FD::Convert(it->first)) << "\":" << it->second;
+ strs.push_back(a.str());
+ }
+ for (vector<string>::const_iterator it=strs.begin(); it!=strs.end(); it++) {
+ os << *it;
+ if ((it+1) != strs.end())
+ os << ",";
+ os << endl;
+ }
+}
+
+template<typename T>
+inline void
+vectorAsString(SparseVector<T>& v, ostringstream& os)
+{
+ SparseVector<weight_t>::iterator it = v.begin();
+ for (; it != v.end(); ++it) {
+ os << FD::Convert(it->first) << "=" << it->second;
+ auto peek = it;
+ if (++peek != v.end())
+ os << " ";
+ }
+}
+
+template<typename T>
+inline void
+updateVectorFromString(string& s, SparseVector<T>& v)
+{
+ string buf;
+ istringstream ss;
+ while (ss >> buf) {
+ size_t p = buf.find_last_of("=");
+ istringstream c(buf.substr(p+1,buf.size()));
+ weight_t val;
+ c >> val;
+ v[FD::Convert(buf.substr(0,p))] = val;
+ }
+}
+
+bool
+dtrain_net_init(int argc, char** argv, po::variables_map* conf)
+{
+ po::options_description ini("Configuration File Options");
+ ini.add_options()
+ ("decoder_conf,C", po::value<string>(), "configuration file for decoder")
+ ("k", po::value<size_t>()->default_value(100), "size of kbest list")
+ ("N", po::value<size_t>()->default_value(4), "N for BLEU approximation")
+ ("margin,m", po::value<weight_t>()->default_value(0.), "margin for margin perceptron")
+ ("output,o", po::value<string>()->default_value(""), "final weights file")
+ ("input_weights,w", po::value<string>(), "input weights file")
+ ("learning_rates,l", po::value<string>(), "pre-defined learning rates per feature")
+ ("learning_rate_R", po::value<weight_t>(), "learning rate for rule id features")
+ ("learning_rate_RB", po::value<weight_t>(), "learning rate for rule bigram features")
+ ("learning_rate_Shape", po::value<weight_t>(), "learning rate for shape features")
+ ("output_derivation,E", po::bool_switch()->default_value(false), "output derivation, not viterbi str")
+ ("output_rules,R", po::bool_switch()->default_value(false), "also output rules")
+ ("update_lm_fn", po::value<string>()->default_value(""), "TODO")
+ ("dense_features,D", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV Shape_S01111_T11011 Shape_S11110_T11011 Shape_S11100_T11000 Shape_S01110_T01110 Shape_S01111_T01111 Shape_S01100_T11000 Shape_S10000_T10000 Shape_S11100_T11100 Shape_S11110_T11110 Shape_S11110_T11010 Shape_S01100_T11100 Shape_S01000_T01000 Shape_S01010_T01010 Shape_S01111_T01011 Shape_S01100_T01100 Shape_S01110_T11010 Shape_S11000_T11000 Shape_S11000_T01100 IsSupportedOnline NewRule KnownRule OOVFix"),
+ "dense features")
+ ("debug_output,d", po::value<string>()->default_value(""), "file for debug output");
+ po::options_description cl("Command Line Options");
+ cl.add_options()
+ ("conf,c", po::value<string>(), "dtrain configuration file")
+ ("addr,a", po::value<string>(), "address of master");
+ cl.add(ini);
+ po::store(parse_command_line(argc, argv, cl), *conf);
+ if (conf->count("conf")) {
+ ifstream f((*conf)["conf"].as<string>().c_str());
+ po::store(po::parse_config_file(f, ini), *conf);
+ }
+ po::notify(*conf);
+ if (!conf->count("decoder_conf")) {
+ cerr << "Missing decoder configuration. Exiting." << endl;
+ return false;
+ }
+ if (!conf->count("learning_rates")) {
+ cerr << "Missing learning rates. Exiting." << endl;
+ return false;
+ }
+ if (!conf->count("addr")) {
+ cerr << "No master address given! Exiting." << endl;
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace
+
+#endif
+