summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/optimization/linesearch/DifferentiableLineSearchObjective.java
blob: a5bc958e46eab1c4304d1793025ea9b4c7fa2256 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
package optimization.linesearch;

import gnu.trove.TDoubleArrayList;
import gnu.trove.TIntArrayList;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;

import optimization.gradientBasedMethods.Objective;
import optimization.util.MathUtils;
import optimization.util.StaticTools;



import util.MathUtil;
import util.Printing;


/**
 * A wrapper class for the actual objective in order to perform 
 * line search.  The optimization code assumes that this does a lot 
 * of caching in order to simplify legibility.  For the applications 
 * we use it for, caching the entire history of evaluations should be 
 * a win. 
 * 
 * Note: the lastEvaluatedAt value is very important, since we will use
 * it to avoid doing an evaluation of the gradient after the line search.  
 * 
 * The differentiable line search objective defines a search along the ray
 * given by a direction of the main objective.
 * It defines the following function, 
 * where f is the original objective function:
 * g(alpha) = f(x_0 + alpha*direction)
 * g'(alpha) = f'(x_0 + alpha*direction)*direction
 * 
 * @author joao
 *
 */
public class DifferentiableLineSearchObjective {

	
	
	Objective o;
	int nrIterations;
	TDoubleArrayList steps;
	TDoubleArrayList values;
	TDoubleArrayList gradients;
	
	//This variables cannot change
	public double[] originalParameters;
	public double[] searchDirection;

	
	/**
	 * Defines a line search objective:
	 * Receives:
	 * Objective to each we are performing the line search, is used to calculate values and gradients
	 * Direction where to do the ray search, note that the direction does not depend of the 
	 * objective but depends from the method.
	 * @param o
	 * @param direction
	 */
	public DifferentiableLineSearchObjective(Objective o) {
		this.o = o;
		originalParameters = new double[o.getNumParameters()];
		searchDirection = new double[o.getNumParameters()];
		steps = new TDoubleArrayList();
		values = new TDoubleArrayList();
		gradients = new TDoubleArrayList();
	}
	/**
	 * Called whenever we start a new iteration. 
	 * Receives the ray where we are searching for and resets all values
	 * 
	 */
	public void reset(double[] direction){
		//Copy initial values
		System.arraycopy(o.getParameters(), 0, originalParameters, 0, o.getNumParameters());
		System.arraycopy(direction, 0, searchDirection, 0, o.getNumParameters());
		
		//Initialize variables
		nrIterations = 0;
		steps.clear();
		values.clear();
		gradients.clear();
	
		values.add(o.getValue());
		gradients.add(MathUtils.dotProduct(o.getGradient(),direction));	
		steps.add(0);
	}
	
	
	/**
	 * update the current value of alpha.
	 * Takes a step with that alpha in direction
	 * Get the real objective value and gradient and calculate all required information.
	 */
	public void updateAlpha(double alpha){
		if(alpha < 0){
			System.out.println("alpha may not be smaller that zero");
			throw new RuntimeException();
		}
		nrIterations++;
		steps.add(alpha);
		//x_t+1 = x_t + alpha*direction
		System.arraycopy(originalParameters,0, o.getParameters(), 0, originalParameters.length);
		MathUtils.plusEquals(o.getParameters(), searchDirection, alpha);
		o.setParameters(o.getParameters());
//		System.out.println("Took a step of " + alpha + " new value " + o.getValue());
		values.add(o.getValue());
		gradients.add(MathUtils.dotProduct(o.getGradient(),searchDirection));		
	}

	
	
	public int getNrIterations(){
		return nrIterations;
	}
	
	/**
	 * return g(alpha) for the current value of alpha
	 * @param iter
	 * @return
	 */
	public double getValue(int iter){
		return values.get(iter);
	}
	
	public double getCurrentValue(){
		return values.get(nrIterations);
	}
	
	public double getOriginalValue(){
		return values.get(0);
	}

	/**
	 * return g'(alpha) for the current value of alpha
	 * @param iter
	 * @return
	 */
	public double getGradient(int iter){
		return gradients.get(iter);
	}
	
	public double getCurrentGradient(){
		return gradients.get(nrIterations);
	}
	
	public double getInitialGradient(){
		return gradients.get(0);
	}
	
	
	
	
	public double getAlpha(){
		return steps.get(nrIterations);
	}
	
	public void printLineSearchSteps(){
		System.out.println(
				" Steps size "+steps.size() + 
				"Values size "+values.size() +
				"Gradeients size "+gradients.size());
		for(int i =0; i < steps.size();i++){
			System.out.println("Iter " + i + " step " + steps.get(i) +
					" value " + values.get(i) + " grad "  + gradients.get(i));
		}
	}
	
	public void printSmallLineSearchSteps(){
		for(int i =0; i < steps.size();i++){
			System.out.print(StaticTools.prettyPrint(steps.get(i), "0.0000E00",8) + " ");
		}
		System.out.println();
	}
	
	public static void main(String[] args) {
		
	}
	
}