1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
|
package optimization.gradientBasedMethods;
import optimization.gradientBasedMethods.stats.OptimizerStats;
import optimization.linesearch.DifferentiableLineSearchObjective;
import optimization.linesearch.LineSearchMethod;
import optimization.stopCriteria.StopingCriteria;
import optimization.util.MathUtils;
/**
*
* @author javg
*
*/
public abstract class AbstractGradientBaseMethod implements Optimizer{
protected int maxNumberOfIterations=10000;
protected int currentProjectionIteration;
protected double currValue;
protected double previousValue = Double.MAX_VALUE;;
protected double step;
protected double[] gradient;
public double[] direction;
//Original values
protected double originalGradientL2Norm;
protected LineSearchMethod lineSearch;
DifferentiableLineSearchObjective lso;
public void reset(){
direction = null;
gradient = null;
previousValue = Double.MAX_VALUE;
currentProjectionIteration = 0;
originalGradientL2Norm = 0;
step = 0;
currValue = 0;
}
public void initializeStructures(Objective o,OptimizerStats stats, StopingCriteria stop){
lso = new DifferentiableLineSearchObjective(o);
}
public void updateStructuresBeforeStep(Objective o,OptimizerStats stats, StopingCriteria stop){
}
public void updateStructuresAfterStep(Objective o,OptimizerStats stats, StopingCriteria stop){
}
public boolean optimize(Objective o,OptimizerStats stats, StopingCriteria stop){
//Initialize structures
stats.collectInitStats(this, o);
direction = new double[o.getNumParameters()];
initializeStructures(o, stats, stop);
for (currentProjectionIteration = 1; currentProjectionIteration < maxNumberOfIterations; currentProjectionIteration++){
//System.out.println("\tgradient descent iteration " + currentProjectionIteration);
//System.out.print("\tparameters:" );
//o.printParameters();
previousValue = currValue;
currValue = o.getValue();
gradient = o.getGradient();
if(stop.stopOptimization(o)){
stats.collectFinalStats(this, o);
return true;
}
getDirection();
if(MathUtils.dotProduct(gradient, direction) > 0){
System.out.println("Not a descent direction");
System.out.println(" current stats " + stats.prettyPrint(1));
System.exit(-1);
}
updateStructuresBeforeStep(o, stats, stop);
lso.reset(direction);
step = lineSearch.getStepSize(lso);
//System.out.println("\t\tLeave with step: " + step);
if(step==-1){
System.out.println("Failed to find step");
stats.collectFinalStats(this, o);
return false;
}
updateStructuresAfterStep( o, stats, stop);
// previousValue = currValue;
// currValue = o.getValue();
// gradient = o.getGradient();
stats.collectIterationStats(this, o);
}
stats.collectFinalStats(this, o);
return false;
}
public int getCurrentIteration() {
return currentProjectionIteration;
}
/**
* Method specific
*/
public abstract double[] getDirection();
public double getCurrentStep() {
return step;
}
public void setMaxIterations(int max) {
maxNumberOfIterations = max;
}
public double getCurrentValue() {
return currValue;
}
}
|