diff options
| author | Kenneth Heafield <github@kheafield.com> | 2012-08-03 07:46:54 -0400 | 
|---|---|---|
| committer | Kenneth Heafield <github@kheafield.com> | 2012-08-03 07:46:54 -0400 | 
| commit | be1ab0a8937f9c5668ea5e6c31b798e87672e55e (patch) | |
| tree | a13aad60ab6cced213401bce6a38ac885ba171ba /training/risk.cc | |
| parent | e5d6f4ae41009c26978ecd62668501af9762b0bc (diff) | |
| parent | 9fe0219562e5db25171cce8776381600ff9a5649 (diff) | |
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'training/risk.cc')
| -rw-r--r-- | training/risk.cc | 45 | 
1 files changed, 45 insertions, 0 deletions
| diff --git a/training/risk.cc b/training/risk.cc new file mode 100644 index 00000000..d5a12cfd --- /dev/null +++ b/training/risk.cc @@ -0,0 +1,45 @@ +#include "risk.h" + +#include "prob.h" +#include "candidate_set.h" +#include "ns.h" + +using namespace std; + +namespace training { + +// g = \sum_e p(e|f) * loss(e) * (phi(e,f) - E[phi(e,f)]) +double CandidateSetRisk::operator()(const vector<double>& params, +                                    SparseVector<double>* g) const { +  prob_t z; +  for (unsigned i = 0; i < cands_.size(); ++i) { +    const prob_t u(cands_[i].fmap.dot(params), init_lnx()); +    z += u; +  } +  const double log_z = log(z); + +  SparseVector<double> exp_feats; +  if (g) { +    for (unsigned i = 0; i < cands_.size(); ++i) { +      const double log_prob = cands_[i].fmap.dot(params) - log_z; +      const double prob = exp(log_prob); +      exp_feats += cands_[i].fmap * prob; +    } +  } + +  double risk = 0; +  for (unsigned i = 0; i < cands_.size(); ++i) { +    const double log_prob = cands_[i].fmap.dot(params) - log_z; +    const double prob = exp(log_prob); +    const double cost = metric_.IsErrorMetric() ? metric_.ComputeScore(cands_[i].eval_feats) +                                                : 1.0 - metric_.ComputeScore(cands_[i].eval_feats); +    const double r = prob * cost; +    risk += r; +    if (g) (*g) += (cands_[i].fmap - exp_feats) * r; +  } +  return risk; +} + +} + + | 
