summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/optimization/linesearch/ArmijoLineSearchMinimizationAlongProjectionArc.java
blob: e153f2dade00cb34efa24b1e5f17d93b268c53f3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package optimization.linesearch;

import optimization.gradientBasedMethods.ProjectedObjective;
import optimization.util.Interpolation;
import optimization.util.MathUtils;





/**
 * Implements Armijo Rule Line search along the projection arc (Non-Linear Programming page 230)
 * To be used with Projected gradient Methods.
 * 
 * Recall that armijo tries successive step sizes alpha until the sufficient decrease is satisfied:
 * f(x+alpha*direction) < f(x) + alpha*c1*grad(f)*direction
 * 
 * In this case we are optimizing over a convex set X so we must guarantee that the new point stays inside the 
 * constraints.
 * First the direction as to be feasible (inside constraints) and will be define as:
 * d = (x_k_f - x_k) where x_k_f is a feasible point.
 * so the armijo condition can be rewritten as:
 * f(x+alpha(x_k_f - x_k)) < f(x) + c1*grad(f)*(x_k_f - x_k)
 * and x_k_f is defined as:
 * [x_k-alpha*grad(f)]+
 * where []+ mean a projection to the feasibility set.
 * So this means that we take a step on the negative gradient (gradient descent) and then obtain then project
 * that point to the feasibility set. 
 * Note that if the point is already feasible then we are back to the normal armijo rule.
 * 
 * @author javg
 *
 */
public class ArmijoLineSearchMinimizationAlongProjectionArc implements LineSearchMethod{

	/**
	 * How much should the step size decrease at each iteration.
	 */
	double contractionFactor = 0.5;
	double c1 = 0.0001;
	
	
	double initialStep;
	int maxIterations = 100;
			
	
	double sigma1 = 0.1;
	double sigma2 = 0.9;
	
	//Experiment
	double previousStepPicked = -1;;
	double previousInitGradientDot = -1;
	double currentInitGradientDot = -1;
	
	GenericPickFirstStep strategy;
	
	
	public void reset(){
		previousStepPicked = -1;;
		previousInitGradientDot = -1;
		currentInitGradientDot = -1;
	}

	
	public ArmijoLineSearchMinimizationAlongProjectionArc(){
		this.initialStep = 1;
	}
	
	public ArmijoLineSearchMinimizationAlongProjectionArc(GenericPickFirstStep strategy){
		this.strategy = strategy;
		this.initialStep = strategy.getFirstStep(this);
	}
	
	
	public void setInitialStep(double initial){
		this.initialStep = initial;
	}
	
	/**
	 * 
	 */
	
	public double getStepSize(DifferentiableLineSearchObjective o) {	

		
		//Should update all in the objective
		initialStep = strategy.getFirstStep(this);
		o.updateAlpha(initialStep);	
		previousInitGradientDot=currentInitGradientDot;
		currentInitGradientDot=o.getCurrentGradient();
		int nrIterations = 0;
	
		//Armijo rule, the current value has to be smaller than the original value plus a small step of the gradient
		while(o.getCurrentValue()  >
			o.getOriginalValue() + c1*(o.getCurrentGradient())){			
//			System.out.println("curr value "+o.getCurrentValue());
//			System.out.println("original value "+o.getOriginalValue());
//			System.out.println("GRADIENT decrease" +(MathUtils.dotProduct(o.o.gradient,
//					MathUtils.arrayMinus(o.originalParameters,((ProjectedObjective)o.o).auxParameters))));
//			System.out.println("GRADIENT SAVED" + o.getCurrentGradient());
			if(nrIterations >= maxIterations){
				System.out.println("Could not find a step leaving line search with -1");
				o.printLineSearchSteps();
				return -1;
			}
			double alpha=o.getAlpha();
			double alphaTemp = 
				Interpolation.quadraticInterpolation(o.getOriginalValue(), o.getInitialGradient(), alpha, o.getCurrentValue());
			if(alphaTemp >= sigma1 || alphaTemp <= sigma2*o.getAlpha()){
				alpha = alphaTemp;
			}else{
				alpha = alpha*contractionFactor;
			}
//			double alpha =obj.getAlpha()*contractionFactor;
			o.updateAlpha(alpha);
			nrIterations++;			
		}
//		System.out.println("curr value "+o.getCurrentValue());
//		System.out.println("original value "+o.getOriginalValue());
//		System.out.println("sufficient decrease" +c1*o.getCurrentGradient());
//		System.out.println("Leavning line search used:");
//		o.printSmallLineSearchSteps();	
		
		previousStepPicked = o.getAlpha();
		return o.getAlpha();
	}
	
	public double getInitialGradient() {
		return currentInitGradientDot;
		
	}

	public double getPreviousInitialGradient() {
		return previousInitGradientDot;
	}

	public double getPreviousStepUsed() {
		return previousStepPicked;
	}
		
}