diff options
Diffstat (limited to 'training/dtrain')
-rw-r--r-- | training/dtrain/update.h | 33 |
1 files changed, 16 insertions, 17 deletions
diff --git a/training/dtrain/update.h b/training/dtrain/update.h index 72d369c4..d7224cca 100644 --- a/training/dtrain/update.h +++ b/training/dtrain/update.h @@ -10,6 +10,16 @@ _cmp(ScoredHyp a, ScoredHyp b) return a.gold > b.gold; } +inline bool +_good(ScoredHyp& a, ScoredHyp& b, weight_t margin) +{ + if ((a.model-b.model)>margin + || a.gold==b.gold) + return true; + + return false; +} + /* * multipartite ranking * sort (descending) by bleu @@ -19,35 +29,24 @@ _cmp(ScoredHyp a, ScoredHyp b) inline size_t CollectUpdates(vector<ScoredHyp>* s, SparseVector<weight_t>& updates, - float margin=0.) + weight_t margin=0.) { size_t num_up = 0; size_t sz = s->size(); - if (sz < 2) return 0; sort(s->begin(), s->end(), _cmp); size_t sep = round(sz*0.1); - size_t sep_hi = sep; - if (sz > 4) { - while (sep_hi<sz && (*s)[sep_hi-1].gold==(*s)[sep_hi].gold) - ++sep_hi; - } - else sep_hi = 1; - for (size_t i = 0; i < sep_hi; i++) { - for (size_t j = sep_hi; j < sz; j++) { - if (((*s)[i].model-(*s)[j].model) > margin - || (*s)[i].gold == (*s)[j].gold) + for (size_t i = 0; i < sep; i++) { + for (size_t j = sep; j < sz; j++) { + if (_good((*s)[i], (*s)[j], margin)) continue; updates += (*s)[i].f-(*s)[j].f; num_up++; } } size_t sep_lo = sz-sep; - while (sep_lo>=sep_hi && (*s)[sep_lo].gold==(*s)[sep_lo+1].gold) - --sep_lo; - for (size_t i = sep_hi; i < sep_lo; i++) { + for (size_t i = sep; i < sep_lo; i++) { for (size_t j = sep_lo; j < sz; j++) { - if (((*s)[i].model-(*s)[j].model) > margin - || (*s)[i].gold == (*s)[j].gold) + if (_good((*s)[i], (*s)[j], margin)) continue; updates += (*s)[i].f-(*s)[j].f; num_up++; |