summaryrefslogtreecommitdiff
path: root/training/dtrain/update.h
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/update.h')
-rw-r--r--training/dtrain/update.h37
1 files changed, 22 insertions, 15 deletions
diff --git a/training/dtrain/update.h b/training/dtrain/update.h
index 30b14771..f6aa9842 100644
--- a/training/dtrain/update.h
+++ b/training/dtrain/update.h
@@ -21,7 +21,8 @@ updates_multipartite(vector<Hyp>* sample,
weight_t threshold,
bool adjust,
bool output=false,
- ostream& os=cout)
+ ostream& os=cout,
+ size_t id=0)
{
size_t up = 0;
size_t sz = sample->size();
@@ -45,11 +46,11 @@ updates_multipartite(vector<Hyp>* sample,
for (size_t j = sep_hi; j < sz; j++) {
Hyp& first=(*sample)[i], second=(*sample)[j];
if ((first.model-second.model)>margin
- || (!adjust && first.gold==second.gold)
+ || (first.gold==second.gold)
|| (threshold && (first.gold-second.gold < threshold)))
continue;
if (output)
- os << first.f-second.f << endl;
+ os << id << "\t" << first.f-second.f << endl;
updates += first.f-second.f;
if (++up==max_up)
return up;
@@ -65,11 +66,11 @@ updates_multipartite(vector<Hyp>* sample,
for (size_t j = sep_lo; j < sz; j++) {
Hyp& first=(*sample)[i], second=(*sample)[j];
if ((first.model-second.model)>margin
- || (!adjust && first.gold==second.gold)
+ || (first.gold==second.gold)
|| (threshold && (first.gold-second.gold < threshold)))
continue;
if (output)
- os << first.f-second.f << endl;
+ os << id << "\t" << first.f-second.f << endl;
updates += first.f-second.f;
if (++up==max_up)
break;
@@ -91,7 +92,8 @@ updates_all(vector<Hyp>* sample,
size_t max_up,
weight_t threshold,
bool output=false,
- ostream& os=cout)
+ ostream& os=cout,
+ size_t id=0)
{
size_t up = 0;
size_t sz = sample->size();
@@ -102,11 +104,11 @@ updates_all(vector<Hyp>* sample,
for (size_t i = 0; i < sz-1; i++) {
for (size_t j = i+1; j < sz; j++) {
Hyp& first=(*sample)[i], second=(*sample)[j];
- if (first.gold == second.gold
+ if ((first.gold == second.gold)
|| (threshold && (first.gold-second.gold < threshold)))
continue;
if (output)
- os << first.f-second.f << endl;
+ os << id << "\t" << first.f-second.f << endl;
updates += first.f-second.f;
if (++up==max_up)
break;
@@ -126,7 +128,8 @@ update_structured(vector<Hyp>* sample,
SparseVector<weight_t>& updates,
weight_t margin,
bool output=false,
- ostream& os=cout)
+ ostream& os=cout,
+ size_t id=0)
{
// hope
sort(sample->begin(), sample->end(), [](Hyp first, Hyp second)
@@ -144,7 +147,7 @@ update_structured(vector<Hyp>* sample,
if (hope.gold != fear.gold) {
updates += hope.f - fear.f;
if (output)
- os << hope.f << "\t" << fear.f << endl;
+ os << id << "\t" << hope.f << "\t" << fear.f << endl;
return 1;
}
@@ -170,7 +173,8 @@ updates_pro(vector<Hyp>* sample,
size_t max_up,
weight_t threshold,
bool output=false,
- ostream& os=cout)
+ ostream& os=cout,
+ size_t id=0)
{
size_t sz = sample->size(), s;
@@ -198,7 +202,7 @@ updates_pro(vector<Hyp>* sample,
for (auto i: g) {
if (output)
- os << i.first->f-i.second->f << endl;
+ os << id << "\t" << i.first->f-i.second->f << endl;
updates += i.first->f-i.second->f;
}
@@ -212,16 +216,19 @@ updates_pro(vector<Hyp>* sample,
inline void
output_sample(vector<Hyp>* sample,
ostream& os=cout,
+ size_t id=0,
bool sorted=true)
{
- if (sorted)
+ if (sorted) {
sort(sample->begin(), sample->end(), [](Hyp first, Hyp second)
{
return first.gold > second.gold;
});
+ }
size_t j = 0;
- for (auto i: *sample) {
- os << j << "\t" << i.gold << "\t" << i.model << "\t" << i.f << endl;
+ for (auto k: *sample) {
+ os << id << "\t" << j << "\t" << k.gold << "\t" << k.model
+ << "\t" << k.f << endl;
j++;
}
}