summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/dtrain/dtrain_net_interface.cc98
1 files changed, 63 insertions, 35 deletions
diff --git a/training/dtrain/dtrain_net_interface.cc b/training/dtrain/dtrain_net_interface.cc
index 761930f8..ac447517 100644
--- a/training/dtrain/dtrain_net_interface.cc
+++ b/training/dtrain/dtrain_net_interface.cc
@@ -98,6 +98,8 @@ main(int argc, char** argv)
debug_output << "{" << endl; // hack us a nice JSON output
// -- debug
+ bool just_translate = false;
+
char *buf = NULL;
string source;
vector<Ngrams> refs;
@@ -210,36 +212,41 @@ main(int argc, char** argv)
} else { // translate
vector<string> parts;
boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| "));
- if (parts[0] == "act:translate") {
+ if (parts[0] == "act:translate" || parts[0] == "act:translate_learn") {
+ if (parts[0] == "act:translate")
+ just_translate = true;
cerr << "[dtrain] translating ..." << endl;
lambdas.init_vector(&decoder_weights);
observer->dont_score = true;
decoder.Decode(parts[1], observer);
observer->dont_score = false;
samples = observer->GetSamples();
- ostringstream os;
- cerr << "[dtrain] 1best features " << (*samples)[0].f << endl;
- if (output_derivation) {
- os << observer->GetViterbiTreeStr() << endl;
- } else {
- PrintWordIDVec((*samples)[0].w, os);
- }
- if (output_rules) {
- os << observer->GetViterbiRules() << endl;
+ if (parts[0] == "act:translate") {
+ ostringstream os;
+ cerr << "[dtrain] 1best features " << (*samples)[0].f << endl;
+ if (output_derivation) {
+ os << observer->GetViterbiTreeStr() << endl;
+ } else {
+ PrintWordIDVec((*samples)[0].w, os);
+ }
+ if (output_rules) {
+ os << observer->GetViterbiRules() << endl;
+ }
+ sock.send(os.str().c_str(), os.str().size()+1, 0);
+ cerr << "[dtrain] done translating, looping again" << endl;
}
- sock.send(os.str().c_str(), os.str().size()+1, 0);
- cerr << "[dtrain] done translating, looping again" << endl;
- continue;
- } else { // learn
+ } //else { // learn
+ if (!just_translate) {
cerr << "[dtrain] learning ..." << endl;
- source = parts[0];
+ source = parts[1];
// debug --
debug_output << "\"source\":\""
- << source.substr(source.find_first_of(">")+2, source.find_last_of(">")-6)
+ << escapeJson(source.substr(source.find_first_of(">")+2, source.find_last_of(">")-6))
<< "\"," << endl;
- debug_output << "\"target\":\"" << parts[1] << "\"," << endl;
+ debug_output << "\"target\":\"" << escapeJson(parts[2]) << "\"," << endl;
// -- debug
parts.erase(parts.begin());
+ parts.erase(parts.begin());
for (auto s: parts) {
vector<WordID> r;
vector<string> toks;
@@ -252,6 +259,8 @@ main(int argc, char** argv)
for (size_t r = 0; r < samples->size(); r++)
(*samples)[r].gold = observer->scorer_->Score((*samples)[r].w, refs, rsz);
+ //}
+ //}
}
}
}
@@ -262,9 +271,10 @@ main(int argc, char** argv)
// decode
lambdas.init_vector(&decoder_weights);
- // debug --
- debug_output << "\"1best\":\"";
- PrintWordIDVec((*samples)[0].w, debug_output);
+ // debug --)
+ ostringstream os;
+ PrintWordIDVec((*samples)[0].w, os);
+ debug_output << "\"1best\":\"" << escapeJson(os.str());
debug_output << "\"," << endl;
debug_output << "\"kbest\":[" << endl;
size_t h = 0;
@@ -272,9 +282,11 @@ main(int argc, char** argv)
debug_output << "\"" << s.gold << " ||| "
<< s.model << " ||| " << s.rank << " ||| ";
for (auto o: s.f)
- debug_output << FD::Convert(o.first) << "=" << o.second << " ";
+ debug_output << escapeJson(FD::Convert(o.first)) << "=" << o.second << " ";
debug_output << " ||| ";
- PrintWordIDVec(s.w, debug_output);
+ ostringstream os;
+ PrintWordIDVec(s.w, os);
+ debug_output << escapeJson(os.str());
h += 1;
debug_output << "\"";
if (h < samples->size()) {
@@ -296,8 +308,12 @@ main(int argc, char** argv)
size_t num_up = CollectUpdates(samples, update, margin);
// debug --
- debug_output << "\"1best_features\":\"" << (*samples)[0].f << "\"," << endl;
- debug_output << "\"update_raw\":\"" << update << "\"," << endl;
+ debug_output << "\"1best_features\":{";
+ sparseVectorToJson((*samples)[0].f, debug_output);
+ debug_output << "}," << endl;
+ debug_output << "\"update_raw\":{";
+ sparseVectorToJson(update, debug_output);
+ debug_output << "}," << endl;
// -- debug
// update
@@ -318,11 +334,16 @@ main(int argc, char** argv)
}
}
}
- lambdas += update;
- i++;
+ if (!just_translate) {
+ lambdas += update;
+ } else {
+ i++;
+ }
// debug --
- debug_output << "\"update\":\"" << update << "\"," << endl;
+ debug_output << "\"update\":{";
+ sparseVectorToJson(update, debug_output);
+ debug_output << "}," << endl;
debug_output << "\"num_up\":" << num_up << "," << endl;
debug_output << "\"updated_features\":" << update.size() << "," << endl;
debug_output << "\"learning_rate_R\":" << learning_rate_R << "," << endl;
@@ -332,7 +353,9 @@ main(int argc, char** argv)
sparseVectorToJson(learning_rates, debug_output);
debug_output << "}," << endl;
debug_output << "\"best_match\":\"";
- PrintWordIDVec((*samples)[0].w, debug_output);
+ ostringstream ps;
+ PrintWordIDVec((*samples)[0].w, ps);
+ debug_output << escapeJson(ps.str());
debug_output << "\"," << endl;
debug_output << "\"best_match_score\":" << (*samples)[0].gold << "," << endl ;
// -- debug
@@ -344,9 +367,6 @@ main(int argc, char** argv)
debug_output << "}" << endl;
// -- debug
- cerr << "[dtrain] done learning, looping again" << endl;
- sock.send(done.c_str(), done.size()+1, 0);
-
// debug --
WriteFile f(debug_fn);
f.get() << debug_output.str();
@@ -354,10 +374,18 @@ main(int argc, char** argv)
// -- debug
// write current weights
- lambdas.init_vector(decoder_weights);
- ostringstream fn;
- fn << output_fn << "." << i << ".gz";
- Weights::WriteToFile(fn.str(), decoder_weights, true);
+ if (!just_translate) {
+ lambdas.init_vector(decoder_weights);
+ ostringstream fn;
+ fn << output_fn << "." << i << ".gz";
+ Weights::WriteToFile(fn.str(), decoder_weights, true);
+ }
+
+ if (!just_translate) {
+ cerr << "[dtrain] done learning, looping again" << endl;
+ sock.send(done.c_str(), done.size()+1, 0);
+ }
+
} // input loop
string shutdown = "off";