diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/risk.cc | 4 | ||||
| -rw-r--r-- | training/risk.h | 1 | 
2 files changed, 4 insertions, 1 deletions
| diff --git a/training/risk.cc b/training/risk.cc index 347ed3cb..d5a12cfd 100644 --- a/training/risk.cc +++ b/training/risk.cc @@ -31,7 +31,9 @@ double CandidateSetRisk::operator()(const vector<double>& params,    for (unsigned i = 0; i < cands_.size(); ++i) {      const double log_prob = cands_[i].fmap.dot(params) - log_z;      const double prob = exp(log_prob); -    const double r = prob * metric_.ComputeScore(cands_[i].eval_feats); +    const double cost = metric_.IsErrorMetric() ? metric_.ComputeScore(cands_[i].eval_feats) +                                                : 1.0 - metric_.ComputeScore(cands_[i].eval_feats); +    const double r = prob * cost;      risk += r;      if (g) (*g) += (cands_[i].fmap - exp_feats) * r;    } diff --git a/training/risk.h b/training/risk.h index 00ff60ec..2e8db0fb 100644 --- a/training/risk.h +++ b/training/risk.h @@ -9,6 +9,7 @@ namespace training {    class CandidateSet;    class CandidateSetRisk { +   public:      explicit CandidateSetRisk(const CandidateSet& cs, const EvaluationMetric& metric) :         cands_(cs),         metric_(metric) {} | 
