From 176f311f6b4b2048dd05e0304d66ae5c61a4506e Mon Sep 17 00:00:00 2001
From: Patrick Simianer
Date: Fri, 8 Apr 2016 21:20:29 +0200
Subject: dtrain: adadelta, fix output, max input, batch learning
---
training/dtrain/update.h | 36 ++++++++++++++++--------------------
1 file changed, 16 insertions(+), 20 deletions(-)
(limited to 'training/dtrain/update.h')
diff --git a/training/dtrain/update.h b/training/dtrain/update.h
index f6aa9842..405a3f76 100644
--- a/training/dtrain/update.h
+++ b/training/dtrain/update.h
@@ -20,9 +20,8 @@ updates_multipartite(vector* sample,
size_t max_up,
weight_t threshold,
bool adjust,
- bool output=false,
- ostream& os=cout,
- size_t id=0)
+ WriteFile& output,
+ size_t id)
{
size_t up = 0;
size_t sz = sample->size();
@@ -50,7 +49,7 @@ updates_multipartite(vector* sample,
|| (threshold && (first.gold-second.gold < threshold)))
continue;
if (output)
- os << id << "\t" << first.f-second.f << endl;
+ *output << id << "\t" << first.f-second.f << endl;
updates += first.f-second.f;
if (++up==max_up)
return up;
@@ -70,7 +69,7 @@ updates_multipartite(vector* sample,
|| (threshold && (first.gold-second.gold < threshold)))
continue;
if (output)
- os << id << "\t" << first.f-second.f << endl;
+ *output << id << "\t" << first.f-second.f << endl;
updates += first.f-second.f;
if (++up==max_up)
break;
@@ -91,9 +90,8 @@ updates_all(vector* sample,
SparseVector& updates,
size_t max_up,
weight_t threshold,
- bool output=false,
- ostream& os=cout,
- size_t id=0)
+ WriteFile output,
+ size_t id)
{
size_t up = 0;
size_t sz = sample->size();
@@ -108,7 +106,7 @@ updates_all(vector* sample,
|| (threshold && (first.gold-second.gold < threshold)))
continue;
if (output)
- os << id << "\t" << first.f-second.f << endl;
+ *output << id << "\t" << first.f-second.f << endl;
updates += first.f-second.f;
if (++up==max_up)
break;
@@ -127,9 +125,8 @@ inline size_t
update_structured(vector* sample,
SparseVector& updates,
weight_t margin,
- bool output=false,
- ostream& os=cout,
- size_t id=0)
+ WriteFile output,
+ size_t id)
{
// hope
sort(sample->begin(), sample->end(), [](Hyp first, Hyp second)
@@ -147,13 +144,13 @@ update_structured(vector* sample,
if (hope.gold != fear.gold) {
updates += hope.f - fear.f;
if (output)
- os << id << "\t" << hope.f << "\t" << fear.f << endl;
+ *output << id << "\t" << hope.f << "\t" << fear.f << endl;
return 1;
}
if (output)
- os << endl;
+ *output << endl;
return 0;
}
@@ -172,9 +169,8 @@ updates_pro(vector* sample,
size_t maxs,
size_t max_up,
weight_t threshold,
- bool output=false,
- ostream& os=cout,
- size_t id=0)
+ WriteFile& output,
+ size_t id)
{
size_t sz = sample->size(), s;
@@ -202,7 +198,7 @@ updates_pro(vector* sample,
for (auto i: g) {
if (output)
- os << id << "\t" << i.first->f-i.second->f << endl;
+ *output << id << "\t" << i.first->f-i.second->f << endl;
updates += i.first->f-i.second->f;
}
@@ -215,7 +211,7 @@ updates_pro(vector* sample,
*/
inline void
output_sample(vector* sample,
- ostream& os=cout,
+ WriteFile& output,
size_t id=0,
bool sorted=true)
{
@@ -227,7 +223,7 @@ output_sample(vector* sample,
}
size_t j = 0;
for (auto k: *sample) {
- os << id << "\t" << j << "\t" << k.gold << "\t" << k.model
+ *output << id << "\t" << j << "\t" << k.gold << "\t" << k.model
<< "\t" << k.f << endl;
j++;
}
--
cgit v1.2.3