diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/optimization/examples/GeneralizedRosenbrock.java')
-rw-r--r-- | gi/posterior-regularisation/prjava/src/optimization/examples/GeneralizedRosenbrock.java | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/optimization/examples/GeneralizedRosenbrock.java b/gi/posterior-regularisation/prjava/src/optimization/examples/GeneralizedRosenbrock.java new file mode 100644 index 00000000..25fa7f09 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/optimization/examples/GeneralizedRosenbrock.java @@ -0,0 +1,110 @@ +package optimization.examples; + + +import optimization.gradientBasedMethods.ConjugateGradient; +import optimization.gradientBasedMethods.GradientDescent; +import optimization.gradientBasedMethods.LBFGS; +import optimization.gradientBasedMethods.Objective; +import optimization.gradientBasedMethods.Optimizer; +import optimization.gradientBasedMethods.stats.OptimizerStats; +import optimization.linesearch.ArmijoLineSearchMinimization; +import optimization.linesearch.LineSearchMethod; +import optimization.stopCriteria.GradientL2Norm; +import optimization.stopCriteria.StopingCriteria; +import optimization.util.MathUtils; + +/** + * + * @author javg + * f(x) = \sum_{i=1}^{N-1} \left[ (1-x_i)^2+ 100 (x_{i+1} - x_i^2 )^2 \right] \quad \forall x\in\mathbb{R}^N. + */ +public class GeneralizedRosenbrock extends Objective{ + + + + public GeneralizedRosenbrock(int dimensions){ + parameters = new double[dimensions]; + java.util.Arrays.fill(parameters, 0); + gradient = new double[dimensions]; + + } + + public GeneralizedRosenbrock(int dimensions, double[] params){ + parameters = params; + gradient = new double[dimensions]; + } + + + public double getValue() { + functionCalls++; + double value = 0; + for(int i = 0; i < parameters.length-1; i++){ + value += MathUtils.square(1-parameters[i]) + 100*MathUtils.square(parameters[i+1] - MathUtils.square(parameters[i])); + } + + return value; + } + + /** + * gx = -2(1-x) -2x200(y-x^2) + * gy = 200(y-x^2) + */ + public double[] getGradient() { + gradientCalls++; + java.util.Arrays.fill(gradient,0); + for(int i = 0; i < parameters.length-1; i++){ + gradient[i]+=-2*(1-parameters[i]) - 400*parameters[i]*(parameters[i+1] - MathUtils.square(parameters[i])); + gradient[i+1]+=200*(parameters[i+1] - MathUtils.square(parameters[i])); + } + return gradient; + } + + + + + + + + public String toString(){ + String res =""; + for(int i = 0; i < parameters.length; i++){ + res += "P" + i+ " " + parameters[i]; + } + res += " Value " + getValue(); + return res; + } + + public static void main(String[] args) { + + GeneralizedRosenbrock o = new GeneralizedRosenbrock(2); + System.out.println("Starting optimization " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1]); + ; + + System.out.println("Doing Gradient descent"); + //LineSearchMethod wolfe = new WolfRuleLineSearch(new InterpolationPickFirstStep(1),100,0.001,0.1); + StopingCriteria stop = new GradientL2Norm(0.001); + LineSearchMethod ls = new ArmijoLineSearchMinimization(); + Optimizer optimizer = new GradientDescent(ls); + OptimizerStats stats = new OptimizerStats(); + optimizer.setMaxIterations(1000); + boolean succed = optimizer.optimize(o,stats, stop); + System.out.println("Suceess " + succed + "/n"+stats.prettyPrint(1)); + System.out.println("Doing Conjugate Gradient descent"); + o = new GeneralizedRosenbrock(2); + // wolfe = new WolfRuleLineSearch(new InterpolationPickFirstStep(1),100,0.001,0.1); + optimizer = new ConjugateGradient(ls); + stats = new OptimizerStats(); + optimizer.setMaxIterations(1000); + succed = optimizer.optimize(o,stats,stop); + System.out.println("Suceess " + succed + "/n"+stats.prettyPrint(1)); + System.out.println("Doing Quasi newton descent"); + o = new GeneralizedRosenbrock(2); + optimizer = new LBFGS(ls,10); + stats = new OptimizerStats(); + optimizer.setMaxIterations(1000); + succed = optimizer.optimize(o,stats,stop); + System.out.println("Suceess " + succed + "/n"+stats.prettyPrint(1)); + + } + +} |