From 6b1a2121cba7c32686d29515a6dc322e62435049 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Fri, 27 Feb 2015 12:43:45 +0100 Subject: simplified update --- training/dtrain/update.h | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/training/dtrain/update.h b/training/dtrain/update.h index 72d369c4..d7224cca 100644 --- a/training/dtrain/update.h +++ b/training/dtrain/update.h @@ -10,6 +10,16 @@ _cmp(ScoredHyp a, ScoredHyp b) return a.gold > b.gold; } +inline bool +_good(ScoredHyp& a, ScoredHyp& b, weight_t margin) +{ + if ((a.model-b.model)>margin + || a.gold==b.gold) + return true; + + return false; +} + /* * multipartite ranking * sort (descending) by bleu @@ -19,35 +29,24 @@ _cmp(ScoredHyp a, ScoredHyp b) inline size_t CollectUpdates(vector* s, SparseVector& updates, - float margin=0.) + weight_t margin=0.) { size_t num_up = 0; size_t sz = s->size(); - if (sz < 2) return 0; sort(s->begin(), s->end(), _cmp); size_t sep = round(sz*0.1); - size_t sep_hi = sep; - if (sz > 4) { - while (sep_hi margin - || (*s)[i].gold == (*s)[j].gold) + for (size_t i = 0; i < sep; i++) { + for (size_t j = sep; j < sz; j++) { + if (_good((*s)[i], (*s)[j], margin)) continue; updates += (*s)[i].f-(*s)[j].f; num_up++; } } size_t sep_lo = sz-sep; - while (sep_lo>=sep_hi && (*s)[sep_lo].gold==(*s)[sep_lo+1].gold) - --sep_lo; - for (size_t i = sep_hi; i < sep_lo; i++) { + for (size_t i = sep; i < sep_lo; i++) { for (size_t j = sep_lo; j < sz; j++) { - if (((*s)[i].model-(*s)[j].model) > margin - || (*s)[i].gold == (*s)[j].gold) + if (_good((*s)[i], (*s)[j], margin)) continue; updates += (*s)[i].f-(*s)[j].f; num_up++; -- cgit v1.2.3