summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/Makefile.am3
-rw-r--r--training/risk.cc43
-rw-r--r--training/risk.h25
3 files changed, 70 insertions, 1 deletions
diff --git a/training/Makefile.am b/training/Makefile.am
index 19ee8f0d..68ebfab4 100644
--- a/training/Makefile.am
+++ b/training/Makefile.am
@@ -27,7 +27,8 @@ noinst_LIBRARIES = libtraining.a
libtraining_a_SOURCES = \
candidate_set.cc \
optimize.cc \
- online_optimizer.cc
+ online_optimizer.cc \
+ risk.cc
mpi_online_optimize_SOURCES = mpi_online_optimize.cc
mpi_online_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz
diff --git a/training/risk.cc b/training/risk.cc
new file mode 100644
index 00000000..347ed3cb
--- /dev/null
+++ b/training/risk.cc
@@ -0,0 +1,43 @@
+#include "risk.h"
+
+#include "prob.h"
+#include "candidate_set.h"
+#include "ns.h"
+
+using namespace std;
+
+namespace training {
+
+// g = \sum_e p(e|f) * loss(e) * (phi(e,f) - E[phi(e,f)])
+double CandidateSetRisk::operator()(const vector<double>& params,
+ SparseVector<double>* g) const {
+ prob_t z;
+ for (unsigned i = 0; i < cands_.size(); ++i) {
+ const prob_t u(cands_[i].fmap.dot(params), init_lnx());
+ z += u;
+ }
+ const double log_z = log(z);
+
+ SparseVector<double> exp_feats;
+ if (g) {
+ for (unsigned i = 0; i < cands_.size(); ++i) {
+ const double log_prob = cands_[i].fmap.dot(params) - log_z;
+ const double prob = exp(log_prob);
+ exp_feats += cands_[i].fmap * prob;
+ }
+ }
+
+ double risk = 0;
+ for (unsigned i = 0; i < cands_.size(); ++i) {
+ const double log_prob = cands_[i].fmap.dot(params) - log_z;
+ const double prob = exp(log_prob);
+ const double r = prob * metric_.ComputeScore(cands_[i].eval_feats);
+ risk += r;
+ if (g) (*g) += (cands_[i].fmap - exp_feats) * r;
+ }
+ return risk;
+}
+
+}
+
+
diff --git a/training/risk.h b/training/risk.h
new file mode 100644
index 00000000..00ff60ec
--- /dev/null
+++ b/training/risk.h
@@ -0,0 +1,25 @@
+#ifndef _RISK_H_
+#define _RISK_H_
+
+#include <vector>
+#include "sparse_vector.h"
+class EvaluationMetric;
+
+namespace training {
+ class CandidateSet;
+
+ class CandidateSetRisk {
+ explicit CandidateSetRisk(const CandidateSet& cs, const EvaluationMetric& metric) :
+ cands_(cs),
+ metric_(metric) {}
+ // compute the risk (expected loss) of a CandidateSet
+ // (optional) the gradient of the risk with respect to params
+ double operator()(const std::vector<double>& params,
+ SparseVector<double>* g = NULL) const;
+ private:
+ const CandidateSet& cands_;
+ const EvaluationMetric& metric_;
+ };
+};
+
+#endif