summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/AbstractGradientBaseMethod.java
blob: 0a4a54456cea54b7baa3a06da8953402a564bd71 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package optimization.gradientBasedMethods;

import optimization.gradientBasedMethods.stats.OptimizerStats;
import optimization.linesearch.DifferentiableLineSearchObjective;
import optimization.linesearch.LineSearchMethod;
import optimization.stopCriteria.StopingCriteria;
import optimization.util.MathUtils;

/**
 * 
 * @author javg
 *
 */
public abstract class AbstractGradientBaseMethod implements Optimizer{
	
	protected int maxNumberOfIterations=10000;
	
	
	
	protected int currentProjectionIteration;
	protected double currValue;	
	protected double previousValue = Double.MAX_VALUE;;
	protected double step;
	protected double[] gradient;
	public double[] direction;
	
	//Original values
	protected double originalGradientL2Norm;
	
	protected LineSearchMethod lineSearch;
	DifferentiableLineSearchObjective lso;
	
	
	public void reset(){
		direction = null;
		gradient = null;
		previousValue = Double.MAX_VALUE;
		currentProjectionIteration = 0;
		originalGradientL2Norm = 0;
		step = 0;
		currValue = 0;
	}
	
	public void initializeStructures(Objective o,OptimizerStats stats, StopingCriteria stop){
		lso =   new DifferentiableLineSearchObjective(o);
	}
	public void updateStructuresBeforeStep(Objective o,OptimizerStats stats, StopingCriteria stop){
	}
	
	public void updateStructuresAfterStep(Objective o,OptimizerStats stats, StopingCriteria stop){
	}
	
	public boolean optimize(Objective o,OptimizerStats stats, StopingCriteria stop){
		//Initialize structures
			
		stats.collectInitStats(this, o);
		direction = new double[o.getNumParameters()];
		initializeStructures(o, stats, stop);
		for (currentProjectionIteration = 1; currentProjectionIteration < maxNumberOfIterations; currentProjectionIteration++){		
//			System.out.println("starting iterations: parameters:" );
//			o.printParameters();
			previousValue = currValue;
			currValue = o.getValue();
			gradient = o.getGradient();
			if(stop.stopOptimization(o)){
				stats.collectFinalStats(this, o);
				return true;
			}	
			
			getDirection();
			if(MathUtils.dotProduct(gradient, direction) > 0){
				System.out.println("Not a descent direction");
				System.out.println(" current stats " + stats.prettyPrint(1));
				System.exit(-1);
			}
			updateStructuresBeforeStep(o, stats, stop);
			lso.reset(direction);
			step = lineSearch.getStepSize(lso);
//			System.out.println("Leave with step: " + step);
			if(step==-1){
				System.out.println("Failed to find step");
				stats.collectFinalStats(this, o);
				return false;		
			}
			updateStructuresAfterStep( o, stats,  stop);
//			previousValue = currValue;
//			currValue = o.getValue();
//			gradient = o.getGradient();
			stats.collectIterationStats(this, o);
		}
		stats.collectFinalStats(this, o);
		return false;
	}
	
	
	public int getCurrentIteration() {
		return currentProjectionIteration;
	}

	
	/**
	 * Method specific
	 */
	public abstract double[] getDirection();

	public double getCurrentStep() {
		return step;
	}



	public void setMaxIterations(int max) {
		maxNumberOfIterations = max;
	}

	public double getCurrentValue() {
		return currValue;
	}
}