summaryrefslogtreecommitdiff
path: root/training/dtrain/update.h
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2015-02-26 13:26:37 +0100
committerPatrick Simianer <p@simianer.de>2015-02-26 13:26:37 +0100
commit4223261682388944fe1b1cf31b9d51d88f9ad53b (patch)
treedaf072c310d60b0386587bde5e554312f193b3b2 /training/dtrain/update.h
parent2a37a7ad1b21ab54701de3b5b44dc4ea55a75307 (diff)
refactoring
Diffstat (limited to 'training/dtrain/update.h')
-rw-r--r--training/dtrain/update.h65
1 files changed, 65 insertions, 0 deletions
diff --git a/training/dtrain/update.h b/training/dtrain/update.h
new file mode 100644
index 00000000..57671ce1
--- /dev/null
+++ b/training/dtrain/update.h
@@ -0,0 +1,65 @@
+#ifndef _DTRAIN_UPDATE_H_
+#define _DTRAIN_UPDATE_H_
+
+namespace dtrain
+{
+
+bool
+CmpHypsByGold(ScoredHyp a, ScoredHyp b)
+{
+ return a.gold > b.gold;
+}
+
+/*
+ * multipartite ranking
+ * sort (descending) by bleu
+ * compare top X (hi) to middle Y (med) and low X (lo)
+ * cmp middle Y to low X
+ */
+inline size_t
+CollectUpdates(vector<ScoredHyp>* s,
+ SparseVector<weight_t>& updates,
+ float margin=1.0)
+{
+ size_t num_pairs = 0;
+ size_t sz = s->size();
+ if (sz < 2) return 0;
+ sort(s->begin(), s->end(), CmpHypsByGold);
+ size_t sep = round(sz*0.1);
+ size_t sep_hi = sep;
+ if (sz > 4) {
+ while
+ (sep_hi < sz && (*s)[sep_hi-1].gold == (*s)[sep_hi].gold) ++sep_hi;
+ }
+ else sep_hi = 1;
+ for (size_t i = 0; i < sep_hi; i++) {
+ for (size_t j = sep_hi; j < sz; j++) {
+ if (((*s)[i].model-(*s)[j].model) > margin)
+ continue;
+ if ((*s)[i].gold != (*s)[j].gold) {
+ updates += (*s)[i].f-(*s)[j].f;
+ num_pairs++;
+ }
+ }
+ }
+ size_t sep_lo = sz-sep;
+ while (sep_lo > 0 && (*s)[sep_lo-1].gold == (*s)[sep_lo].gold)
+ --sep_lo;
+ for (size_t i = sep_hi; i < sep_lo; i++) {
+ for (size_t j = sep_lo; j < sz; j++) {
+ if (((*s)[i].model-(*s)[j].model) > margin)
+ continue;
+ if ((*s)[i].gold != (*s)[j].gold) {
+ updates += (*s)[i].f-(*s)[j].f;
+ num_pairs++;
+ }
+ }
+ }
+
+ return num_pairs;
+}
+
+} // namespace
+
+#endif
+