From bdea91300c85539ab7153ccba58689612f66bb4d Mon Sep 17 00:00:00 2001 From: desaicwtf Date: Fri, 9 Jul 2010 16:59:55 +0000 Subject: add optimization library source code git-svn-id: https://ws10smt.googlecode.com/svn/trunk@204 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../DifferentiableLineSearchObjective.java | 185 +++++++++++++++++++++ 1 file changed, 185 insertions(+) create mode 100644 gi/posterior-regularisation/prjava/src/optimization/linesearch/DifferentiableLineSearchObjective.java (limited to 'gi/posterior-regularisation/prjava/src/optimization/linesearch/DifferentiableLineSearchObjective.java') diff --git a/gi/posterior-regularisation/prjava/src/optimization/linesearch/DifferentiableLineSearchObjective.java b/gi/posterior-regularisation/prjava/src/optimization/linesearch/DifferentiableLineSearchObjective.java new file mode 100644 index 00000000..a5bc958e --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/optimization/linesearch/DifferentiableLineSearchObjective.java @@ -0,0 +1,185 @@ +package optimization.linesearch; + +import gnu.trove.TDoubleArrayList; +import gnu.trove.TIntArrayList; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; + +import optimization.gradientBasedMethods.Objective; +import optimization.util.MathUtils; +import optimization.util.StaticTools; + + + +import util.MathUtil; +import util.Printing; + + +/** + * A wrapper class for the actual objective in order to perform + * line search. The optimization code assumes that this does a lot + * of caching in order to simplify legibility. For the applications + * we use it for, caching the entire history of evaluations should be + * a win. + * + * Note: the lastEvaluatedAt value is very important, since we will use + * it to avoid doing an evaluation of the gradient after the line search. + * + * The differentiable line search objective defines a search along the ray + * given by a direction of the main objective. + * It defines the following function, + * where f is the original objective function: + * g(alpha) = f(x_0 + alpha*direction) + * g'(alpha) = f'(x_0 + alpha*direction)*direction + * + * @author joao + * + */ +public class DifferentiableLineSearchObjective { + + + + Objective o; + int nrIterations; + TDoubleArrayList steps; + TDoubleArrayList values; + TDoubleArrayList gradients; + + //This variables cannot change + public double[] originalParameters; + public double[] searchDirection; + + + /** + * Defines a line search objective: + * Receives: + * Objective to each we are performing the line search, is used to calculate values and gradients + * Direction where to do the ray search, note that the direction does not depend of the + * objective but depends from the method. + * @param o + * @param direction + */ + public DifferentiableLineSearchObjective(Objective o) { + this.o = o; + originalParameters = new double[o.getNumParameters()]; + searchDirection = new double[o.getNumParameters()]; + steps = new TDoubleArrayList(); + values = new TDoubleArrayList(); + gradients = new TDoubleArrayList(); + } + /** + * Called whenever we start a new iteration. + * Receives the ray where we are searching for and resets all values + * + */ + public void reset(double[] direction){ + //Copy initial values + System.arraycopy(o.getParameters(), 0, originalParameters, 0, o.getNumParameters()); + System.arraycopy(direction, 0, searchDirection, 0, o.getNumParameters()); + + //Initialize variables + nrIterations = 0; + steps.clear(); + values.clear(); + gradients.clear(); + + values.add(o.getValue()); + gradients.add(MathUtils.dotProduct(o.getGradient(),direction)); + steps.add(0); + } + + + /** + * update the current value of alpha. + * Takes a step with that alpha in direction + * Get the real objective value and gradient and calculate all required information. + */ + public void updateAlpha(double alpha){ + if(alpha < 0){ + System.out.println("alpha may not be smaller that zero"); + throw new RuntimeException(); + } + nrIterations++; + steps.add(alpha); + //x_t+1 = x_t + alpha*direction + System.arraycopy(originalParameters,0, o.getParameters(), 0, originalParameters.length); + MathUtils.plusEquals(o.getParameters(), searchDirection, alpha); + o.setParameters(o.getParameters()); +// System.out.println("Took a step of " + alpha + " new value " + o.getValue()); + values.add(o.getValue()); + gradients.add(MathUtils.dotProduct(o.getGradient(),searchDirection)); + } + + + + public int getNrIterations(){ + return nrIterations; + } + + /** + * return g(alpha) for the current value of alpha + * @param iter + * @return + */ + public double getValue(int iter){ + return values.get(iter); + } + + public double getCurrentValue(){ + return values.get(nrIterations); + } + + public double getOriginalValue(){ + return values.get(0); + } + + /** + * return g'(alpha) for the current value of alpha + * @param iter + * @return + */ + public double getGradient(int iter){ + return gradients.get(iter); + } + + public double getCurrentGradient(){ + return gradients.get(nrIterations); + } + + public double getInitialGradient(){ + return gradients.get(0); + } + + + + + public double getAlpha(){ + return steps.get(nrIterations); + } + + public void printLineSearchSteps(){ + System.out.println( + " Steps size "+steps.size() + + "Values size "+values.size() + + "Gradeients size "+gradients.size()); + for(int i =0; i < steps.size();i++){ + System.out.println("Iter " + i + " step " + steps.get(i) + + " value " + values.get(i) + " grad " + gradients.get(i)); + } + } + + public void printSmallLineSearchSteps(){ + for(int i =0; i < steps.size();i++){ + System.out.print(StaticTools.prettyPrint(steps.get(i), "0.0000E00",8) + " "); + } + System.out.println(); + } + + public static void main(String[] args) { + + } + +} -- cgit v1.2.3