summaryrefslogtreecommitdiff
path: root/training/candidate_set.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/candidate_set.cc')
-rw-r--r--training/candidate_set.cc39
1 files changed, 19 insertions, 20 deletions
diff --git a/training/candidate_set.cc b/training/candidate_set.cc
index 5ab4558a..e2ca9ad2 100644
--- a/training/candidate_set.cc
+++ b/training/candidate_set.cc
@@ -62,15 +62,6 @@ struct ApproxVectorEquals {
}
};
-double Candidate::g(const SegmentEvaluator& scorer, const EvaluationMetric* metric) const {
- if (g_ == -100.0f) {
- SufficientStats ss;
- scorer.Evaluate(ewords, &ss);
- g_ = metric->ComputeScore(ss);
- }
- return g_;
-}
-
struct CandidateCompare {
bool operator()(const Candidate& a, const Candidate& b) const {
ApproxVectorEquals eq;
@@ -88,16 +79,6 @@ struct CandidateHasher {
}
};
-void CandidateSet::WriteToFile(const string& file) const {
- WriteFile wf(file);
- ostream& out = *wf.stream();
- out.precision(10);
- for (unsigned i = 0; i < cs.size(); ++i) {
- out << TD::GetString(cs[i].ewords) << endl;
- out << cs[i].fmap << endl;
- }
-}
-
static void ParseSparseVector(string& line, size_t cur, SparseVector<double>* out) {
SparseVector<double>& x = *out;
size_t last_start = cur;
@@ -123,18 +104,34 @@ static void ParseSparseVector(string& line, size_t cur, SparseVector<double>* ou
}
}
+void CandidateSet::WriteToFile(const string& file) const {
+ WriteFile wf(file);
+ ostream& out = *wf.stream();
+ out.precision(10);
+ string ss;
+ for (unsigned i = 0; i < cs.size(); ++i) {
+ out << TD::GetString(cs[i].ewords) << endl;
+ out << cs[i].fmap << endl;
+ cs[i].score_stats.Encode(&ss);
+ out << ss << endl;
+ }
+}
+
void CandidateSet::ReadFromFile(const string& file) {
cerr << "Reading candidates from " << file << endl;
ReadFile rf(file);
istream& in = *rf.stream();
string cand;
string feats;
+ string ss;
while(getline(in, cand)) {
getline(in, feats);
+ getline(in, ss);
assert(in);
cs.push_back(Candidate());
TD::ConvertSentence(cand, &cs.back().ewords);
ParseSparseVector(feats, 0, &cs.back().fmap);
+ cs.back().score_stats = SufficientStats(ss);
}
cerr << " read " << cs.size() << " candidates\n";
}
@@ -154,7 +151,7 @@ void CandidateSet::Dedup() {
cerr << " out=" << cs.size() << endl;
}
-void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size) {
+void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) {
KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, kbest_size);
for (unsigned i = 0; i < kbest_size; ++i) {
@@ -162,6 +159,8 @@ void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size) {
kbest.LazyKthBest(hg.nodes_.size() - 1, i);
if (!d) break;
cs.push_back(Candidate(d->yield, d->feature_values));
+ if (scorer)
+ scorer->Evaluate(d->yield, &cs.back().score_stats);
}
Dedup();
}