summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/test/X2y2WithConstraints.java
blob: 9059a59e534e6bd44bbda602daad3b0837cdc7d9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package test;



import optimization.gradientBasedMethods.ProjectedGradientDescent;
import optimization.gradientBasedMethods.ProjectedObjective;
import optimization.gradientBasedMethods.stats.OptimizerStats;
import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;
import optimization.linesearch.InterpolationPickFirstStep;
import optimization.linesearch.LineSearchMethod;
import optimization.projections.BoundsProjection;
import optimization.projections.Projection;
import optimization.projections.SimplexProjection;
import optimization.stopCriteria.CompositeStopingCriteria;
import optimization.stopCriteria.GradientL2Norm;
import optimization.stopCriteria.ProjectedGradientL2Norm;
import optimization.stopCriteria.StopingCriteria;
import optimization.stopCriteria.ValueDifference;


/**
 * @author javg
 * 
 * 
 *ax2+ b(y2 -displacement)
 */
public class X2y2WithConstraints extends ProjectedObjective{


	double a, b;
	double dx;
	double dy;
	Projection projection;
	
	
	public X2y2WithConstraints(double a, double b, double[] params, double dx, double dy, Projection proj){
		//projection = new BoundsProjection(0.2,Double.MAX_VALUE);
		super();
		projection = proj;	
		this.a = a;
		this.b = b;
		this.dx = dx;
		this.dy = dy;
		setInitialParameters(params);
		System.out.println("Function " +a+"(x-"+dx+")^2 + "+b+"(y-"+dy+")^2");
		System.out.println("Gradient " +(2*a)+"(x-"+dx+") ; "+(b*2)+"(y-"+dy+")");
		printParameters();
		projection.project(parameters);
		printParameters();
		gradient = new double[2];
	}
	
	public double getValue() {
		functionCalls++;
		return a*(parameters[0]-dx)*(parameters[0]-dx)+b*((parameters[1]-dy)*(parameters[1]-dy));
	}

	public double[] getGradient() {
		if(gradient == null){
			gradient = new double[2];
		}
		gradientCalls++;
		gradient[0]=2*a*(parameters[0]-dx);
		gradient[1]=2*b*(parameters[1]-dy);
		return gradient;
	}
	
	
	public double[] projectPoint(double[] point) {
		double[] newPoint = point.clone();
		projection.project(newPoint);
		return newPoint;
	}	
	
	public void optimizeWithProjectedGradientDescent(LineSearchMethod ls, OptimizerStats stats, X2y2WithConstraints o){
		ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
		StopingCriteria stopGrad = new ProjectedGradientL2Norm(0.001);
		StopingCriteria stopValue = new ValueDifference(0.001);
		CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
		compositeStop.add(stopGrad);
		compositeStop.add(stopValue);
		
		optimizer.setMaxIterations(5);
		boolean succed = optimizer.optimize(o,stats,compositeStop);
		System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
		System.out.println("Solution: " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1]);
		if(succed){
			System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
		}else{
			System.out.println("Failed to optimize");
		}
	}
	
	
	
	public String toString(){
		
		return "P1: " + parameters[0] + " P2: " + parameters[1] + " value " + getValue() + " grad (" + getGradient()[0] + ":" + getGradient()[1]+")";
	}
	
	public static void main(String[] args) {
		double a = 1;
		double b=1;
		double x0 = 0;
		double y0  =1;
		double dx = 0.5;
		double dy = 0.2	;
		double [] parameters = new double[2];
		parameters[0] = x0;
		parameters[1] = y0;
		X2y2WithConstraints o = new X2y2WithConstraints(a,b,parameters,dx,dy, 
				new SimplexProjection(0.5)
				//new BoundsProjection(0.0,0.4)
		);
		System.out.println("Starting optimization " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1] + " a " + a + " b "+b );
		o.setDebugLevel(4);
		
		LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc(new InterpolationPickFirstStep(1));
		
		OptimizerStats stats = new OptimizerStats();
		o.optimizeWithProjectedGradientDescent(ls, stats, o);
		
//		o = new x2y2WithConstraints(a,b,x0,y0,dx,dy);
//		stats = new OptimizerStats();
//		o.optimizeWithSpectralProjectedGradientDescent(stats, o);
	}
	
	
	
	
}