summaryrefslogtreecommitdiff
path: root/training/dtrain
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain')
-rw-r--r--training/dtrain/dtrain.cc31
-rw-r--r--training/dtrain/update.h37
2 files changed, 40 insertions, 28 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index ddd27211..3e9902ab 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -143,10 +143,15 @@ main(int argc, char** argv)
time_t total_time = 0.;
// output
- WriteFile raw_out;
- if (output_raw) raw_out.Init(output_raw_fn);
- WriteFile updates_out;
- if (output_updates) updates_out.Init(output_raw_fn);
+ WriteFile out_up, out_raw;
+ if (output_raw) {
+ out_raw.Init(output_raw_fn);
+ *out_raw << setprecision(numeric_limits<double>::digits10+1);
+ }
+ if (output_updates) {
+ out_up.Init(output_updates_fn);
+ *out_up << setprecision(numeric_limits<double>::digits10+1);
+ }
for (size_t t = 0; t < T; t++) // T iterations
@@ -220,25 +225,25 @@ main(int argc, char** argv)
list_sz += observer->effective_size;
if (output_raw)
- output_sample(sample);
+ output_sample(sample, *out_raw, i);
// update model
if (!noup) {
SparseVector<weight_t> updates;
if (structured)
- num_up += update_structured(sample, updates, margin/*,
- output_updates, updates_out.get()*/); // FIXME
+ num_up += update_structured(sample, updates, margin,
+ output_updates, *out_up, i);
else if (all_pairs)
- num_up += updates_all(sample, updates, max_up, threshold/*,
- output_updates, updates_out.get()*/); // FIXME
+ num_up += updates_all(sample, updates, max_up, threshold,
+ output_updates, *out_up, i);
else if (pro)
- num_up += updates_pro(sample, updates, cut, max_up, threshold/*,
- output_updates, updates_out.get()*/); // FIXME
+ num_up += updates_pro(sample, updates, cut, max_up, threshold,
+ output_updates, *out_up, i);
else
num_up += updates_multipartite(sample, updates, cut, margin,
- max_up, threshold, adjust_cut/*,
- output_updates, updates_out.get()*/); // FIXME
+ max_up, threshold, adjust_cut,
+ output_updates, *out_up, i);
SparseVector<weight_t> lambdas_copy;
if (l1_reg)
lambdas_copy = lambdas;
diff --git a/training/dtrain/update.h b/training/dtrain/update.h
index 30b14771..f6aa9842 100644
--- a/training/dtrain/update.h
+++ b/training/dtrain/update.h
@@ -21,7 +21,8 @@ updates_multipartite(vector<Hyp>* sample,
weight_t threshold,
bool adjust,
bool output=false,
- ostream& os=cout)
+ ostream& os=cout,
+ size_t id=0)
{
size_t up = 0;
size_t sz = sample->size();
@@ -45,11 +46,11 @@ updates_multipartite(vector<Hyp>* sample,
for (size_t j = sep_hi; j < sz; j++) {
Hyp& first=(*sample)[i], second=(*sample)[j];
if ((first.model-second.model)>margin
- || (!adjust && first.gold==second.gold)
+ || (first.gold==second.gold)
|| (threshold && (first.gold-second.gold < threshold)))
continue;
if (output)
- os << first.f-second.f << endl;
+ os << id << "\t" << first.f-second.f << endl;
updates += first.f-second.f;
if (++up==max_up)
return up;
@@ -65,11 +66,11 @@ updates_multipartite(vector<Hyp>* sample,
for (size_t j = sep_lo; j < sz; j++) {
Hyp& first=(*sample)[i], second=(*sample)[j];
if ((first.model-second.model)>margin
- || (!adjust && first.gold==second.gold)
+ || (first.gold==second.gold)
|| (threshold && (first.gold-second.gold < threshold)))
continue;
if (output)
- os << first.f-second.f << endl;
+ os << id << "\t" << first.f-second.f << endl;
updates += first.f-second.f;
if (++up==max_up)
break;
@@ -91,7 +92,8 @@ updates_all(vector<Hyp>* sample,
size_t max_up,
weight_t threshold,
bool output=false,
- ostream& os=cout)
+ ostream& os=cout,
+ size_t id=0)
{
size_t up = 0;
size_t sz = sample->size();
@@ -102,11 +104,11 @@ updates_all(vector<Hyp>* sample,
for (size_t i = 0; i < sz-1; i++) {
for (size_t j = i+1; j < sz; j++) {
Hyp& first=(*sample)[i], second=(*sample)[j];
- if (first.gold == second.gold
+ if ((first.gold == second.gold)
|| (threshold && (first.gold-second.gold < threshold)))
continue;
if (output)
- os << first.f-second.f << endl;
+ os << id << "\t" << first.f-second.f << endl;
updates += first.f-second.f;
if (++up==max_up)
break;
@@ -126,7 +128,8 @@ update_structured(vector<Hyp>* sample,
SparseVector<weight_t>& updates,
weight_t margin,
bool output=false,
- ostream& os=cout)
+ ostream& os=cout,
+ size_t id=0)
{
// hope
sort(sample->begin(), sample->end(), [](Hyp first, Hyp second)
@@ -144,7 +147,7 @@ update_structured(vector<Hyp>* sample,
if (hope.gold != fear.gold) {
updates += hope.f - fear.f;
if (output)
- os << hope.f << "\t" << fear.f << endl;
+ os << id << "\t" << hope.f << "\t" << fear.f << endl;
return 1;
}
@@ -170,7 +173,8 @@ updates_pro(vector<Hyp>* sample,
size_t max_up,
weight_t threshold,
bool output=false,
- ostream& os=cout)
+ ostream& os=cout,
+ size_t id=0)
{
size_t sz = sample->size(), s;
@@ -198,7 +202,7 @@ updates_pro(vector<Hyp>* sample,
for (auto i: g) {
if (output)
- os << i.first->f-i.second->f << endl;
+ os << id << "\t" << i.first->f-i.second->f << endl;
updates += i.first->f-i.second->f;
}
@@ -212,16 +216,19 @@ updates_pro(vector<Hyp>* sample,
inline void
output_sample(vector<Hyp>* sample,
ostream& os=cout,
+ size_t id=0,
bool sorted=true)
{
- if (sorted)
+ if (sorted) {
sort(sample->begin(), sample->end(), [](Hyp first, Hyp second)
{
return first.gold > second.gold;
});
+ }
size_t j = 0;
- for (auto i: *sample) {
- os << j << "\t" << i.gold << "\t" << i.model << "\t" << i.f << endl;
+ for (auto k: *sample) {
+ os << id << "\t" << j << "\t" << k.gold << "\t" << k.model
+ << "\t" << k.f << endl;
j++;
}
}