summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/optimization/linesearch/DifferentiableLineSearchObjective.java
diff options
context:
space:
mode:
authordesaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-09 16:59:55 +0000
committerdesaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-09 16:59:55 +0000
commitbdea91300c85539ab7153ccba58689612f66bb4d (patch)
treee778ffa1ea4d04a239b58c6e6191c0d4549006f0 /gi/posterior-regularisation/prjava/src/optimization/linesearch/DifferentiableLineSearchObjective.java
parent0d1d84630a08f1c901cf09b4bcc9356c4165302f (diff)
add optimization library source code
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@204 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/optimization/linesearch/DifferentiableLineSearchObjective.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/optimization/linesearch/DifferentiableLineSearchObjective.java185
1 files changed, 185 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/optimization/linesearch/DifferentiableLineSearchObjective.java b/gi/posterior-regularisation/prjava/src/optimization/linesearch/DifferentiableLineSearchObjective.java
new file mode 100644
index 00000000..a5bc958e
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/optimization/linesearch/DifferentiableLineSearchObjective.java
@@ -0,0 +1,185 @@
+package optimization.linesearch;
+
+import gnu.trove.TDoubleArrayList;
+import gnu.trove.TIntArrayList;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+
+import optimization.gradientBasedMethods.Objective;
+import optimization.util.MathUtils;
+import optimization.util.StaticTools;
+
+
+
+import util.MathUtil;
+import util.Printing;
+
+
+/**
+ * A wrapper class for the actual objective in order to perform
+ * line search. The optimization code assumes that this does a lot
+ * of caching in order to simplify legibility. For the applications
+ * we use it for, caching the entire history of evaluations should be
+ * a win.
+ *
+ * Note: the lastEvaluatedAt value is very important, since we will use
+ * it to avoid doing an evaluation of the gradient after the line search.
+ *
+ * The differentiable line search objective defines a search along the ray
+ * given by a direction of the main objective.
+ * It defines the following function,
+ * where f is the original objective function:
+ * g(alpha) = f(x_0 + alpha*direction)
+ * g'(alpha) = f'(x_0 + alpha*direction)*direction
+ *
+ * @author joao
+ *
+ */
+public class DifferentiableLineSearchObjective {
+
+
+
+ Objective o;
+ int nrIterations;
+ TDoubleArrayList steps;
+ TDoubleArrayList values;
+ TDoubleArrayList gradients;
+
+ //This variables cannot change
+ public double[] originalParameters;
+ public double[] searchDirection;
+
+
+ /**
+ * Defines a line search objective:
+ * Receives:
+ * Objective to each we are performing the line search, is used to calculate values and gradients
+ * Direction where to do the ray search, note that the direction does not depend of the
+ * objective but depends from the method.
+ * @param o
+ * @param direction
+ */
+ public DifferentiableLineSearchObjective(Objective o) {
+ this.o = o;
+ originalParameters = new double[o.getNumParameters()];
+ searchDirection = new double[o.getNumParameters()];
+ steps = new TDoubleArrayList();
+ values = new TDoubleArrayList();
+ gradients = new TDoubleArrayList();
+ }
+ /**
+ * Called whenever we start a new iteration.
+ * Receives the ray where we are searching for and resets all values
+ *
+ */
+ public void reset(double[] direction){
+ //Copy initial values
+ System.arraycopy(o.getParameters(), 0, originalParameters, 0, o.getNumParameters());
+ System.arraycopy(direction, 0, searchDirection, 0, o.getNumParameters());
+
+ //Initialize variables
+ nrIterations = 0;
+ steps.clear();
+ values.clear();
+ gradients.clear();
+
+ values.add(o.getValue());
+ gradients.add(MathUtils.dotProduct(o.getGradient(),direction));
+ steps.add(0);
+ }
+
+
+ /**
+ * update the current value of alpha.
+ * Takes a step with that alpha in direction
+ * Get the real objective value and gradient and calculate all required information.
+ */
+ public void updateAlpha(double alpha){
+ if(alpha < 0){
+ System.out.println("alpha may not be smaller that zero");
+ throw new RuntimeException();
+ }
+ nrIterations++;
+ steps.add(alpha);
+ //x_t+1 = x_t + alpha*direction
+ System.arraycopy(originalParameters,0, o.getParameters(), 0, originalParameters.length);
+ MathUtils.plusEquals(o.getParameters(), searchDirection, alpha);
+ o.setParameters(o.getParameters());
+// System.out.println("Took a step of " + alpha + " new value " + o.getValue());
+ values.add(o.getValue());
+ gradients.add(MathUtils.dotProduct(o.getGradient(),searchDirection));
+ }
+
+
+
+ public int getNrIterations(){
+ return nrIterations;
+ }
+
+ /**
+ * return g(alpha) for the current value of alpha
+ * @param iter
+ * @return
+ */
+ public double getValue(int iter){
+ return values.get(iter);
+ }
+
+ public double getCurrentValue(){
+ return values.get(nrIterations);
+ }
+
+ public double getOriginalValue(){
+ return values.get(0);
+ }
+
+ /**
+ * return g'(alpha) for the current value of alpha
+ * @param iter
+ * @return
+ */
+ public double getGradient(int iter){
+ return gradients.get(iter);
+ }
+
+ public double getCurrentGradient(){
+ return gradients.get(nrIterations);
+ }
+
+ public double getInitialGradient(){
+ return gradients.get(0);
+ }
+
+
+
+
+ public double getAlpha(){
+ return steps.get(nrIterations);
+ }
+
+ public void printLineSearchSteps(){
+ System.out.println(
+ " Steps size "+steps.size() +
+ "Values size "+values.size() +
+ "Gradeients size "+gradients.size());
+ for(int i =0; i < steps.size();i++){
+ System.out.println("Iter " + i + " step " + steps.get(i) +
+ " value " + values.get(i) + " grad " + gradients.get(i));
+ }
+ }
+
+ public void printSmallLineSearchSteps(){
+ for(int i =0; i < steps.size();i++){
+ System.out.print(StaticTools.prettyPrint(steps.get(i), "0.0000E00",8) + " ");
+ }
+ System.out.println();
+ }
+
+ public static void main(String[] args) {
+
+ }
+
+}