summaryrefslogtreecommitdiff
path: root/training/dtrain/update.h
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2015-07-22 12:42:57 +0200
committerPatrick Simianer <p@simianer.de>2015-07-22 12:42:57 +0200
commit0208c988890a72d4a3e80fb3cebf2abd03162050 (patch)
treec01402e63503dd3c0653647821084f45dde8878c /training/dtrain/update.h
parente606d71a31038281d141022cd8c26a21cada3f27 (diff)
parent434a42b9d096abb436cac1d9788157c16b8ccab0 (diff)
merge dtrain_struct
Diffstat (limited to 'training/dtrain/update.h')
-rw-r--r--training/dtrain/update.h38
1 files changed, 38 insertions, 0 deletions
diff --git a/training/dtrain/update.h b/training/dtrain/update.h
index d7224cca..6f42e5bd 100644
--- a/training/dtrain/update.h
+++ b/training/dtrain/update.h
@@ -10,6 +10,18 @@ _cmp(ScoredHyp a, ScoredHyp b)
return a.gold > b.gold;
}
+bool
+_cmpHope(ScoredHyp a, ScoredHyp b)
+{
+ return (a.model+a.gold) > (b.model+b.gold);
+}
+
+bool
+_cmpFear(ScoredHyp a, ScoredHyp b)
+{
+ return (a.model-a.gold) > (b.model-b.gold);
+}
+
inline bool
_good(ScoredHyp& a, ScoredHyp& b, weight_t margin)
{
@@ -20,6 +32,15 @@ _good(ScoredHyp& a, ScoredHyp& b, weight_t margin)
return false;
}
+inline bool
+_goodS(ScoredHyp& a, ScoredHyp& b)
+{
+ if (a.gold==b.gold)
+ return true;
+
+ return false;
+}
+
/*
* multipartite ranking
* sort (descending) by bleu
@@ -56,6 +77,23 @@ CollectUpdates(vector<ScoredHyp>* s,
return num_up;
}
+inline size_t
+CollectUpdatesStruct(vector<ScoredHyp>* s,
+ SparseVector<weight_t>& updates,
+ weight_t unused=-1)
+{
+ // hope
+ sort(s->begin(), s->end(), _cmpHope);
+ ScoredHyp hope = (*s)[0];
+ // fear
+ sort(s->begin(), s->end(), _cmpFear);
+ ScoredHyp fear = (*s)[0];
+ if (!_goodS(hope, fear))
+ updates += hope.f - fear.f;
+
+ return updates.size();
+}
+
} // namespace
#endif