diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2012-06-24 22:30:50 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2012-06-24 22:30:50 -0400 |
commit | ac40b555b98a2ea295d48e95263086b52ed3b74b (patch) | |
tree | bb324192fd23f9d77b61024cb14e9d6d4907a7c1 /training | |
parent | 535cfe8da7c6107d8415afd1381d59f2a6b9844f (diff) |
minimum risk training, not completely ready for primetime
Diffstat (limited to 'training')
-rw-r--r-- | training/risk.cc | 4 | ||||
-rw-r--r-- | training/risk.h | 1 |
2 files changed, 4 insertions, 1 deletions
diff --git a/training/risk.cc b/training/risk.cc index 347ed3cb..d5a12cfd 100644 --- a/training/risk.cc +++ b/training/risk.cc @@ -31,7 +31,9 @@ double CandidateSetRisk::operator()(const vector<double>& params, for (unsigned i = 0; i < cands_.size(); ++i) { const double log_prob = cands_[i].fmap.dot(params) - log_z; const double prob = exp(log_prob); - const double r = prob * metric_.ComputeScore(cands_[i].eval_feats); + const double cost = metric_.IsErrorMetric() ? metric_.ComputeScore(cands_[i].eval_feats) + : 1.0 - metric_.ComputeScore(cands_[i].eval_feats); + const double r = prob * cost; risk += r; if (g) (*g) += (cands_[i].fmap - exp_feats) * r; } diff --git a/training/risk.h b/training/risk.h index 00ff60ec..2e8db0fb 100644 --- a/training/risk.h +++ b/training/risk.h @@ -9,6 +9,7 @@ namespace training { class CandidateSet; class CandidateSetRisk { + public: explicit CandidateSetRisk(const CandidateSet& cs, const EvaluationMetric& metric) : cands_(cs), metric_(metric) {} |