summaryrefslogtreecommitdiff
path: root/training/risk.h
blob: 00ff60ec59a2519c599508a5962bde2369279914 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#ifndef _RISK_H_
#define _RISK_H_

#include <vector>
#include "sparse_vector.h"
class EvaluationMetric;

namespace training {
  class CandidateSet;

  class CandidateSetRisk {
    explicit CandidateSetRisk(const CandidateSet& cs, const EvaluationMetric& metric) :
       cands_(cs),
       metric_(metric) {}
    // compute the risk (expected loss) of a CandidateSet
    // (optional) the gradient of the risk with respect to params
    double operator()(const std::vector<double>& params,
                      SparseVector<double>* g = NULL) const;
   private:
    const CandidateSet& cands_;
    const EvaluationMetric& metric_;
  };
};

#endif