summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc36
1 files changed, 21 insertions, 15 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 441e2cd7..0a27a068 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -414,6 +414,12 @@ main(int argc, char** argv)
score_t kbest_loss_first, kbest_loss_last = 0.0;
+ for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
+ it != pairs.end(); it++) {
+ score_t model_diff = it->first.model - it->second.model;
+ kbest_loss_first += max(0.0, -1.0 * model_diff);
+ }
+
for (int ki=0; ki < repeat; ki++) {
score_t kbest_loss = 0.0; // test-k-best
@@ -520,21 +526,22 @@ main(int argc, char** argv)
}
}
- if (ki==0) kbest_loss_first = kbest_loss;
if (ki==repeat-1) { // done
kbest_loss_last = kbest_loss;
- score_t best_score = -1.;
- score_t best_model = -std::numeric_limits<score_t>::max();
- unsigned best_idx;
- for (unsigned i=0; i < samples->size(); i++) {
- score_t s = lambdas.dot((*samples)[i].f);
- if (s > best_model) {
- best_idx = i;
- best_model = s;
+ if (repeat > 1) {
+ score_t best_score = -1.;
+ score_t best_model = -std::numeric_limits<score_t>::max();
+ unsigned best_idx;
+ for (unsigned i=0; i < samples->size(); i++) {
+ score_t s = lambdas.dot((*samples)[i].f);
+ if (s > best_model) {
+ best_idx = i;
+ best_model = s;
+ }
}
+ score_sum += (*samples)[best_idx].score;
+ model_sum += best_model;
}
- score_sum += (*samples)[best_idx].score;
- model_sum += best_model;
}
} // repeat
@@ -588,15 +595,14 @@ main(int argc, char** argv)
cerr << _p << " (" << model_diff << ")" << endl;
cerr << " avg # pairs: ";
cerr << _np << npairs/(float)in_sz << endl;
- cerr << " avg # margin viol: ";
- cerr << margin_violations/(float)in_sz << endl;
cerr << " avg # rank err: ";
cerr << rank_errors/(float)in_sz;
if (faster_perceptron) cerr << " (meaningless)";
cerr << endl;
+ cerr << " avg # margin viol: ";
+ cerr << margin_violations/(float)in_sz << endl;
if (batch) cerr << " batch loss: " << batch_loss << endl;
- if (repeat > 1) cerr << " k-best loss imp: " << ((float)kbest_loss_improve/in_sz)*100 << "%" << endl;
-
+ cerr << " k-best loss imp: " << ((float)kbest_loss_improve/in_sz)*100 << "%" << endl;
cerr << " non0 feature count: " << nonz << endl;
cerr << " avg list sz: " << list_sz/(float)in_sz << endl;
cerr << " avg f count: " << f_count/(float)list_sz << endl;