summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/optimization/examples/GeneralizedRosenbrock.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/optimization/examples/GeneralizedRosenbrock.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/optimization/examples/GeneralizedRosenbrock.java110
1 files changed, 110 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/optimization/examples/GeneralizedRosenbrock.java b/gi/posterior-regularisation/prjava/src/optimization/examples/GeneralizedRosenbrock.java
new file mode 100644
index 00000000..25fa7f09
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/optimization/examples/GeneralizedRosenbrock.java
@@ -0,0 +1,110 @@
+package optimization.examples;
+
+
+import optimization.gradientBasedMethods.ConjugateGradient;
+import optimization.gradientBasedMethods.GradientDescent;
+import optimization.gradientBasedMethods.LBFGS;
+import optimization.gradientBasedMethods.Objective;
+import optimization.gradientBasedMethods.Optimizer;
+import optimization.gradientBasedMethods.stats.OptimizerStats;
+import optimization.linesearch.ArmijoLineSearchMinimization;
+import optimization.linesearch.LineSearchMethod;
+import optimization.stopCriteria.GradientL2Norm;
+import optimization.stopCriteria.StopingCriteria;
+import optimization.util.MathUtils;
+
+/**
+ *
+ * @author javg
+ * f(x) = \sum_{i=1}^{N-1} \left[ (1-x_i)^2+ 100 (x_{i+1} - x_i^2 )^2 \right] \quad \forall x\in\mathbb{R}^N.
+ */
+public class GeneralizedRosenbrock extends Objective{
+
+
+
+ public GeneralizedRosenbrock(int dimensions){
+ parameters = new double[dimensions];
+ java.util.Arrays.fill(parameters, 0);
+ gradient = new double[dimensions];
+
+ }
+
+ public GeneralizedRosenbrock(int dimensions, double[] params){
+ parameters = params;
+ gradient = new double[dimensions];
+ }
+
+
+ public double getValue() {
+ functionCalls++;
+ double value = 0;
+ for(int i = 0; i < parameters.length-1; i++){
+ value += MathUtils.square(1-parameters[i]) + 100*MathUtils.square(parameters[i+1] - MathUtils.square(parameters[i]));
+ }
+
+ return value;
+ }
+
+ /**
+ * gx = -2(1-x) -2x200(y-x^2)
+ * gy = 200(y-x^2)
+ */
+ public double[] getGradient() {
+ gradientCalls++;
+ java.util.Arrays.fill(gradient,0);
+ for(int i = 0; i < parameters.length-1; i++){
+ gradient[i]+=-2*(1-parameters[i]) - 400*parameters[i]*(parameters[i+1] - MathUtils.square(parameters[i]));
+ gradient[i+1]+=200*(parameters[i+1] - MathUtils.square(parameters[i]));
+ }
+ return gradient;
+ }
+
+
+
+
+
+
+
+ public String toString(){
+ String res ="";
+ for(int i = 0; i < parameters.length; i++){
+ res += "P" + i+ " " + parameters[i];
+ }
+ res += " Value " + getValue();
+ return res;
+ }
+
+ public static void main(String[] args) {
+
+ GeneralizedRosenbrock o = new GeneralizedRosenbrock(2);
+ System.out.println("Starting optimization " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1]);
+ ;
+
+ System.out.println("Doing Gradient descent");
+ //LineSearchMethod wolfe = new WolfRuleLineSearch(new InterpolationPickFirstStep(1),100,0.001,0.1);
+ StopingCriteria stop = new GradientL2Norm(0.001);
+ LineSearchMethod ls = new ArmijoLineSearchMinimization();
+ Optimizer optimizer = new GradientDescent(ls);
+ OptimizerStats stats = new OptimizerStats();
+ optimizer.setMaxIterations(1000);
+ boolean succed = optimizer.optimize(o,stats, stop);
+ System.out.println("Suceess " + succed + "/n"+stats.prettyPrint(1));
+ System.out.println("Doing Conjugate Gradient descent");
+ o = new GeneralizedRosenbrock(2);
+ // wolfe = new WolfRuleLineSearch(new InterpolationPickFirstStep(1),100,0.001,0.1);
+ optimizer = new ConjugateGradient(ls);
+ stats = new OptimizerStats();
+ optimizer.setMaxIterations(1000);
+ succed = optimizer.optimize(o,stats,stop);
+ System.out.println("Suceess " + succed + "/n"+stats.prettyPrint(1));
+ System.out.println("Doing Quasi newton descent");
+ o = new GeneralizedRosenbrock(2);
+ optimizer = new LBFGS(ls,10);
+ stats = new OptimizerStats();
+ optimizer.setMaxIterations(1000);
+ succed = optimizer.optimize(o,stats,stop);
+ System.out.println("Suceess " + succed + "/n"+stats.prettyPrint(1));
+
+ }
+
+}