summaryrefslogtreecommitdiff
path: root/vest/aer_scorer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'vest/aer_scorer.cc')
-rw-r--r--vest/aer_scorer.cc118
1 files changed, 118 insertions, 0 deletions
diff --git a/vest/aer_scorer.cc b/vest/aer_scorer.cc
new file mode 100644
index 00000000..9c8a783a
--- /dev/null
+++ b/vest/aer_scorer.cc
@@ -0,0 +1,118 @@
+#include "aer_scorer.h"
+
+#include <cmath>
+#include <cassert>
+#include <sstream>
+
+#include "tdict.h"
+#include "aligner.h"
+
+using namespace std;
+
+class AERScore : public Score {
+ friend class AERScorer;
+ public:
+ AERScore() : num_matches(), num_predicted(), num_in_ref() {}
+ AERScore(int m, int p, int r) :
+ num_matches(m), num_predicted(p), num_in_ref(r) {}
+ virtual void PlusEquals(const Score& delta) {
+ const AERScore& other = static_cast<const AERScore&>(delta);
+ num_matches += other.num_matches;
+ num_predicted += other.num_predicted;
+ num_in_ref += other.num_in_ref;
+ }
+ virtual Score* GetZero() const {
+ return new AERScore;
+ }
+ virtual void Subtract(const Score& rhs, Score* out) const {
+ AERScore* res = static_cast<AERScore*>(out);
+ const AERScore& other = static_cast<const AERScore&>(rhs);
+ res->num_matches = num_matches - other.num_matches;
+ res->num_predicted = num_predicted - other.num_predicted;
+ res->num_in_ref = num_in_ref - other.num_in_ref;
+ }
+ float Precision() const {
+ return static_cast<float>(num_matches) / num_predicted;
+ }
+ float Recall() const {
+ return static_cast<float>(num_matches) / num_in_ref;
+ }
+ virtual float ComputeScore() const {
+ const float prec = Precision();
+ const float rec = Recall();
+ const float f = (2.0 * prec * rec) / (rec + prec);
+ if (isnan(f)) return 1.0f;
+ return 1.0f - f;
+ }
+ virtual bool IsAdditiveIdentity() const {
+ return (num_matches == 0) && (num_predicted == 0) && (num_in_ref == 0);
+ }
+ virtual void ScoreDetails(std::string* out) const {
+ ostringstream os;
+ os << "AER=" << (ComputeScore() * 100.0)
+ << " F=" << (100 - ComputeScore() * 100.0)
+ << " P=" << (Precision() * 100.0) << " R=" << (Recall() * 100.0)
+ << " [" << num_matches << " " << num_predicted << " " << num_in_ref << "]";
+ *out = os.str();
+ }
+ virtual void Encode(std::string*out) const {
+ out->resize(sizeof(int) * 3);
+ *(int *)&(*out)[sizeof(int) * 0] = num_matches;
+ *(int *)&(*out)[sizeof(int) * 1] = num_predicted;
+ *(int *)&(*out)[sizeof(int) * 2] = num_in_ref;
+ }
+ private:
+ int num_matches;
+ int num_predicted;
+ int num_in_ref;
+};
+
+AERScorer::AERScorer(const vector<vector<WordID> >& refs, const string& src) : src_(src) {
+ if (refs.size() != 1) {
+ cerr << "AERScorer can only take a single reference!\n";
+ abort();
+ }
+ ref_ = AlignerTools::ReadPharaohAlignmentGrid(TD::GetString(refs.front()));
+}
+
+static inline bool Safe(const Array2D<bool>& a, int i, int j) {
+ if (i >= 0 && j >= 0 && i < a.width() && j < a.height())
+ return a(i,j);
+ else
+ return false;
+}
+
+Score* AERScorer::ScoreCandidate(const vector<WordID>& shyp) const {
+ boost::shared_ptr<Array2D<bool> > hyp =
+ AlignerTools::ReadPharaohAlignmentGrid(TD::GetString(shyp));
+
+ int m = 0;
+ int r = 0;
+ int p = 0;
+ int i_len = ref_->width();
+ int j_len = ref_->height();
+ for (int i = 0; i < i_len; ++i) {
+ for (int j = 0; j < j_len; ++j) {
+ if ((*ref_)(i,j)) {
+ ++r;
+ if (Safe(*hyp, i, j)) ++m;
+ }
+ }
+ }
+ for (int i = 0; i < hyp->width(); ++i)
+ for (int j = 0; j < hyp->height(); ++j)
+ if ((*hyp)(i,j)) ++p;
+
+ return new AERScore(m,p,r);
+}
+
+Score* AERScorer::ScoreFromString(const string& in) {
+ AERScore* res = new AERScore;
+ res->num_matches = *(const int *)&in[sizeof(int) * 0];
+ res->num_predicted = *(const int *)&in[sizeof(int) * 1];
+ res->num_in_ref = *(const int *)&in[sizeof(int) * 2];
+ return res;
+}
+
+const std::string* AERScorer::GetSource() const { return &src_; }
+