summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc31
1 files changed, 18 insertions, 13 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index ddd27211..3e9902ab 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -143,10 +143,15 @@ main(int argc, char** argv)
time_t total_time = 0.;
// output
- WriteFile raw_out;
- if (output_raw) raw_out.Init(output_raw_fn);
- WriteFile updates_out;
- if (output_updates) updates_out.Init(output_raw_fn);
+ WriteFile out_up, out_raw;
+ if (output_raw) {
+ out_raw.Init(output_raw_fn);
+ *out_raw << setprecision(numeric_limits<double>::digits10+1);
+ }
+ if (output_updates) {
+ out_up.Init(output_updates_fn);
+ *out_up << setprecision(numeric_limits<double>::digits10+1);
+ }
for (size_t t = 0; t < T; t++) // T iterations
@@ -220,25 +225,25 @@ main(int argc, char** argv)
list_sz += observer->effective_size;
if (output_raw)
- output_sample(sample);
+ output_sample(sample, *out_raw, i);
// update model
if (!noup) {
SparseVector<weight_t> updates;
if (structured)
- num_up += update_structured(sample, updates, margin/*,
- output_updates, updates_out.get()*/); // FIXME
+ num_up += update_structured(sample, updates, margin,
+ output_updates, *out_up, i);
else if (all_pairs)
- num_up += updates_all(sample, updates, max_up, threshold/*,
- output_updates, updates_out.get()*/); // FIXME
+ num_up += updates_all(sample, updates, max_up, threshold,
+ output_updates, *out_up, i);
else if (pro)
- num_up += updates_pro(sample, updates, cut, max_up, threshold/*,
- output_updates, updates_out.get()*/); // FIXME
+ num_up += updates_pro(sample, updates, cut, max_up, threshold,
+ output_updates, *out_up, i);
else
num_up += updates_multipartite(sample, updates, cut, margin,
- max_up, threshold, adjust_cut/*,
- output_updates, updates_out.get()*/); // FIXME
+ max_up, threshold, adjust_cut,
+ output_updates, *out_up, i);
SparseVector<weight_t> lambdas_copy;
if (l1_reg)
lambdas_copy = lambdas;