diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/candidate_set.cc | 6 | ||||
-rw-r--r-- | training/candidate_set.h | 21 | ||||
-rw-r--r-- | training/mpi_flex_optimize.cc | 10 |
3 files changed, 22 insertions, 15 deletions
diff --git a/training/candidate_set.cc b/training/candidate_set.cc index e2ca9ad2..8c086ece 100644 --- a/training/candidate_set.cc +++ b/training/candidate_set.cc @@ -112,7 +112,7 @@ void CandidateSet::WriteToFile(const string& file) const { for (unsigned i = 0; i < cs.size(); ++i) { out << TD::GetString(cs[i].ewords) << endl; out << cs[i].fmap << endl; - cs[i].score_stats.Encode(&ss); + cs[i].eval_feats.Encode(&ss); out << ss << endl; } } @@ -131,7 +131,7 @@ void CandidateSet::ReadFromFile(const string& file) { cs.push_back(Candidate()); TD::ConvertSentence(cand, &cs.back().ewords); ParseSparseVector(feats, 0, &cs.back().fmap); - cs.back().score_stats = SufficientStats(ss); + cs.back().eval_feats = SufficientStats(ss); } cerr << " read " << cs.size() << " candidates\n"; } @@ -160,7 +160,7 @@ void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, c if (!d) break; cs.push_back(Candidate(d->yield, d->feature_values)); if (scorer) - scorer->Evaluate(d->yield, &cs.back().score_stats); + scorer->Evaluate(d->yield, &cs.back().eval_feats); } Dedup(); } diff --git a/training/candidate_set.h b/training/candidate_set.h index 824a4de2..9d326ed0 100644 --- a/training/candidate_set.h +++ b/training/candidate_set.h @@ -15,16 +15,25 @@ namespace training { struct Candidate { Candidate() {} Candidate(const std::vector<WordID>& e, const SparseVector<double>& fm) : - ewords(e), - fmap(fm) {} - std::vector<WordID> ewords; - SparseVector<double> fmap; - SufficientStats score_stats; + ewords(e), + fmap(fm) {} + Candidate(const std::vector<WordID>& e, + const SparseVector<double>& fm, + const SegmentEvaluator& se) : + ewords(e), + fmap(fm) { + se.Evaluate(ewords, &eval_feats); + } + void swap(Candidate& other) { - score_stats.swap(other.score_stats); + eval_feats.swap(other.eval_feats); ewords.swap(other.ewords); fmap.swap(other.fmap); } + + std::vector<WordID> ewords; + SparseVector<double> fmap; + SufficientStats eval_feats; }; // represents some kind of collection of translation candidates, e.g. diff --git a/training/mpi_flex_optimize.cc b/training/mpi_flex_optimize.cc index a9197208..a9ead018 100644 --- a/training/mpi_flex_optimize.cc +++ b/training/mpi_flex_optimize.cc @@ -179,18 +179,16 @@ double ApplyRegularizationTerms(const double C, const double T, const vector<double>& weights, const vector<double>& prev_weights, - vector<double>* g) { - assert(weights.size() == g->size()); + double* g) { double reg = 0; for (size_t i = 0; i < weights.size(); ++i) { const double prev_w_i = (i < prev_weights.size() ? prev_weights[i] : 0.0); const double& w_i = weights[i]; - double& g_i = (*g)[i]; reg += C * w_i * w_i; - g_i += 2 * C * w_i; + g[i] += 2 * C * w_i; reg += T * (w_i - prev_w_i) * (w_i - prev_w_i); - g_i += 2 * T * (w_i - prev_w_i); + g[i] += 2 * T * (w_i - prev_w_i); } return reg; } @@ -365,7 +363,7 @@ int main(int argc, char** argv) { time_series_strength, // * (iter == 0 ? 0.0 : 1.0), cur_weights, prev_weights, - &gg); + &gg[0]); obj += r; if (mi == 0 || mi == (minibatch_iterations - 1)) { if (!mi) cerr << iter << ' '; else cerr << ' '; |