diff options
author | Patrick Simianer <p@simianer.de> | 2016-07-14 21:44:24 +0200 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2016-07-14 21:44:24 +0200 |
commit | 1a749f62c19ea77b74a61a5ec747c16fea95f860 (patch) | |
tree | 8c42257bbbcc4a1e79f4892a03d4513ca78fc5d7 /training | |
parent | 62ceb12a0491d9490f02332422cce2613163a20c (diff) |
translate _and_ learn
Diffstat (limited to 'training')
-rw-r--r-- | training/dtrain/dtrain_net_interface.cc | 98 |
1 files changed, 63 insertions, 35 deletions
diff --git a/training/dtrain/dtrain_net_interface.cc b/training/dtrain/dtrain_net_interface.cc index 761930f8..ac447517 100644 --- a/training/dtrain/dtrain_net_interface.cc +++ b/training/dtrain/dtrain_net_interface.cc @@ -98,6 +98,8 @@ main(int argc, char** argv) debug_output << "{" << endl; // hack us a nice JSON output // -- debug + bool just_translate = false; + char *buf = NULL; string source; vector<Ngrams> refs; @@ -210,36 +212,41 @@ main(int argc, char** argv) } else { // translate vector<string> parts; boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| ")); - if (parts[0] == "act:translate") { + if (parts[0] == "act:translate" || parts[0] == "act:translate_learn") { + if (parts[0] == "act:translate") + just_translate = true; cerr << "[dtrain] translating ..." << endl; lambdas.init_vector(&decoder_weights); observer->dont_score = true; decoder.Decode(parts[1], observer); observer->dont_score = false; samples = observer->GetSamples(); - ostringstream os; - cerr << "[dtrain] 1best features " << (*samples)[0].f << endl; - if (output_derivation) { - os << observer->GetViterbiTreeStr() << endl; - } else { - PrintWordIDVec((*samples)[0].w, os); - } - if (output_rules) { - os << observer->GetViterbiRules() << endl; + if (parts[0] == "act:translate") { + ostringstream os; + cerr << "[dtrain] 1best features " << (*samples)[0].f << endl; + if (output_derivation) { + os << observer->GetViterbiTreeStr() << endl; + } else { + PrintWordIDVec((*samples)[0].w, os); + } + if (output_rules) { + os << observer->GetViterbiRules() << endl; + } + sock.send(os.str().c_str(), os.str().size()+1, 0); + cerr << "[dtrain] done translating, looping again" << endl; } - sock.send(os.str().c_str(), os.str().size()+1, 0); - cerr << "[dtrain] done translating, looping again" << endl; - continue; - } else { // learn + } //else { // learn + if (!just_translate) { cerr << "[dtrain] learning ..." << endl; - source = parts[0]; + source = parts[1]; // debug -- debug_output << "\"source\":\"" - << source.substr(source.find_first_of(">")+2, source.find_last_of(">")-6) + << escapeJson(source.substr(source.find_first_of(">")+2, source.find_last_of(">")-6)) << "\"," << endl; - debug_output << "\"target\":\"" << parts[1] << "\"," << endl; + debug_output << "\"target\":\"" << escapeJson(parts[2]) << "\"," << endl; // -- debug parts.erase(parts.begin()); + parts.erase(parts.begin()); for (auto s: parts) { vector<WordID> r; vector<string> toks; @@ -252,6 +259,8 @@ main(int argc, char** argv) for (size_t r = 0; r < samples->size(); r++) (*samples)[r].gold = observer->scorer_->Score((*samples)[r].w, refs, rsz); + //} + //} } } } @@ -262,9 +271,10 @@ main(int argc, char** argv) // decode lambdas.init_vector(&decoder_weights); - // debug -- - debug_output << "\"1best\":\""; - PrintWordIDVec((*samples)[0].w, debug_output); + // debug --) + ostringstream os; + PrintWordIDVec((*samples)[0].w, os); + debug_output << "\"1best\":\"" << escapeJson(os.str()); debug_output << "\"," << endl; debug_output << "\"kbest\":[" << endl; size_t h = 0; @@ -272,9 +282,11 @@ main(int argc, char** argv) debug_output << "\"" << s.gold << " ||| " << s.model << " ||| " << s.rank << " ||| "; for (auto o: s.f) - debug_output << FD::Convert(o.first) << "=" << o.second << " "; + debug_output << escapeJson(FD::Convert(o.first)) << "=" << o.second << " "; debug_output << " ||| "; - PrintWordIDVec(s.w, debug_output); + ostringstream os; + PrintWordIDVec(s.w, os); + debug_output << escapeJson(os.str()); h += 1; debug_output << "\""; if (h < samples->size()) { @@ -296,8 +308,12 @@ main(int argc, char** argv) size_t num_up = CollectUpdates(samples, update, margin); // debug -- - debug_output << "\"1best_features\":\"" << (*samples)[0].f << "\"," << endl; - debug_output << "\"update_raw\":\"" << update << "\"," << endl; + debug_output << "\"1best_features\":{"; + sparseVectorToJson((*samples)[0].f, debug_output); + debug_output << "}," << endl; + debug_output << "\"update_raw\":{"; + sparseVectorToJson(update, debug_output); + debug_output << "}," << endl; // -- debug // update @@ -318,11 +334,16 @@ main(int argc, char** argv) } } } - lambdas += update; - i++; + if (!just_translate) { + lambdas += update; + } else { + i++; + } // debug -- - debug_output << "\"update\":\"" << update << "\"," << endl; + debug_output << "\"update\":{"; + sparseVectorToJson(update, debug_output); + debug_output << "}," << endl; debug_output << "\"num_up\":" << num_up << "," << endl; debug_output << "\"updated_features\":" << update.size() << "," << endl; debug_output << "\"learning_rate_R\":" << learning_rate_R << "," << endl; @@ -332,7 +353,9 @@ main(int argc, char** argv) sparseVectorToJson(learning_rates, debug_output); debug_output << "}," << endl; debug_output << "\"best_match\":\""; - PrintWordIDVec((*samples)[0].w, debug_output); + ostringstream ps; + PrintWordIDVec((*samples)[0].w, ps); + debug_output << escapeJson(ps.str()); debug_output << "\"," << endl; debug_output << "\"best_match_score\":" << (*samples)[0].gold << "," << endl ; // -- debug @@ -344,9 +367,6 @@ main(int argc, char** argv) debug_output << "}" << endl; // -- debug - cerr << "[dtrain] done learning, looping again" << endl; - sock.send(done.c_str(), done.size()+1, 0); - // debug -- WriteFile f(debug_fn); f.get() << debug_output.str(); @@ -354,10 +374,18 @@ main(int argc, char** argv) // -- debug // write current weights - lambdas.init_vector(decoder_weights); - ostringstream fn; - fn << output_fn << "." << i << ".gz"; - Weights::WriteToFile(fn.str(), decoder_weights, true); + if (!just_translate) { + lambdas.init_vector(decoder_weights); + ostringstream fn; + fn << output_fn << "." << i << ".gz"; + Weights::WriteToFile(fn.str(), decoder_weights, true); + } + + if (!just_translate) { + cerr << "[dtrain] done learning, looping again" << endl; + sock.send(done.c_str(), done.size()+1, 0); + } + } // input loop string shutdown = "off"; |