summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/candidate_set.cc6
-rw-r--r--training/candidate_set.h21
-rw-r--r--training/mpi_flex_optimize.cc10
3 files changed, 22 insertions, 15 deletions
diff --git a/training/candidate_set.cc b/training/candidate_set.cc
index e2ca9ad2..8c086ece 100644
--- a/training/candidate_set.cc
+++ b/training/candidate_set.cc
@@ -112,7 +112,7 @@ void CandidateSet::WriteToFile(const string& file) const {
for (unsigned i = 0; i < cs.size(); ++i) {
out << TD::GetString(cs[i].ewords) << endl;
out << cs[i].fmap << endl;
- cs[i].score_stats.Encode(&ss);
+ cs[i].eval_feats.Encode(&ss);
out << ss << endl;
}
}
@@ -131,7 +131,7 @@ void CandidateSet::ReadFromFile(const string& file) {
cs.push_back(Candidate());
TD::ConvertSentence(cand, &cs.back().ewords);
ParseSparseVector(feats, 0, &cs.back().fmap);
- cs.back().score_stats = SufficientStats(ss);
+ cs.back().eval_feats = SufficientStats(ss);
}
cerr << " read " << cs.size() << " candidates\n";
}
@@ -160,7 +160,7 @@ void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, c
if (!d) break;
cs.push_back(Candidate(d->yield, d->feature_values));
if (scorer)
- scorer->Evaluate(d->yield, &cs.back().score_stats);
+ scorer->Evaluate(d->yield, &cs.back().eval_feats);
}
Dedup();
}
diff --git a/training/candidate_set.h b/training/candidate_set.h
index 824a4de2..9d326ed0 100644
--- a/training/candidate_set.h
+++ b/training/candidate_set.h
@@ -15,16 +15,25 @@ namespace training {
struct Candidate {
Candidate() {}
Candidate(const std::vector<WordID>& e, const SparseVector<double>& fm) :
- ewords(e),
- fmap(fm) {}
- std::vector<WordID> ewords;
- SparseVector<double> fmap;
- SufficientStats score_stats;
+ ewords(e),
+ fmap(fm) {}
+ Candidate(const std::vector<WordID>& e,
+ const SparseVector<double>& fm,
+ const SegmentEvaluator& se) :
+ ewords(e),
+ fmap(fm) {
+ se.Evaluate(ewords, &eval_feats);
+ }
+
void swap(Candidate& other) {
- score_stats.swap(other.score_stats);
+ eval_feats.swap(other.eval_feats);
ewords.swap(other.ewords);
fmap.swap(other.fmap);
}
+
+ std::vector<WordID> ewords;
+ SparseVector<double> fmap;
+ SufficientStats eval_feats;
};
// represents some kind of collection of translation candidates, e.g.
diff --git a/training/mpi_flex_optimize.cc b/training/mpi_flex_optimize.cc
index a9197208..a9ead018 100644
--- a/training/mpi_flex_optimize.cc
+++ b/training/mpi_flex_optimize.cc
@@ -179,18 +179,16 @@ double ApplyRegularizationTerms(const double C,
const double T,
const vector<double>& weights,
const vector<double>& prev_weights,
- vector<double>* g) {
- assert(weights.size() == g->size());
+ double* g) {
double reg = 0;
for (size_t i = 0; i < weights.size(); ++i) {
const double prev_w_i = (i < prev_weights.size() ? prev_weights[i] : 0.0);
const double& w_i = weights[i];
- double& g_i = (*g)[i];
reg += C * w_i * w_i;
- g_i += 2 * C * w_i;
+ g[i] += 2 * C * w_i;
reg += T * (w_i - prev_w_i) * (w_i - prev_w_i);
- g_i += 2 * T * (w_i - prev_w_i);
+ g[i] += 2 * T * (w_i - prev_w_i);
}
return reg;
}
@@ -365,7 +363,7 @@ int main(int argc, char** argv) {
time_series_strength, // * (iter == 0 ? 0.0 : 1.0),
cur_weights,
prev_weights,
- &gg);
+ &gg[0]);
obj += r;
if (mi == 0 || mi == (minibatch_iterations - 1)) {
if (!mi) cerr << iter << ' '; else cerr << ' ';