summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/optimization/examples/x2y2WithConstraints.java
blob: 391775b717947cef048f00cd6b72d2ad557ffe77 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package optimization.examples;


import optimization.gradientBasedMethods.ProjectedGradientDescent;
import optimization.gradientBasedMethods.ProjectedObjective;
import optimization.gradientBasedMethods.stats.OptimizerStats;
import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;
import optimization.linesearch.InterpolationPickFirstStep;
import optimization.linesearch.LineSearchMethod;
import optimization.projections.BoundsProjection;
import optimization.projections.Projection;
import optimization.projections.SimplexProjection;
import optimization.stopCriteria.CompositeStopingCriteria;
import optimization.stopCriteria.GradientL2Norm;
import optimization.stopCriteria.ProjectedGradientL2Norm;
import optimization.stopCriteria.StopingCriteria;
import optimization.stopCriteria.ValueDifference;


/**
 * @author javg
 * 
 * 
 *ax2+ b(y2 -displacement)
 */
public class x2y2WithConstraints extends ProjectedObjective{


	double a, b;
	double dx;
	double dy;
	Projection projection;
	
	
	public x2y2WithConstraints(double a, double b, double[] params, double dx, double dy, Projection proj){
		//projection = new BoundsProjection(0.2,Double.MAX_VALUE);
		super();
		projection = proj;	
		this.a = a;
		this.b = b;
		this.dx = dx;
		this.dy = dy;
		setInitialParameters(params);
		System.out.println("Function " +a+"(x-"+dx+")^2 + "+b+"(y-"+dy+")^2");
		System.out.println("Gradient " +(2*a)+"(x-"+dx+") ; "+(b*2)+"(y-"+dy+")");
		printParameters();
		projection.project(parameters);
		printParameters();
		gradient = new double[2];
	}
	
	public double getValue() {
		functionCalls++;
		return a*(parameters[0]-dx)*(parameters[0]-dx)+b*((parameters[1]-dy)*(parameters[1]-dy));
	}

	public double[] getGradient() {
		if(gradient == null){
			gradient = new double[2];
		}
		gradientCalls++;
		gradient[0]=2*a*(parameters[0]-dx);
		gradient[1]=2*b*(parameters[1]-dy);
		return gradient;
	}
	
	
	public double[] projectPoint(double[] point) {
		double[] newPoint = point.clone();
		projection.project(newPoint);
		return newPoint;
	}	
	
	public void optimizeWithProjectedGradientDescent(LineSearchMethod ls, OptimizerStats stats, x2y2WithConstraints o){
		ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
		StopingCriteria stopGrad = new ProjectedGradientL2Norm(0.001);
		StopingCriteria stopValue = new ValueDifference(0.001);
		CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
		compositeStop.add(stopGrad);
		compositeStop.add(stopValue);
		
		optimizer.setMaxIterations(5);
		boolean succed = optimizer.optimize(o,stats,compositeStop);
		System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
		System.out.println("Solution: " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1]);
		if(succed){
			System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
		}else{
			System.out.println("Failed to optimize");
		}
	}
	
	
	
	public String toString(){
		
		return "P1: " + parameters[0] + " P2: " + parameters[1] + " value " + getValue() + " grad (" + getGradient()[0] + ":" + getGradient()[1]+")";
	}
	
	public static void main(String[] args) {
		double a = 1;
		double b=1;
		double x0 = 0;
		double y0  =1;
		double dx = 0.5;
		double dy = 0.5	;
		double [] parameters = new double[2];
		parameters[0] = x0;
		parameters[1] = y0;
		x2y2WithConstraints o = new x2y2WithConstraints(a,b,parameters,dx,dy, new SimplexProjection(0.5));
		System.out.println("Starting optimization " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1] + " a " + a + " b "+b );
		o.setDebugLevel(4);
		
		LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc(new InterpolationPickFirstStep(1));
		
		OptimizerStats stats = new OptimizerStats();
		o.optimizeWithProjectedGradientDescent(ls, stats, o);
		
//		o = new x2y2WithConstraints(a,b,x0,y0,dx,dy);
//		stats = new OptimizerStats();
//		o.optimizeWithSpectralProjectedGradientDescent(stats, o);
	}
	
	
	
	
}