blob: 22dc9fee13721540fdbb1ea0c575039a86bc832c (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
|
#include "ns.h"
#include "tdict.h"
typedef float (*MetricScoreCallback)(void*, SufficientStats* stats);
typedef void (*MetricStatsCallback)(void*,
std::string *hyp,
std::vector<std::string> *refs,
SufficientStats* out);
struct PythonEvaluationMetric : public EvaluationMetric {
PythonEvaluationMetric(const std::string& id) : EvaluationMetric(id) {}
static EvaluationMetric* Instance(const std::string& id,
void* obj,
MetricStatsCallback statscb,
MetricScoreCallback scorecb) {
PythonEvaluationMetric* metric = new PythonEvaluationMetric(id);
metric->pymetric = obj;
metric->_compute_score = scorecb;
metric->_compute_sufficient_stats = statscb;
return metric;
}
float ComputeScore(const SufficientStats& stats) const {
SufficientStats stats_(stats);
return _compute_score(pymetric, &stats_);
}
void ComputeSufficientStatistics(const std::vector<WordID>& hyp,
const std::vector<std::vector<WordID> >& refs,
SufficientStats* out) const {
std::string hyp_(TD::GetString(hyp));
std::vector<std::string> refs_;
for(unsigned i = 0; i < refs.size(); ++i) {
refs_.push_back(TD::GetString(refs[i]));
}
_compute_sufficient_stats(pymetric, &hyp_, &refs_, out);
}
void* pymetric;
MetricStatsCallback _compute_sufficient_stats;
MetricScoreCallback _compute_score;
};
|