summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pro-train/mr_pro_map.cc9
-rw-r--r--training/candidate_set.cc39
-rw-r--r--training/candidate_set.h18
3 files changed, 31 insertions, 35 deletions
diff --git a/pro-train/mr_pro_map.cc b/pro-train/mr_pro_map.cc
index 2aa0dc6f..bb13fdf4 100644
--- a/pro-train/mr_pro_map.cc
+++ b/pro-train/mr_pro_map.cc
@@ -92,7 +92,6 @@ struct DiffOrder {
void Sample(const unsigned gamma,
const unsigned xi,
const training::CandidateSet& J_i,
- const SegmentEvaluator& scorer,
const EvaluationMetric* metric,
vector<TrainingInstance>* pv) {
const bool invert_score = metric->IsErrorMetric();
@@ -102,8 +101,8 @@ void Sample(const unsigned gamma,
const size_t a = rng->inclusive(0, J_i.size() - 1)();
const size_t b = rng->inclusive(0, J_i.size() - 1)();
if (a == b) continue;
- float ga = J_i[a].g(scorer, metric);
- float gb = J_i[b].g(scorer, metric);
+ float ga = metric->ComputeScore(J_i[a].score_stats);
+ float gb = metric->ComputeScore(J_i[b].score_stats);
bool positive = gb < ga;
if (invert_score) positive = !positive;
const float gdiff = fabs(ga - gb);
@@ -187,10 +186,10 @@ int main(int argc, char** argv) {
J_i.ReadFromFile(kbest_file);
HypergraphIO::ReadFromJSON(rf.stream(), &hg);
hg.Reweight(weights);
- J_i.AddKBestCandidates(hg, kbest_size);
+ J_i.AddKBestCandidates(hg, kbest_size, ds[sent_id]);
J_i.WriteToFile(kbest_file);
- Sample(gamma, xi, J_i, *ds[sent_id], metric, &v);
+ Sample(gamma, xi, J_i, metric, &v);
for (unsigned i = 0; i < v.size(); ++i) {
const TrainingInstance& vi = v[i];
cout << vi.y << "\t" << vi.x << endl;
diff --git a/training/candidate_set.cc b/training/candidate_set.cc
index 5ab4558a..e2ca9ad2 100644
--- a/training/candidate_set.cc
+++ b/training/candidate_set.cc
@@ -62,15 +62,6 @@ struct ApproxVectorEquals {
}
};
-double Candidate::g(const SegmentEvaluator& scorer, const EvaluationMetric* metric) const {
- if (g_ == -100.0f) {
- SufficientStats ss;
- scorer.Evaluate(ewords, &ss);
- g_ = metric->ComputeScore(ss);
- }
- return g_;
-}
-
struct CandidateCompare {
bool operator()(const Candidate& a, const Candidate& b) const {
ApproxVectorEquals eq;
@@ -88,16 +79,6 @@ struct CandidateHasher {
}
};
-void CandidateSet::WriteToFile(const string& file) const {
- WriteFile wf(file);
- ostream& out = *wf.stream();
- out.precision(10);
- for (unsigned i = 0; i < cs.size(); ++i) {
- out << TD::GetString(cs[i].ewords) << endl;
- out << cs[i].fmap << endl;
- }
-}
-
static void ParseSparseVector(string& line, size_t cur, SparseVector<double>* out) {
SparseVector<double>& x = *out;
size_t last_start = cur;
@@ -123,18 +104,34 @@ static void ParseSparseVector(string& line, size_t cur, SparseVector<double>* ou
}
}
+void CandidateSet::WriteToFile(const string& file) const {
+ WriteFile wf(file);
+ ostream& out = *wf.stream();
+ out.precision(10);
+ string ss;
+ for (unsigned i = 0; i < cs.size(); ++i) {
+ out << TD::GetString(cs[i].ewords) << endl;
+ out << cs[i].fmap << endl;
+ cs[i].score_stats.Encode(&ss);
+ out << ss << endl;
+ }
+}
+
void CandidateSet::ReadFromFile(const string& file) {
cerr << "Reading candidates from " << file << endl;
ReadFile rf(file);
istream& in = *rf.stream();
string cand;
string feats;
+ string ss;
while(getline(in, cand)) {
getline(in, feats);
+ getline(in, ss);
assert(in);
cs.push_back(Candidate());
TD::ConvertSentence(cand, &cs.back().ewords);
ParseSparseVector(feats, 0, &cs.back().fmap);
+ cs.back().score_stats = SufficientStats(ss);
}
cerr << " read " << cs.size() << " candidates\n";
}
@@ -154,7 +151,7 @@ void CandidateSet::Dedup() {
cerr << " out=" << cs.size() << endl;
}
-void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size) {
+void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) {
KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, kbest_size);
for (unsigned i = 0; i < kbest_size; ++i) {
@@ -162,6 +159,8 @@ void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size) {
kbest.LazyKthBest(hg.nodes_.size() - 1, i);
if (!d) break;
cs.push_back(Candidate(d->yield, d->feature_values));
+ if (scorer)
+ scorer->Evaluate(d->yield, &cs.back().score_stats);
}
Dedup();
}
diff --git a/training/candidate_set.h b/training/candidate_set.h
index e2b0b1ba..824a4de2 100644
--- a/training/candidate_set.h
+++ b/training/candidate_set.h
@@ -4,29 +4,27 @@
#include <vector>
#include <algorithm>
+#include "ns.h"
#include "wordid.h"
#include "sparse_vector.h"
class Hypergraph;
-struct SegmentEvaluator;
-struct EvaluationMetric;
namespace training {
struct Candidate {
- Candidate() : g_(-100.0f) {}
- Candidate(const std::vector<WordID>& e, const SparseVector<double>& fm) : ewords(e), fmap(fm), g_(-100.0f) {}
+ Candidate() {}
+ Candidate(const std::vector<WordID>& e, const SparseVector<double>& fm) :
+ ewords(e),
+ fmap(fm) {}
std::vector<WordID> ewords;
SparseVector<double> fmap;
- double g(const SegmentEvaluator& scorer, const EvaluationMetric* metric) const;
+ SufficientStats score_stats;
void swap(Candidate& other) {
- std::swap(g_, other.g_);
+ score_stats.swap(other.score_stats);
ewords.swap(other.ewords);
fmap.swap(other.fmap);
}
- private:
- mutable float g_;
- //SufficientStats score_stats;
};
// represents some kind of collection of translation candidates, e.g.
@@ -39,7 +37,7 @@ class CandidateSet {
void ReadFromFile(const std::string& file);
void WriteToFile(const std::string& file) const;
- void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size);
+ void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL);
// TODO add code to do unique k-best
// TODO add code to draw k samples