diff options
Diffstat (limited to 'training/mr_optimize_reduce.cc')
-rw-r--r-- | training/mr_optimize_reduce.cc | 19 |
1 files changed, 5 insertions, 14 deletions
diff --git a/training/mr_optimize_reduce.cc b/training/mr_optimize_reduce.cc index b931991d..15e28fa1 100644 --- a/training/mr_optimize_reduce.cc +++ b/training/mr_optimize_reduce.cc @@ -88,25 +88,19 @@ int main(int argc, char** argv) { const bool use_b64 = conf["input_format"].as<string>() == "b64"; - Weights weights; - weights.InitFromFile(conf["input_weights"].as<string>()); + vector<weight_t> lambdas; + Weights::InitFromFile(conf["input_weights"].as<string>(), &lambdas); const string s_obj = "**OBJ**"; int num_feats = FD::NumFeats(); cerr << "Number of features: " << num_feats << endl; const bool gaussian_prior = conf.count("gaussian_prior"); - vector<double> means(num_feats, 0); + vector<weight_t> means(num_feats, 0); if (conf.count("means")) { if (!gaussian_prior) { cerr << "Don't use --means without --gaussian_prior!\n"; exit(1); } - Weights wm; - wm.InitFromFile(conf["means"].as<string>()); - if (num_feats != FD::NumFeats()) { - cerr << "[ERROR] Means file had unexpected features!\n"; - exit(1); - } - wm.InitVector(&means); + Weights::InitFromFile(conf["means"].as<string>(), &means); } shared_ptr<BatchOptimizer> o; const string omethod = conf["optimization_method"].as<string>(); @@ -124,8 +118,6 @@ int main(int argc, char** argv) { cerr << "No state file found, assuming ITERATION 1\n"; } - vector<double> lambdas(num_feats, 0); - weights.InitVector(&lambdas); double objective = 0; vector<double> gradient(num_feats, 0); // 0<TAB>**OBJ**=12.2;Feat1=2.3;Feat2=-0.2; @@ -223,8 +215,7 @@ int main(int argc, char** argv) { old.clear(); SanityCheck(lambdas); ShowLargestFeatures(lambdas); - weights.InitFromVector(lambdas); - weights.WriteToFile(conf["output_weights"].as<string>(), false); + Weights::WriteToFile(conf["output_weights"].as<string>(), lambdas, false); const bool conv = o->HasConverged(); if (conv) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; } |