diff options
| author | Victor Chahuneau <vchahune@cs.cmu.edu> | 2012-07-23 19:59:44 -0400 | 
|---|---|---|
| committer | Victor Chahuneau <vchahune@cs.cmu.edu> | 2012-07-23 19:59:44 -0400 | 
| commit | 070e7fbbacd70dd916a95c1ab08b0113ec221c30 (patch) | |
| tree | f3cf94bf86fa79b592b1869072f85e529ec16a11 /python/src/py_scorer.h | |
| parent | 9a5e10322c82916a3c3fdfa0489ed1999bc988c5 (diff) | |
[python] Evaluation metrics in Python
Diffstat (limited to 'python/src/py_scorer.h')
| -rw-r--r-- | python/src/py_scorer.h | 44 | 
1 files changed, 44 insertions, 0 deletions
| diff --git a/python/src/py_scorer.h b/python/src/py_scorer.h new file mode 100644 index 00000000..22dc9fee --- /dev/null +++ b/python/src/py_scorer.h @@ -0,0 +1,44 @@ +#include "ns.h" +#include "tdict.h" + +typedef float (*MetricScoreCallback)(void*, SufficientStats* stats); +typedef void (*MetricStatsCallback)(void*, +        std::string *hyp, +        std::vector<std::string> *refs, +        SufficientStats* out); + +struct PythonEvaluationMetric : public EvaluationMetric { + +    PythonEvaluationMetric(const std::string& id) : EvaluationMetric(id) {} + +    static EvaluationMetric* Instance(const std::string& id,  +            void* obj, +            MetricStatsCallback statscb, +            MetricScoreCallback scorecb) { +        PythonEvaluationMetric* metric = new PythonEvaluationMetric(id); +        metric->pymetric = obj; +        metric->_compute_score =  scorecb; +        metric->_compute_sufficient_stats = statscb; +        return metric; +    } + +    float ComputeScore(const SufficientStats& stats) const { +        SufficientStats stats_(stats); +        return _compute_score(pymetric, &stats_); +    } + +    void ComputeSufficientStatistics(const std::vector<WordID>& hyp, +            const std::vector<std::vector<WordID> >& refs, +            SufficientStats* out) const { +        std::string hyp_(TD::GetString(hyp)); +        std::vector<std::string> refs_; +        for(unsigned i = 0; i < refs.size(); ++i) { +            refs_.push_back(TD::GetString(refs[i])); +        } +        _compute_sufficient_stats(pymetric, &hyp_, &refs_, out); +    } + +    void* pymetric; +    MetricStatsCallback _compute_sufficient_stats; +    MetricScoreCallback _compute_score; +}; | 
