From 88a597f7bea6cd8325b48678dfaf874fae4d660d Mon Sep 17 00:00:00 2001
From: Patrick Simianer
Date: Fri, 30 Jan 2015 16:21:35 +0100
Subject: dtrain: fix_features
---
training/dtrain/dtrain.cc | 21 +++++++++++++++++++--
1 file changed, 19 insertions(+), 2 deletions(-)
(limited to 'training')
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index b180bc82..ae5b630a 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -44,6 +44,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("repeat", po::value()->default_value(1), "repeat optimization over kbest list this number of times")
("check", po::value()->zero_tokens(), "produce list of loss differentials")
("output_ranking", po::value()->default_value(""), "Output kbests with model scores and metric per iteration to this folder.")
+ ("fix_features", po::value()->zero_tokens(), "Ignore all features that are not in input_weights.")
("noup", po::value()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
@@ -115,6 +116,8 @@ main(int argc, char** argv)
if (cfg.count("rescale")) rescale = true;
bool keep = false;
if (cfg.count("keep")) keep = true;
+ bool fix_features = false;
+ if (cfg.count("fix_features")) fix_features = true;
const unsigned k = cfg["k"].as();
const unsigned N = cfg["N"].as();
@@ -193,8 +196,18 @@ main(int argc, char** argv)
// init weights
vector& decoder_weights = decoder.CurrentWeightVector();
- SparseVector lambdas, cumulative_penalties, w_average;
- if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as(), &decoder_weights);
+
+ SparseVector lambdas, cumulative_penalties, w_average, fixed;
+ if (cfg.count("input_weights")) {
+ Weights::InitFromFile(cfg["input_weights"].as(), &decoder_weights);
+ if (fix_features) {
+ Weights::InitSparseVector(decoder_weights, &fixed);
+ SparseVector::iterator it = fixed.begin();
+ for (; it != fixed.end(); ++it) {
+ it->second = 1.0;
+ }
+ }
+ }
Weights::InitSparseVector(decoder_weights, &lambdas);
// meta params for perceptron, SVM
@@ -334,6 +347,8 @@ main(int argc, char** argv)
if (next || stop) break;
// weights
+ if (fix_features)
+ lambdas.cw_mult(fixed);
lambdas.init_vector(&decoder_weights);
// getting input
@@ -642,6 +657,8 @@ main(int argc, char** argv)
// write weights to file
if (select_weights == "best" || keep) {
+ if (fix_features)
+ lambdas.cw_mult(fixed);
lambdas.init_vector(&decoder_weights);
string w_fn = "weights." + boost::lexical_cast(t) + ".gz";
Weights::WriteToFile(w_fn, decoder_weights, true);
--
cgit v1.2.3