diff options
author | Patrick Simianer <p@simianer.de> | 2015-02-27 12:43:45 +0100 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2015-02-27 12:43:45 +0100 |
commit | 6b1a2121cba7c32686d29515a6dc322e62435049 (patch) | |
tree | 1df9aea6899c159bc1e632baf03caf01668f8ae4 /training/dtrain | |
parent | 0982b8bbb67a42d4b0e084a217e0acd7ddcff243 (diff) |
simplified update
Diffstat (limited to 'training/dtrain')
-rw-r--r-- | training/dtrain/update.h | 33 |
1 files changed, 16 insertions, 17 deletions
diff --git a/training/dtrain/update.h b/training/dtrain/update.h index 72d369c4..d7224cca 100644 --- a/training/dtrain/update.h +++ b/training/dtrain/update.h @@ -10,6 +10,16 @@ _cmp(ScoredHyp a, ScoredHyp b) return a.gold > b.gold; } +inline bool +_good(ScoredHyp& a, ScoredHyp& b, weight_t margin) +{ + if ((a.model-b.model)>margin + || a.gold==b.gold) + return true; + + return false; +} + /* * multipartite ranking * sort (descending) by bleu @@ -19,35 +29,24 @@ _cmp(ScoredHyp a, ScoredHyp b) inline size_t CollectUpdates(vector<ScoredHyp>* s, SparseVector<weight_t>& updates, - float margin=0.) + weight_t margin=0.) { size_t num_up = 0; size_t sz = s->size(); - if (sz < 2) return 0; sort(s->begin(), s->end(), _cmp); size_t sep = round(sz*0.1); - size_t sep_hi = sep; - if (sz > 4) { - while (sep_hi<sz && (*s)[sep_hi-1].gold==(*s)[sep_hi].gold) - ++sep_hi; - } - else sep_hi = 1; - for (size_t i = 0; i < sep_hi; i++) { - for (size_t j = sep_hi; j < sz; j++) { - if (((*s)[i].model-(*s)[j].model) > margin - || (*s)[i].gold == (*s)[j].gold) + for (size_t i = 0; i < sep; i++) { + for (size_t j = sep; j < sz; j++) { + if (_good((*s)[i], (*s)[j], margin)) continue; updates += (*s)[i].f-(*s)[j].f; num_up++; } } size_t sep_lo = sz-sep; - while (sep_lo>=sep_hi && (*s)[sep_lo].gold==(*s)[sep_lo+1].gold) - --sep_lo; - for (size_t i = sep_hi; i < sep_lo; i++) { + for (size_t i = sep; i < sep_lo; i++) { for (size_t j = sep_lo; j < sz; j++) { - if (((*s)[i].model-(*s)[j].model) > margin - || (*s)[i].gold == (*s)[j].gold) + if (_good((*s)[i], (*s)[j], margin)) continue; updates += (*s)[i].f-(*s)[j].f; num_up++; |