summaryrefslogtreecommitdiff
path: root/training/mpi_flex_optimize.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/mpi_flex_optimize.cc')
-rw-r--r--training/mpi_flex_optimize.cc10
1 files changed, 4 insertions, 6 deletions
diff --git a/training/mpi_flex_optimize.cc b/training/mpi_flex_optimize.cc
index a9197208..a9ead018 100644
--- a/training/mpi_flex_optimize.cc
+++ b/training/mpi_flex_optimize.cc
@@ -179,18 +179,16 @@ double ApplyRegularizationTerms(const double C,
const double T,
const vector<double>& weights,
const vector<double>& prev_weights,
- vector<double>* g) {
- assert(weights.size() == g->size());
+ double* g) {
double reg = 0;
for (size_t i = 0; i < weights.size(); ++i) {
const double prev_w_i = (i < prev_weights.size() ? prev_weights[i] : 0.0);
const double& w_i = weights[i];
- double& g_i = (*g)[i];
reg += C * w_i * w_i;
- g_i += 2 * C * w_i;
+ g[i] += 2 * C * w_i;
reg += T * (w_i - prev_w_i) * (w_i - prev_w_i);
- g_i += 2 * T * (w_i - prev_w_i);
+ g[i] += 2 * T * (w_i - prev_w_i);
}
return reg;
}
@@ -365,7 +363,7 @@ int main(int argc, char** argv) {
time_series_strength, // * (iter == 0 ? 0.0 : 1.0),
cur_weights,
prev_weights,
- &gg);
+ &gg[0]);
obj += r;
if (mi == 0 || mi == (minibatch_iterations - 1)) {
if (!mi) cerr << iter << ' '; else cerr << ' ';