summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain_net_interface.h
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2016-01-27 17:26:24 +0100
committerPatrick Simianer <p@simianer.de>2016-01-27 17:26:24 +0100
commitd5b5e9e31ca3f222ba6cfe5e788a14a087c0b66d (patch)
treecb52f789f0257d8d3235fd3dc1e0a8f27b2ddf1e /training/dtrain/dtrain_net_interface.h
parent7de6a7dc654a37a70999a4e6d06a8fb6efebb82f (diff)
dtrain_net_interface: support for per coordinate learning rates
Diffstat (limited to 'training/dtrain/dtrain_net_interface.h')
-rw-r--r--training/dtrain/dtrain_net_interface.h45
1 files changed, 40 insertions, 5 deletions
diff --git a/training/dtrain/dtrain_net_interface.h b/training/dtrain/dtrain_net_interface.h
index b201c7a3..720c4c9b 100644
--- a/training/dtrain/dtrain_net_interface.h
+++ b/training/dtrain/dtrain_net_interface.h
@@ -6,13 +6,42 @@
namespace dtrain
{
+/*
+ * source: http://stackoverflow.com/questions/7724448/\
+ simple-json-string-escape-for-c/33799784#33799784
+ *
+ */
+inline string
+escapeJson(const string& s) {
+ ostringstream o;
+ for (auto c = s.cbegin(); c != s.cend(); c++) {
+ switch (*c) {
+ case '"': o << "\\\""; break;
+ case '\\': o << "\\\\"; break;
+ case '\b': o << "\\b"; break;
+ case '\f': o << "\\f"; break;
+ case '\n': o << "\\n"; break;
+ case '\r': o << "\\r"; break;
+ case '\t': o << "\\t"; break;
+ default:
+ if ('\x00' <= *c && *c <= '\x1f') {
+ o << "\\u"
+ << std::hex << std::setw(4) << std::setfill('0') << (int)*c;
+ } else {
+ o << *c;
+ }
+ }
+ }
+ return o.str();
+}
+
inline void
-weightsToJson(SparseVector<weight_t>& w, ostringstream& os)
+sparseVectorToJson(SparseVector<weight_t>& w, ostringstream& os)
{
vector<string> strs;
for (typename SparseVector<weight_t>::iterator it=w.begin(),e=w.end(); it!=e; ++it) {
ostringstream a;
- a << "\"" << FD::Convert(it->first) << "\":" << it->second;
+ a << "\"" << escapeJson(FD::Convert(it->first)) << "\":" << it->second;
strs.push_back(a.str());
}
for (vector<string>::const_iterator it=strs.begin(); it!=strs.end(); it++) {
@@ -62,10 +91,12 @@ dtrain_net_init(int argc, char** argv, po::variables_map* conf)
("margin,m", po::value<weight_t>()->default_value(0.), "margin for margin perceptron")
("output,o", po::value<string>()->default_value(""), "final weights file")
("input_weights,w", po::value<string>(), "input weights file")
- ("learning_rate,l", po::value<weight_t>()->default_value(0.001), "learning rate")
- ("learning_rate_sparse,l", po::value<weight_t>()->default_value(0.00001), "learning rate for sparse features")
+ ("learning_rates,l", po::value<string>(), "pre-defined learning rates per feature")
+ ("learning_rate_R", po::value<weight_t>(), "learning rate for rule id features")
+ ("learning_rate_RB", po::value<weight_t>(), "learning rate for rule bigram features")
+ ("learning_rate_Shape", po::value<weight_t>(), "learning rate for shape features")
("output_derivation,E", po::bool_switch()->default_value(false), "output derivation, not viterbi str")
- ("output_rules,R", po::bool_switch()->default_value(false), "also output rules")
+ ("output_rules,R", po::bool_switch()->default_value(false), "also output rules")
("dense_features,D", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV Shape_S01111_T11011 Shape_S11110_T11011 Shape_S11100_T11000 Shape_S01110_T01110 Shape_S01111_T01111 Shape_S01100_T11000 Shape_S10000_T10000 Shape_S11100_T11100 Shape_S11110_T11110 Shape_S11110_T11010 Shape_S01100_T11100 Shape_S01000_T01000 Shape_S01010_T01010 Shape_S01111_T01011 Shape_S01100_T01100 Shape_S01110_T11010 Shape_S11000_T11000 Shape_S11000_T01100 IsSupportedOnline NewRule KnownRule OOVFix"),
"dense features")
("debug_output,d", po::value<string>()->default_value(""), "file for debug output");
@@ -84,6 +115,10 @@ dtrain_net_init(int argc, char** argv, po::variables_map* conf)
cerr << "Missing decoder configuration. Exiting." << endl;
return false;
}
+ if (!conf->count("learning_rates")) {
+ cerr << "Missing learning rates. Exiting." << endl;
+ return false;
+ }
if (!conf->count("addr")) {
cerr << "No master address given! Exiting." << endl;
return false;