summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/risk.cc4
-rw-r--r--training/risk.h1
2 files changed, 4 insertions, 1 deletions
diff --git a/training/risk.cc b/training/risk.cc
index 347ed3cb..d5a12cfd 100644
--- a/training/risk.cc
+++ b/training/risk.cc
@@ -31,7 +31,9 @@ double CandidateSetRisk::operator()(const vector<double>& params,
for (unsigned i = 0; i < cands_.size(); ++i) {
const double log_prob = cands_[i].fmap.dot(params) - log_z;
const double prob = exp(log_prob);
- const double r = prob * metric_.ComputeScore(cands_[i].eval_feats);
+ const double cost = metric_.IsErrorMetric() ? metric_.ComputeScore(cands_[i].eval_feats)
+ : 1.0 - metric_.ComputeScore(cands_[i].eval_feats);
+ const double r = prob * cost;
risk += r;
if (g) (*g) += (cands_[i].fmap - exp_feats) * r;
}
diff --git a/training/risk.h b/training/risk.h
index 00ff60ec..2e8db0fb 100644
--- a/training/risk.h
+++ b/training/risk.h
@@ -9,6 +9,7 @@ namespace training {
class CandidateSet;
class CandidateSetRisk {
+ public:
explicit CandidateSetRisk(const CandidateSet& cs, const EvaluationMetric& metric) :
cands_(cs),
metric_(metric) {}