summaryrefslogtreecommitdiff
path: root/dtrain/test/log_reg/bin_class.h
blob: 3466109afb151fb61c6dd137c7d883951c99c3f7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#ifndef _BIN_CLASS_H_
#define _BIN_CLASS_H_

#include <vector>
#include "sparse_vector.h"

struct TrainingInstance {
  // TODO add other info? loss for MIRA-type updates?
  SparseVector<double> x_feature_map;
  bool y;
};

struct Objective {
  virtual ~Objective();

  // returns f(x) and f'(x)
  virtual double ObjectiveAndGradient(const SparseVector<double>& x,
                  const std::vector<TrainingInstance>& training_instances,
                  SparseVector<double>* g) const = 0;
};

#endif