summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/optimization/linesearch/ArmijoLineSearchMinimization.java
blob: c9f9b8dfefd8077cfe3e6934c34c35ac891c80de (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
package optimization.linesearch;

import optimization.util.Interpolation;


/**
 * Implements Back Tracking Line Search as described on page 37 of Numerical Optimization.
 * Also known as armijo rule
 * @author javg
 *
 */
public class ArmijoLineSearchMinimization implements LineSearchMethod{

	/**
	 * How much should the step size decrease at each iteration.
	 */
	double contractionFactor = 0.5;
	double c1 = 0.0001;
	
	double sigma1 = 0.1;
	double sigma2 = 0.9;


	
	double initialStep;
	int maxIterations = 10;
	
			
	public ArmijoLineSearchMinimization(){
		this.initialStep = 1;
	}
	
	//Experiment
	double previousStepPicked = -1;;
	double previousInitGradientDot = -1;
	double currentInitGradientDot = -1;
	
	
	public void reset(){
		previousStepPicked = -1;;
		previousInitGradientDot = -1;
		currentInitGradientDot = -1;
	}
	
	public void setInitialStep(double initial){
		initialStep = initial;
	}
	
	/**
	 * 
	 */
	
	public double getStepSize(DifferentiableLineSearchObjective o) {	
		currentInitGradientDot = o.getInitialGradient();
		//Should update all in the objective
		o.updateAlpha(initialStep);	
		int nrIterations = 0;
		//System.out.println("tried alpha" + initialStep + " value " + o.getCurrentValue());
		while(!WolfeConditions.suficientDecrease(o,c1)){			
			if(nrIterations >= maxIterations){
				o.printLineSearchSteps();	
				return -1;
			}
			double alpha=o.getAlpha();
			double alphaTemp = 
				Interpolation.quadraticInterpolation(o.getOriginalValue(), o.getInitialGradient(), alpha, o.getCurrentValue());
			if(alphaTemp >= sigma1 || alphaTemp <= sigma2*o.getAlpha()){
//				System.out.println("using alpha temp " + alphaTemp);
				alpha = alphaTemp;
			}else{
//				System.out.println("Discarding alpha temp " + alphaTemp);
				alpha = alpha*contractionFactor;
			}
//			double alpha =o.getAlpha()*contractionFactor;

			o.updateAlpha(alpha);
			//System.out.println("tried alpha" + alpha+ " value " + o.getCurrentValue());
			nrIterations++;			
		}
		
		//System.out.println("Leavning line search used:");
		//o.printLineSearchSteps();	
		
		previousInitGradientDot = currentInitGradientDot;
		previousStepPicked = o.getAlpha();
		return o.getAlpha();
	}

	public double getInitialGradient() {
		return currentInitGradientDot;
		
	}

	public double getPreviousInitialGradient() {
		return previousInitGradientDot;
	}

	public double getPreviousStepUsed() {
		return previousStepPicked;
	}
		
}