From ded34c668ca87b9e0a0ebca68944c6648602593a Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 28 May 2012 00:19:10 -0400 Subject: cache metric computation in pro --- pro-train/mr_pro_map.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'pro-train') diff --git a/pro-train/mr_pro_map.cc b/pro-train/mr_pro_map.cc index 2aa0dc6f..bb13fdf4 100644 --- a/pro-train/mr_pro_map.cc +++ b/pro-train/mr_pro_map.cc @@ -92,7 +92,6 @@ struct DiffOrder { void Sample(const unsigned gamma, const unsigned xi, const training::CandidateSet& J_i, - const SegmentEvaluator& scorer, const EvaluationMetric* metric, vector* pv) { const bool invert_score = metric->IsErrorMetric(); @@ -102,8 +101,8 @@ void Sample(const unsigned gamma, const size_t a = rng->inclusive(0, J_i.size() - 1)(); const size_t b = rng->inclusive(0, J_i.size() - 1)(); if (a == b) continue; - float ga = J_i[a].g(scorer, metric); - float gb = J_i[b].g(scorer, metric); + float ga = metric->ComputeScore(J_i[a].score_stats); + float gb = metric->ComputeScore(J_i[b].score_stats); bool positive = gb < ga; if (invert_score) positive = !positive; const float gdiff = fabs(ga - gb); @@ -187,10 +186,10 @@ int main(int argc, char** argv) { J_i.ReadFromFile(kbest_file); HypergraphIO::ReadFromJSON(rf.stream(), &hg); hg.Reweight(weights); - J_i.AddKBestCandidates(hg, kbest_size); + J_i.AddKBestCandidates(hg, kbest_size, ds[sent_id]); J_i.WriteToFile(kbest_file); - Sample(gamma, xi, J_i, *ds[sent_id], metric, &v); + Sample(gamma, xi, J_i, metric, &v); for (unsigned i = 0; i < v.size(); ++i) { const TrainingInstance& vi = v[i]; cout << vi.y << "\t" << vi.x << endl; -- cgit v1.2.3