1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
|
package optimization.gradientBasedMethods;
import optimization.gradientBasedMethods.stats.OptimizerStats;
import optimization.linesearch.DifferentiableLineSearchObjective;
import optimization.linesearch.LineSearchMethod;
import optimization.stopCriteria.StopingCriteria;
import optimization.util.MathUtils;
public class ConjugateGradient extends AbstractGradientBaseMethod{
double[] previousGradient;
double[] previousDirection;
public ConjugateGradient(LineSearchMethod lineSearch) {
this.lineSearch = lineSearch;
}
public void reset(){
super.reset();
java.util.Arrays.fill(previousDirection, 0);
java.util.Arrays.fill(previousGradient, 0);
}
public void initializeStructures(Objective o,OptimizerStats stats, StopingCriteria stop){
super.initializeStructures(o, stats, stop);
previousGradient = new double[o.getNumParameters()];
previousDirection = new double[o.getNumParameters()];
}
public void updateStructuresBeforeStep(Objective o,OptimizerStats stats, StopingCriteria stop){
System.arraycopy(gradient, 0, previousGradient, 0, gradient.length);
System.arraycopy(direction, 0, previousDirection, 0, direction.length);
}
// public boolean optimize(Objective o,OptimizerStats stats, StopingCriteria stop){
// DifferentiableLineSearchObjective lso = new DifferentiableLineSearchObjective(o);
// stats.collectInitStats(this, o);
// direction = new double[o.getNumParameters()];
// initializeStructures(o, stats, stop);
// for (currentProjectionIteration = 0; currentProjectionIteration < maxNumberOfIterations; currentProjectionIteration++){
// previousValue = currValue;
// currValue = o.getValue();
// gradient =o.getGradient();
// if(stop.stopOptimization(gradient)){
// stats.collectFinalStats(this, o);
// return true;
// }
// getDirection();
// updateStructures(o, stats, stop);
// lso.reset(direction);
// step = lineSearch.getStepSize(lso);
// if(step==-1){
// System.out.println("Failed to find a step size");
// System.out.println("Failed to find step");
// stats.collectFinalStats(this, o);
// return false;
// }
//
// stats.collectIterationStats(this, o);
// }
// stats.collectFinalStats(this, o);
// return false;
// }
public double[] getDirection(){
direction = MathUtils.negation(gradient);
if(currentProjectionIteration != 1){
//Using Polak-Ribiere method (book equation 5.45)
double b = MathUtils.dotProduct(gradient, MathUtils.arrayMinus(gradient, previousGradient))
/MathUtils.dotProduct(previousGradient, previousGradient);
if(b<0){
System.out.println("Defaulting to gradient descent");
b = Math.max(b, 0);
}
MathUtils.plusEquals(direction, previousDirection, b);
//Debug code
if(MathUtils.dotProduct(direction, gradient) > 0){
System.out.println("Not an descent direction reseting to gradien");
direction = MathUtils.negation(gradient);
}
}
return direction;
}
}
|