diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2012-05-28 00:19:10 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2012-05-28 00:19:10 -0400 |
commit | 0af4919f399d009ca8bcb9cadefbcc148c174c20 (patch) | |
tree | 5ebf444006d62c3666108611d9e4288055eb489c /training | |
parent | e17505c233fe62528205580f3cb1a62423954c25 (diff) |
cache metric computation in pro
Diffstat (limited to 'training')
-rw-r--r-- | training/candidate_set.cc | 39 | ||||
-rw-r--r-- | training/candidate_set.h | 18 |
2 files changed, 27 insertions, 30 deletions
diff --git a/training/candidate_set.cc b/training/candidate_set.cc index 5ab4558a..e2ca9ad2 100644 --- a/training/candidate_set.cc +++ b/training/candidate_set.cc @@ -62,15 +62,6 @@ struct ApproxVectorEquals { } }; -double Candidate::g(const SegmentEvaluator& scorer, const EvaluationMetric* metric) const { - if (g_ == -100.0f) { - SufficientStats ss; - scorer.Evaluate(ewords, &ss); - g_ = metric->ComputeScore(ss); - } - return g_; -} - struct CandidateCompare { bool operator()(const Candidate& a, const Candidate& b) const { ApproxVectorEquals eq; @@ -88,16 +79,6 @@ struct CandidateHasher { } }; -void CandidateSet::WriteToFile(const string& file) const { - WriteFile wf(file); - ostream& out = *wf.stream(); - out.precision(10); - for (unsigned i = 0; i < cs.size(); ++i) { - out << TD::GetString(cs[i].ewords) << endl; - out << cs[i].fmap << endl; - } -} - static void ParseSparseVector(string& line, size_t cur, SparseVector<double>* out) { SparseVector<double>& x = *out; size_t last_start = cur; @@ -123,18 +104,34 @@ static void ParseSparseVector(string& line, size_t cur, SparseVector<double>* ou } } +void CandidateSet::WriteToFile(const string& file) const { + WriteFile wf(file); + ostream& out = *wf.stream(); + out.precision(10); + string ss; + for (unsigned i = 0; i < cs.size(); ++i) { + out << TD::GetString(cs[i].ewords) << endl; + out << cs[i].fmap << endl; + cs[i].score_stats.Encode(&ss); + out << ss << endl; + } +} + void CandidateSet::ReadFromFile(const string& file) { cerr << "Reading candidates from " << file << endl; ReadFile rf(file); istream& in = *rf.stream(); string cand; string feats; + string ss; while(getline(in, cand)) { getline(in, feats); + getline(in, ss); assert(in); cs.push_back(Candidate()); TD::ConvertSentence(cand, &cs.back().ewords); ParseSparseVector(feats, 0, &cs.back().fmap); + cs.back().score_stats = SufficientStats(ss); } cerr << " read " << cs.size() << " candidates\n"; } @@ -154,7 +151,7 @@ void CandidateSet::Dedup() { cerr << " out=" << cs.size() << endl; } -void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size) { +void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) { KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, kbest_size); for (unsigned i = 0; i < kbest_size; ++i) { @@ -162,6 +159,8 @@ void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size) { kbest.LazyKthBest(hg.nodes_.size() - 1, i); if (!d) break; cs.push_back(Candidate(d->yield, d->feature_values)); + if (scorer) + scorer->Evaluate(d->yield, &cs.back().score_stats); } Dedup(); } diff --git a/training/candidate_set.h b/training/candidate_set.h index e2b0b1ba..824a4de2 100644 --- a/training/candidate_set.h +++ b/training/candidate_set.h @@ -4,29 +4,27 @@ #include <vector> #include <algorithm> +#include "ns.h" #include "wordid.h" #include "sparse_vector.h" class Hypergraph; -struct SegmentEvaluator; -struct EvaluationMetric; namespace training { struct Candidate { - Candidate() : g_(-100.0f) {} - Candidate(const std::vector<WordID>& e, const SparseVector<double>& fm) : ewords(e), fmap(fm), g_(-100.0f) {} + Candidate() {} + Candidate(const std::vector<WordID>& e, const SparseVector<double>& fm) : + ewords(e), + fmap(fm) {} std::vector<WordID> ewords; SparseVector<double> fmap; - double g(const SegmentEvaluator& scorer, const EvaluationMetric* metric) const; + SufficientStats score_stats; void swap(Candidate& other) { - std::swap(g_, other.g_); + score_stats.swap(other.score_stats); ewords.swap(other.ewords); fmap.swap(other.fmap); } - private: - mutable float g_; - //SufficientStats score_stats; }; // represents some kind of collection of translation candidates, e.g. @@ -39,7 +37,7 @@ class CandidateSet { void ReadFromFile(const std::string& file); void WriteToFile(const std::string& file) const; - void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size); + void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL); // TODO add code to do unique k-best // TODO add code to draw k samples |