summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/optimization/examples/x2y2.java
blob: f087681ea0f0031f4191501c007c0983ed4d1b40 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package optimization.examples;


import optimization.gradientBasedMethods.ConjugateGradient;

import optimization.gradientBasedMethods.GradientDescent;
import optimization.gradientBasedMethods.LBFGS;
import optimization.gradientBasedMethods.Objective;
import optimization.gradientBasedMethods.stats.OptimizerStats;
import optimization.linesearch.GenericPickFirstStep;
import optimization.linesearch.LineSearchMethod;
import optimization.linesearch.WolfRuleLineSearch;
import optimization.stopCriteria.GradientL2Norm;
import optimization.stopCriteria.StopingCriteria;


/**
 * @author javg
 *
 */
public class x2y2 extends Objective{

	
	//Implements function ax2+ by2 
	double a, b;
	public x2y2(double a, double b){
		this.a = a;
		this.b = b;
		parameters = new double[2];
		parameters[0] = 4;
		parameters[1] = 4;
		gradient = new double[2];
	}
	
	public double getValue() {
		functionCalls++;
		return a*parameters[0]*parameters[0]+b*parameters[1]*parameters[1];
	}

	public double[] getGradient() {
		gradientCalls++;
		gradient[0]=2*a*parameters[0];
		gradient[1]=2*b*parameters[1];
		return gradient;
//		if(debugLevel >=2){
//			double[] numericalGradient = DebugHelpers.getNumericalGradient(this, parameters, 0.000001);
//			for(int i = 0; i < parameters.length; i++){
//				double diff = Math.abs(gradient[i]-numericalGradient[i]);
//				if(diff > 0.00001){
//					System.out.println("Numerical Gradient does not match");
//					System.exit(1);
//				}
//			}
//		}
	}

	
	
	public void optimizeWithGradientDescent(LineSearchMethod ls, OptimizerStats stats, x2y2 o){
		GradientDescent optimizer = new GradientDescent(ls);
		StopingCriteria stop = new GradientL2Norm(0.001);
//		optimizer.setGradientConvergenceValue(0.001);
		optimizer.setMaxIterations(100);
		boolean succed = optimizer.optimize(o,stats,stop);
		System.out.println("Ended optimzation Gradient Descent\n" + stats.prettyPrint(1));
		System.out.println("Solution: " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1]);
		if(succed){
			System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
		}else{
			System.out.println("Failed to optimize");
		}
	}
	
	public void optimizeWithConjugateGradient(LineSearchMethod ls, OptimizerStats stats, x2y2 o){
		ConjugateGradient optimizer = new ConjugateGradient(ls);
		StopingCriteria stop = new GradientL2Norm(0.001);

		optimizer.setMaxIterations(10);
		boolean succed = optimizer.optimize(o,stats,stop);
		System.out.println("Ended optimzation Conjugate Gradient\n" + stats.prettyPrint(1));
		System.out.println("Solution: " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1]);
		if(succed){
			System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
		}else{
			System.out.println("Failed to optimize");
		}
	}
	
	public void optimizeWithLBFGS(LineSearchMethod ls, OptimizerStats stats, x2y2 o){
		LBFGS optimizer = new LBFGS(ls,10);
		StopingCriteria stop = new GradientL2Norm(0.001);
		optimizer.setMaxIterations(10);
		boolean succed = optimizer.optimize(o,stats,stop);
		System.out.println("Ended optimzation LBFGS\n" + stats.prettyPrint(1));
		System.out.println("Solution: " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1]);
		if(succed){
			System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
		}else{
			System.out.println("Failed to optimize");
		}
	}
	
	public static void main(String[] args) {
		x2y2 o = new x2y2(1,10);
		System.out.println("Starting optimization " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1]);
		o.setDebugLevel(4);
		LineSearchMethod wolfe = new WolfRuleLineSearch(new GenericPickFirstStep(1),0.001,0.9);;
//		LineSearchMethod ls = new ArmijoLineSearchMinimization();
		OptimizerStats stats = new OptimizerStats();
		o.optimizeWithGradientDescent(wolfe, stats, o);
		o = new x2y2(1,10);
		System.out.println("Starting optimization " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1]);
//		ls = new ArmijoLineSearchMinimization();
		stats = new OptimizerStats();
		o.optimizeWithConjugateGradient(wolfe, stats, o);
		o = new x2y2(1,10);
		System.out.println("Starting optimization " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1]);
//		ls = new ArmijoLineSearchMinimization();
		stats = new OptimizerStats();
		o.optimizeWithLBFGS(wolfe, stats, o);	
	}
	
	public String toString(){
		return "P1: " + parameters[0] + " P2: " + parameters[1] + " value " + getValue();
	}
	
	
}