diff options
Diffstat (limited to 'training/dtrain/update.h')
-rw-r--r-- | training/dtrain/update.h | 34 |
1 files changed, 16 insertions, 18 deletions
diff --git a/training/dtrain/update.h b/training/dtrain/update.h index 57671ce1..72d369c4 100644 --- a/training/dtrain/update.h +++ b/training/dtrain/update.h @@ -5,7 +5,7 @@ namespace dtrain { bool -CmpHypsByGold(ScoredHyp a, ScoredHyp b) +_cmp(ScoredHyp a, ScoredHyp b) { return a.gold > b.gold; } @@ -19,44 +19,42 @@ CmpHypsByGold(ScoredHyp a, ScoredHyp b) inline size_t CollectUpdates(vector<ScoredHyp>* s, SparseVector<weight_t>& updates, - float margin=1.0) + float margin=0.) { - size_t num_pairs = 0; + size_t num_up = 0; size_t sz = s->size(); if (sz < 2) return 0; - sort(s->begin(), s->end(), CmpHypsByGold); + sort(s->begin(), s->end(), _cmp); size_t sep = round(sz*0.1); size_t sep_hi = sep; if (sz > 4) { - while - (sep_hi < sz && (*s)[sep_hi-1].gold == (*s)[sep_hi].gold) ++sep_hi; + while (sep_hi<sz && (*s)[sep_hi-1].gold==(*s)[sep_hi].gold) + ++sep_hi; } else sep_hi = 1; for (size_t i = 0; i < sep_hi; i++) { for (size_t j = sep_hi; j < sz; j++) { - if (((*s)[i].model-(*s)[j].model) > margin) + if (((*s)[i].model-(*s)[j].model) > margin + || (*s)[i].gold == (*s)[j].gold) continue; - if ((*s)[i].gold != (*s)[j].gold) { - updates += (*s)[i].f-(*s)[j].f; - num_pairs++; - } + updates += (*s)[i].f-(*s)[j].f; + num_up++; } } size_t sep_lo = sz-sep; - while (sep_lo > 0 && (*s)[sep_lo-1].gold == (*s)[sep_lo].gold) + while (sep_lo>=sep_hi && (*s)[sep_lo].gold==(*s)[sep_lo+1].gold) --sep_lo; for (size_t i = sep_hi; i < sep_lo; i++) { for (size_t j = sep_lo; j < sz; j++) { - if (((*s)[i].model-(*s)[j].model) > margin) + if (((*s)[i].model-(*s)[j].model) > margin + || (*s)[i].gold == (*s)[j].gold) continue; - if ((*s)[i].gold != (*s)[j].gold) { - updates += (*s)[i].f-(*s)[j].f; - num_pairs++; - } + updates += (*s)[i].f-(*s)[j].f; + num_up++; } } - return num_pairs; + return num_up; } } // namespace |