From 176f311f6b4b2048dd05e0304d66ae5c61a4506e Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Fri, 8 Apr 2016 21:20:29 +0200 Subject: dtrain: adadelta, fix output, max input, batch learning --- training/dtrain/update.h | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) (limited to 'training/dtrain/update.h') diff --git a/training/dtrain/update.h b/training/dtrain/update.h index f6aa9842..405a3f76 100644 --- a/training/dtrain/update.h +++ b/training/dtrain/update.h @@ -20,9 +20,8 @@ updates_multipartite(vector* sample, size_t max_up, weight_t threshold, bool adjust, - bool output=false, - ostream& os=cout, - size_t id=0) + WriteFile& output, + size_t id) { size_t up = 0; size_t sz = sample->size(); @@ -50,7 +49,7 @@ updates_multipartite(vector* sample, || (threshold && (first.gold-second.gold < threshold))) continue; if (output) - os << id << "\t" << first.f-second.f << endl; + *output << id << "\t" << first.f-second.f << endl; updates += first.f-second.f; if (++up==max_up) return up; @@ -70,7 +69,7 @@ updates_multipartite(vector* sample, || (threshold && (first.gold-second.gold < threshold))) continue; if (output) - os << id << "\t" << first.f-second.f << endl; + *output << id << "\t" << first.f-second.f << endl; updates += first.f-second.f; if (++up==max_up) break; @@ -91,9 +90,8 @@ updates_all(vector* sample, SparseVector& updates, size_t max_up, weight_t threshold, - bool output=false, - ostream& os=cout, - size_t id=0) + WriteFile output, + size_t id) { size_t up = 0; size_t sz = sample->size(); @@ -108,7 +106,7 @@ updates_all(vector* sample, || (threshold && (first.gold-second.gold < threshold))) continue; if (output) - os << id << "\t" << first.f-second.f << endl; + *output << id << "\t" << first.f-second.f << endl; updates += first.f-second.f; if (++up==max_up) break; @@ -127,9 +125,8 @@ inline size_t update_structured(vector* sample, SparseVector& updates, weight_t margin, - bool output=false, - ostream& os=cout, - size_t id=0) + WriteFile output, + size_t id) { // hope sort(sample->begin(), sample->end(), [](Hyp first, Hyp second) @@ -147,13 +144,13 @@ update_structured(vector* sample, if (hope.gold != fear.gold) { updates += hope.f - fear.f; if (output) - os << id << "\t" << hope.f << "\t" << fear.f << endl; + *output << id << "\t" << hope.f << "\t" << fear.f << endl; return 1; } if (output) - os << endl; + *output << endl; return 0; } @@ -172,9 +169,8 @@ updates_pro(vector* sample, size_t maxs, size_t max_up, weight_t threshold, - bool output=false, - ostream& os=cout, - size_t id=0) + WriteFile& output, + size_t id) { size_t sz = sample->size(), s; @@ -202,7 +198,7 @@ updates_pro(vector* sample, for (auto i: g) { if (output) - os << id << "\t" << i.first->f-i.second->f << endl; + *output << id << "\t" << i.first->f-i.second->f << endl; updates += i.first->f-i.second->f; } @@ -215,7 +211,7 @@ updates_pro(vector* sample, */ inline void output_sample(vector* sample, - ostream& os=cout, + WriteFile& output, size_t id=0, bool sorted=true) { @@ -227,7 +223,7 @@ output_sample(vector* sample, } size_t j = 0; for (auto k: *sample) { - os << id << "\t" << j << "\t" << k.gold << "\t" << k.model + *output << id << "\t" << j << "\t" << k.gold << "\t" << k.model << "\t" << k.f << endl; j++; } -- cgit v1.2.3