diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-08-11 02:37:10 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-08-11 02:37:10 +0000 |
commit | 80686d4e567bae579ea39e009826a2de92cd4ace (patch) | |
tree | c3c35fcba57dde423a248f38aa121ad197c79734 /mteval/aer_scorer.cc | |
parent | 3c85c407c333899f6b4bc26632d312b8e568b638 (diff) |
major refactor, break bad circular deps
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@509 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'mteval/aer_scorer.cc')
-rw-r--r-- | mteval/aer_scorer.cc | 135 |
1 files changed, 135 insertions, 0 deletions
diff --git a/mteval/aer_scorer.cc b/mteval/aer_scorer.cc new file mode 100644 index 00000000..edd4390f --- /dev/null +++ b/mteval/aer_scorer.cc @@ -0,0 +1,135 @@ +#include "aer_scorer.h" + +#include <cmath> +#include <cassert> +#include <sstream> + +#include "tdict.h" +#include "alignment_pharaoh.h" + +using namespace std; + +class AERScore : public ScoreBase<AERScore> { + friend class AERScorer; + public: + AERScore() : num_matches(), num_predicted(), num_in_ref() {} + AERScore(int m, int p, int r) : + num_matches(m), num_predicted(p), num_in_ref(r) {} + virtual void PlusPartialEquals(const Score& rhs, int oracle_e_cover, int oracle_f_cover, int src_len){} + virtual void PlusEquals(const Score& delta, const float scale) { + const AERScore& other = static_cast<const AERScore&>(delta); + num_matches += scale*other.num_matches; + num_predicted += scale*other.num_predicted; + num_in_ref += scale*other.num_in_ref; + } + virtual void PlusEquals(const Score& delta) { + const AERScore& other = static_cast<const AERScore&>(delta); + num_matches += other.num_matches; + num_predicted += other.num_predicted; + num_in_ref += other.num_in_ref; + } + + + virtual ScoreP GetZero() const { + return ScoreP(new AERScore); + } + virtual ScoreP GetOne() const { + return ScoreP(new AERScore); + } + virtual void Subtract(const Score& rhs, Score* out) const { + AERScore* res = static_cast<AERScore*>(out); + const AERScore& other = static_cast<const AERScore&>(rhs); + res->num_matches = num_matches - other.num_matches; + res->num_predicted = num_predicted - other.num_predicted; + res->num_in_ref = num_in_ref - other.num_in_ref; + } + float Precision() const { + return static_cast<float>(num_matches) / num_predicted; + } + float Recall() const { + return static_cast<float>(num_matches) / num_in_ref; + } + float ComputePartialScore() const { return 0.0;} + virtual float ComputeScore() const { + const float prec = Precision(); + const float rec = Recall(); + const float f = (2.0 * prec * rec) / (rec + prec); + if (isnan(f)) return 1.0f; + return 1.0f - f; + } + virtual bool IsAdditiveIdentity() const { + return (num_matches == 0) && (num_predicted == 0) && (num_in_ref == 0); + } + virtual void ScoreDetails(std::string* out) const { + ostringstream os; + os << "AER=" << (ComputeScore() * 100.0) + << " F=" << (100 - ComputeScore() * 100.0) + << " P=" << (Precision() * 100.0) << " R=" << (Recall() * 100.0) + << " [" << num_matches << " " << num_predicted << " " << num_in_ref << "]"; + *out = os.str(); + } + virtual void Encode(std::string*out) const { + out->resize(sizeof(int) * 3); + *(int *)&(*out)[sizeof(int) * 0] = num_matches; + *(int *)&(*out)[sizeof(int) * 1] = num_predicted; + *(int *)&(*out)[sizeof(int) * 2] = num_in_ref; + } + private: + int num_matches; + int num_predicted; + int num_in_ref; +}; + +AERScorer::AERScorer(const vector<vector<WordID> >& refs, const string& src) : src_(src) { + if (refs.size() != 1) { + cerr << "AERScorer can only take a single reference!\n"; + abort(); + } + ref_ = AlignmentPharaoh::ReadPharaohAlignmentGrid(TD::GetString(refs.front())); +} + +static inline bool Safe(const Array2D<bool>& a, int i, int j) { + if (i >= 0 && j >= 0 && i < a.width() && j < a.height()) + return a(i,j); + else + return false; +} + +ScoreP AERScorer::ScoreCCandidate(const vector<WordID>& shyp) const { + return ScoreP(); +} + +ScoreP AERScorer::ScoreCandidate(const vector<WordID>& shyp) const { + boost::shared_ptr<Array2D<bool> > hyp = + AlignmentPharaoh::ReadPharaohAlignmentGrid(TD::GetString(shyp)); + + int m = 0; + int r = 0; + int p = 0; + int i_len = ref_->width(); + int j_len = ref_->height(); + for (int i = 0; i < i_len; ++i) { + for (int j = 0; j < j_len; ++j) { + if ((*ref_)(i,j)) { + ++r; + if (Safe(*hyp, i, j)) ++m; + } + } + } + for (int i = 0; i < hyp->width(); ++i) + for (int j = 0; j < hyp->height(); ++j) + if ((*hyp)(i,j)) ++p; + + return ScoreP(new AERScore(m,p,r)); +} + +ScoreP AERScorer::ScoreFromString(const string& in) { + AERScore* res = new AERScore; + res->num_matches = *(const int *)&in[sizeof(int) * 0]; + res->num_predicted = *(const int *)&in[sizeof(int) * 1]; + res->num_in_ref = *(const int *)&in[sizeof(int) * 2]; + return ScoreP(res); +} + +const std::string* AERScorer::GetSource() const { return &src_; } + |