summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-05-28 00:19:10 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-05-28 00:19:10 -0400
commitded34c668ca87b9e0a0ebca68944c6648602593a (patch)
tree3c55eba0af516976bebcbcdab80571efff0aab01 /training
parent104aad02a868c1fc6320276d9b3b9b0e1f41f457 (diff)
cache metric computation in pro
Diffstat (limited to 'training')
-rw-r--r--training/candidate_set.cc39
-rw-r--r--training/candidate_set.h18
2 files changed, 27 insertions, 30 deletions
diff --git a/training/candidate_set.cc b/training/candidate_set.cc
index 5ab4558a..e2ca9ad2 100644
--- a/training/candidate_set.cc
+++ b/training/candidate_set.cc
@@ -62,15 +62,6 @@ struct ApproxVectorEquals {
}
};
-double Candidate::g(const SegmentEvaluator& scorer, const EvaluationMetric* metric) const {
- if (g_ == -100.0f) {
- SufficientStats ss;
- scorer.Evaluate(ewords, &ss);
- g_ = metric->ComputeScore(ss);
- }
- return g_;
-}
-
struct CandidateCompare {
bool operator()(const Candidate& a, const Candidate& b) const {
ApproxVectorEquals eq;
@@ -88,16 +79,6 @@ struct CandidateHasher {
}
};
-void CandidateSet::WriteToFile(const string& file) const {
- WriteFile wf(file);
- ostream& out = *wf.stream();
- out.precision(10);
- for (unsigned i = 0; i < cs.size(); ++i) {
- out << TD::GetString(cs[i].ewords) << endl;
- out << cs[i].fmap << endl;
- }
-}
-
static void ParseSparseVector(string& line, size_t cur, SparseVector<double>* out) {
SparseVector<double>& x = *out;
size_t last_start = cur;
@@ -123,18 +104,34 @@ static void ParseSparseVector(string& line, size_t cur, SparseVector<double>* ou
}
}
+void CandidateSet::WriteToFile(const string& file) const {
+ WriteFile wf(file);
+ ostream& out = *wf.stream();
+ out.precision(10);
+ string ss;
+ for (unsigned i = 0; i < cs.size(); ++i) {
+ out << TD::GetString(cs[i].ewords) << endl;
+ out << cs[i].fmap << endl;
+ cs[i].score_stats.Encode(&ss);
+ out << ss << endl;
+ }
+}
+
void CandidateSet::ReadFromFile(const string& file) {
cerr << "Reading candidates from " << file << endl;
ReadFile rf(file);
istream& in = *rf.stream();
string cand;
string feats;
+ string ss;
while(getline(in, cand)) {
getline(in, feats);
+ getline(in, ss);
assert(in);
cs.push_back(Candidate());
TD::ConvertSentence(cand, &cs.back().ewords);
ParseSparseVector(feats, 0, &cs.back().fmap);
+ cs.back().score_stats = SufficientStats(ss);
}
cerr << " read " << cs.size() << " candidates\n";
}
@@ -154,7 +151,7 @@ void CandidateSet::Dedup() {
cerr << " out=" << cs.size() << endl;
}
-void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size) {
+void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) {
KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, kbest_size);
for (unsigned i = 0; i < kbest_size; ++i) {
@@ -162,6 +159,8 @@ void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size) {
kbest.LazyKthBest(hg.nodes_.size() - 1, i);
if (!d) break;
cs.push_back(Candidate(d->yield, d->feature_values));
+ if (scorer)
+ scorer->Evaluate(d->yield, &cs.back().score_stats);
}
Dedup();
}
diff --git a/training/candidate_set.h b/training/candidate_set.h
index e2b0b1ba..824a4de2 100644
--- a/training/candidate_set.h
+++ b/training/candidate_set.h
@@ -4,29 +4,27 @@
#include <vector>
#include <algorithm>
+#include "ns.h"
#include "wordid.h"
#include "sparse_vector.h"
class Hypergraph;
-struct SegmentEvaluator;
-struct EvaluationMetric;
namespace training {
struct Candidate {
- Candidate() : g_(-100.0f) {}
- Candidate(const std::vector<WordID>& e, const SparseVector<double>& fm) : ewords(e), fmap(fm), g_(-100.0f) {}
+ Candidate() {}
+ Candidate(const std::vector<WordID>& e, const SparseVector<double>& fm) :
+ ewords(e),
+ fmap(fm) {}
std::vector<WordID> ewords;
SparseVector<double> fmap;
- double g(const SegmentEvaluator& scorer, const EvaluationMetric* metric) const;
+ SufficientStats score_stats;
void swap(Candidate& other) {
- std::swap(g_, other.g_);
+ score_stats.swap(other.score_stats);
ewords.swap(other.ewords);
fmap.swap(other.fmap);
}
- private:
- mutable float g_;
- //SufficientStats score_stats;
};
// represents some kind of collection of translation candidates, e.g.
@@ -39,7 +37,7 @@ class CandidateSet {
void ReadFromFile(const std::string& file);
void WriteToFile(const std::string& file) const;
- void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size);
+ void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL);
// TODO add code to do unique k-best
// TODO add code to draw k samples