diff options
-rw-r--r-- | pro-train/mr_pro_map.cc | 9 | ||||
-rw-r--r-- | training/candidate_set.cc | 39 | ||||
-rw-r--r-- | training/candidate_set.h | 18 |
3 files changed, 31 insertions, 35 deletions
diff --git a/pro-train/mr_pro_map.cc b/pro-train/mr_pro_map.cc index 2aa0dc6f..bb13fdf4 100644 --- a/pro-train/mr_pro_map.cc +++ b/pro-train/mr_pro_map.cc @@ -92,7 +92,6 @@ struct DiffOrder { void Sample(const unsigned gamma, const unsigned xi, const training::CandidateSet& J_i, - const SegmentEvaluator& scorer, const EvaluationMetric* metric, vector<TrainingInstance>* pv) { const bool invert_score = metric->IsErrorMetric(); @@ -102,8 +101,8 @@ void Sample(const unsigned gamma, const size_t a = rng->inclusive(0, J_i.size() - 1)(); const size_t b = rng->inclusive(0, J_i.size() - 1)(); if (a == b) continue; - float ga = J_i[a].g(scorer, metric); - float gb = J_i[b].g(scorer, metric); + float ga = metric->ComputeScore(J_i[a].score_stats); + float gb = metric->ComputeScore(J_i[b].score_stats); bool positive = gb < ga; if (invert_score) positive = !positive; const float gdiff = fabs(ga - gb); @@ -187,10 +186,10 @@ int main(int argc, char** argv) { J_i.ReadFromFile(kbest_file); HypergraphIO::ReadFromJSON(rf.stream(), &hg); hg.Reweight(weights); - J_i.AddKBestCandidates(hg, kbest_size); + J_i.AddKBestCandidates(hg, kbest_size, ds[sent_id]); J_i.WriteToFile(kbest_file); - Sample(gamma, xi, J_i, *ds[sent_id], metric, &v); + Sample(gamma, xi, J_i, metric, &v); for (unsigned i = 0; i < v.size(); ++i) { const TrainingInstance& vi = v[i]; cout << vi.y << "\t" << vi.x << endl; diff --git a/training/candidate_set.cc b/training/candidate_set.cc index 5ab4558a..e2ca9ad2 100644 --- a/training/candidate_set.cc +++ b/training/candidate_set.cc @@ -62,15 +62,6 @@ struct ApproxVectorEquals { } }; -double Candidate::g(const SegmentEvaluator& scorer, const EvaluationMetric* metric) const { - if (g_ == -100.0f) { - SufficientStats ss; - scorer.Evaluate(ewords, &ss); - g_ = metric->ComputeScore(ss); - } - return g_; -} - struct CandidateCompare { bool operator()(const Candidate& a, const Candidate& b) const { ApproxVectorEquals eq; @@ -88,16 +79,6 @@ struct CandidateHasher { } }; -void CandidateSet::WriteToFile(const string& file) const { - WriteFile wf(file); - ostream& out = *wf.stream(); - out.precision(10); - for (unsigned i = 0; i < cs.size(); ++i) { - out << TD::GetString(cs[i].ewords) << endl; - out << cs[i].fmap << endl; - } -} - static void ParseSparseVector(string& line, size_t cur, SparseVector<double>* out) { SparseVector<double>& x = *out; size_t last_start = cur; @@ -123,18 +104,34 @@ static void ParseSparseVector(string& line, size_t cur, SparseVector<double>* ou } } +void CandidateSet::WriteToFile(const string& file) const { + WriteFile wf(file); + ostream& out = *wf.stream(); + out.precision(10); + string ss; + for (unsigned i = 0; i < cs.size(); ++i) { + out << TD::GetString(cs[i].ewords) << endl; + out << cs[i].fmap << endl; + cs[i].score_stats.Encode(&ss); + out << ss << endl; + } +} + void CandidateSet::ReadFromFile(const string& file) { cerr << "Reading candidates from " << file << endl; ReadFile rf(file); istream& in = *rf.stream(); string cand; string feats; + string ss; while(getline(in, cand)) { getline(in, feats); + getline(in, ss); assert(in); cs.push_back(Candidate()); TD::ConvertSentence(cand, &cs.back().ewords); ParseSparseVector(feats, 0, &cs.back().fmap); + cs.back().score_stats = SufficientStats(ss); } cerr << " read " << cs.size() << " candidates\n"; } @@ -154,7 +151,7 @@ void CandidateSet::Dedup() { cerr << " out=" << cs.size() << endl; } -void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size) { +void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) { KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, kbest_size); for (unsigned i = 0; i < kbest_size; ++i) { @@ -162,6 +159,8 @@ void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size) { kbest.LazyKthBest(hg.nodes_.size() - 1, i); if (!d) break; cs.push_back(Candidate(d->yield, d->feature_values)); + if (scorer) + scorer->Evaluate(d->yield, &cs.back().score_stats); } Dedup(); } diff --git a/training/candidate_set.h b/training/candidate_set.h index e2b0b1ba..824a4de2 100644 --- a/training/candidate_set.h +++ b/training/candidate_set.h @@ -4,29 +4,27 @@ #include <vector> #include <algorithm> +#include "ns.h" #include "wordid.h" #include "sparse_vector.h" class Hypergraph; -struct SegmentEvaluator; -struct EvaluationMetric; namespace training { struct Candidate { - Candidate() : g_(-100.0f) {} - Candidate(const std::vector<WordID>& e, const SparseVector<double>& fm) : ewords(e), fmap(fm), g_(-100.0f) {} + Candidate() {} + Candidate(const std::vector<WordID>& e, const SparseVector<double>& fm) : + ewords(e), + fmap(fm) {} std::vector<WordID> ewords; SparseVector<double> fmap; - double g(const SegmentEvaluator& scorer, const EvaluationMetric* metric) const; + SufficientStats score_stats; void swap(Candidate& other) { - std::swap(g_, other.g_); + score_stats.swap(other.score_stats); ewords.swap(other.ewords); fmap.swap(other.fmap); } - private: - mutable float g_; - //SufficientStats score_stats; }; // represents some kind of collection of translation candidates, e.g. @@ -39,7 +37,7 @@ class CandidateSet { void ReadFromFile(const std::string& file); void WriteToFile(const std::string& file) const; - void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size); + void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL); // TODO add code to do unique k-best // TODO add code to draw k samples |