summaryrefslogtreecommitdiff
path: root/training/mr_optimize_reduce.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/mr_optimize_reduce.cc')
-rw-r--r--training/mr_optimize_reduce.cc19
1 files changed, 5 insertions, 14 deletions
diff --git a/training/mr_optimize_reduce.cc b/training/mr_optimize_reduce.cc
index b931991d..15e28fa1 100644
--- a/training/mr_optimize_reduce.cc
+++ b/training/mr_optimize_reduce.cc
@@ -88,25 +88,19 @@ int main(int argc, char** argv) {
const bool use_b64 = conf["input_format"].as<string>() == "b64";
- Weights weights;
- weights.InitFromFile(conf["input_weights"].as<string>());
+ vector<weight_t> lambdas;
+ Weights::InitFromFile(conf["input_weights"].as<string>(), &lambdas);
const string s_obj = "**OBJ**";
int num_feats = FD::NumFeats();
cerr << "Number of features: " << num_feats << endl;
const bool gaussian_prior = conf.count("gaussian_prior");
- vector<double> means(num_feats, 0);
+ vector<weight_t> means(num_feats, 0);
if (conf.count("means")) {
if (!gaussian_prior) {
cerr << "Don't use --means without --gaussian_prior!\n";
exit(1);
}
- Weights wm;
- wm.InitFromFile(conf["means"].as<string>());
- if (num_feats != FD::NumFeats()) {
- cerr << "[ERROR] Means file had unexpected features!\n";
- exit(1);
- }
- wm.InitVector(&means);
+ Weights::InitFromFile(conf["means"].as<string>(), &means);
}
shared_ptr<BatchOptimizer> o;
const string omethod = conf["optimization_method"].as<string>();
@@ -124,8 +118,6 @@ int main(int argc, char** argv) {
cerr << "No state file found, assuming ITERATION 1\n";
}
- vector<double> lambdas(num_feats, 0);
- weights.InitVector(&lambdas);
double objective = 0;
vector<double> gradient(num_feats, 0);
// 0<TAB>**OBJ**=12.2;Feat1=2.3;Feat2=-0.2;
@@ -223,8 +215,7 @@ int main(int argc, char** argv) {
old.clear();
SanityCheck(lambdas);
ShowLargestFeatures(lambdas);
- weights.InitFromVector(lambdas);
- weights.WriteToFile(conf["output_weights"].as<string>(), false);
+ Weights::WriteToFile(conf["output_weights"].as<string>(), lambdas, false);
const bool conv = o->HasConverged();
if (conv) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; }