From 88a597f7bea6cd8325b48678dfaf874fae4d660d Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Fri, 30 Jan 2015 16:21:35 +0100 Subject: dtrain: fix_features --- training/dtrain/dtrain.cc | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index b180bc82..ae5b630a 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -44,6 +44,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("repeat", po::value()->default_value(1), "repeat optimization over kbest list this number of times") ("check", po::value()->zero_tokens(), "produce list of loss differentials") ("output_ranking", po::value()->default_value(""), "Output kbests with model scores and metric per iteration to this folder.") + ("fix_features", po::value()->zero_tokens(), "Ignore all features that are not in input_weights.") ("noup", po::value()->zero_tokens(), "do not update weights"); po::options_description cl("Command Line Options"); cl.add_options() @@ -115,6 +116,8 @@ main(int argc, char** argv) if (cfg.count("rescale")) rescale = true; bool keep = false; if (cfg.count("keep")) keep = true; + bool fix_features = false; + if (cfg.count("fix_features")) fix_features = true; const unsigned k = cfg["k"].as(); const unsigned N = cfg["N"].as(); @@ -193,8 +196,18 @@ main(int argc, char** argv) // init weights vector& decoder_weights = decoder.CurrentWeightVector(); - SparseVector lambdas, cumulative_penalties, w_average; - if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as(), &decoder_weights); + + SparseVector lambdas, cumulative_penalties, w_average, fixed; + if (cfg.count("input_weights")) { + Weights::InitFromFile(cfg["input_weights"].as(), &decoder_weights); + if (fix_features) { + Weights::InitSparseVector(decoder_weights, &fixed); + SparseVector::iterator it = fixed.begin(); + for (; it != fixed.end(); ++it) { + it->second = 1.0; + } + } + } Weights::InitSparseVector(decoder_weights, &lambdas); // meta params for perceptron, SVM @@ -334,6 +347,8 @@ main(int argc, char** argv) if (next || stop) break; // weights + if (fix_features) + lambdas.cw_mult(fixed); lambdas.init_vector(&decoder_weights); // getting input @@ -642,6 +657,8 @@ main(int argc, char** argv) // write weights to file if (select_weights == "best" || keep) { + if (fix_features) + lambdas.cw_mult(fixed); lambdas.init_vector(&decoder_weights); string w_fn = "weights." + boost::lexical_cast(t) + ".gz"; Weights::WriteToFile(w_fn, decoder_weights, true); -- cgit v1.2.3