summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/ProjectedGradientDescent.java
diff options
context:
space:
mode:
authordesaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-09 16:59:55 +0000
committerdesaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-09 16:59:55 +0000
commitbdea91300c85539ab7153ccba58689612f66bb4d (patch)
treee778ffa1ea4d04a239b58c6e6191c0d4549006f0 /gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/ProjectedGradientDescent.java
parent0d1d84630a08f1c901cf09b4bcc9356c4165302f (diff)
add optimization library source code
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@204 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/ProjectedGradientDescent.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/ProjectedGradientDescent.java154
1 files changed, 154 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/ProjectedGradientDescent.java b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/ProjectedGradientDescent.java
new file mode 100644
index 00000000..0186e945
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/ProjectedGradientDescent.java
@@ -0,0 +1,154 @@
+package optimization.gradientBasedMethods;
+
+import java.io.IOException;
+
+import optimization.gradientBasedMethods.stats.OptimizerStats;
+import optimization.linesearch.DifferentiableLineSearchObjective;
+import optimization.linesearch.LineSearchMethod;
+import optimization.linesearch.ProjectedDifferentiableLineSearchObjective;
+import optimization.stopCriteria.StopingCriteria;
+import optimization.util.MathUtils;
+
+
+/**
+ * This class implements the projected gradiend
+ * as described in Bertsekas "Non Linear Programming"
+ * section 2.3.
+ *
+ * The update is given by:
+ * x_k+1 = x_k + alpha^k(xbar_k-x_k)
+ * Where xbar is:
+ * xbar = [x_k -s_k grad(f(x_k))]+
+ * where []+ is the projection into the feasibility set
+ *
+ * alpha is the step size
+ * s_k - is a positive scalar which can be view as a step size as well, by
+ * setting alpha to 1, then x_k+1 = [x_k -s_k grad(f(x_k))]+
+ * This is called taking a step size along the projection arc (Bertsekas) which
+ * we will use by default.
+ *
+ * Note that the only place where we actually take a step size is on pick a step size
+ * so this is going to be just like a normal gradient descent but use a different
+ * armijo line search where we project after taking a step.
+ *
+ *
+ * @author javg
+ *
+ */
+public class ProjectedGradientDescent extends ProjectedAbstractGradientBaseMethod{
+
+
+
+
+ public ProjectedGradientDescent(LineSearchMethod lineSearch) {
+ this.lineSearch = lineSearch;
+ }
+
+ //Use projected differential objective instead
+ public void initializeStructures(Objective o, OptimizerStats stats, StopingCriteria stop) {
+ lso = new ProjectedDifferentiableLineSearchObjective(o);
+ };
+
+
+ ProjectedObjective obj;
+ public boolean optimize(ProjectedObjective o,OptimizerStats stats, StopingCriteria stop){
+ obj = o;
+ return super.optimize(o, stats, stop);
+ }
+
+ public double[] getDirection(){
+ for(int i = 0; i< gradient.length; i++){
+ direction[i] = -gradient[i];
+ }
+ return direction;
+ }
+
+
+
+
+}
+
+
+
+
+
+
+
+///OLD CODE
+
+//Use projected gradient norm
+//public boolean stopCriteria(double[] gradient){
+// if(originalDirenctionL2Norm == 0){
+// System.out.println("Leaving original direction norm is zero");
+// return true;
+// }
+// if(MathUtils.L2Norm(direction)/originalDirenctionL2Norm < gradientConvergenceValue){
+// System.out.println("Leaving projected gradient Norm smaller than epsilon");
+// return true;
+// }
+// if((previousValue - currValue)/Math.abs(previousValue) < valueConvergenceValue) {
+// System.out.println("Leaving value change below treshold " + previousValue + " - " + currValue);
+// System.out.println(previousValue/currValue + " - " + currValue/currValue
+// + " = " + (previousValue - currValue)/Math.abs(previousValue));
+// return true;
+// }
+// return false;
+//}
+//
+
+//public boolean optimize(ProjectedObjective o,OptimizerStats stats, StopingCriteria stop){
+// stats.collectInitStats(this, o);
+// obj = o;
+// step = 0;
+// currValue = o.getValue();
+// previousValue = Double.MAX_VALUE;
+// gradient = o.getGradient();
+// originalGradientL2Norm = MathUtils.L2Norm(gradient);
+// parameterChange = new double[gradient.length];
+// getDirection();
+// ProjectedDifferentiableLineSearchObjective lso = new ProjectedDifferentiableLineSearchObjective(o,direction);
+//
+// originalDirenctionL2Norm = MathUtils.L2Norm(direction);
+// //MatrixOutput.printDoubleArray(currParameters, "parameters");
+// for (currentProjectionIteration = 0; currentProjectionIteration < maxNumberOfIterations; currentProjectionIteration++){
+// // System.out.println("Iter " + currentProjectionIteration);
+// //o.printParameters();
+//
+//
+//
+// if(stop.stopOptimization(gradient)){
+// stats.collectFinalStats(this, o);
+// lastStepUsed = step;
+// return true;
+// }
+// lso.reset(direction);
+// step = lineSearch.getStepSize(lso);
+// if(step==-1){
+// System.out.println("Failed to find step");
+// stats.collectFinalStats(this, o);
+// return false;
+//
+// }
+//
+// //Update the direction for stopping criteria
+// previousValue = currValue;
+// currValue = o.getValue();
+// gradient = o.getGradient();
+// direction = getDirection();
+// if(MathUtils.dotProduct(gradient, direction) > 0){
+// System.out.println("Not a descent direction");
+// System.out.println(" current stats " + stats.prettyPrint(1));
+// System.exit(-1);
+// }
+// stats.collectIterationStats(this, o);
+// }
+// lastStepUsed = step;
+// stats.collectFinalStats(this, o);
+// return false;
+// }
+
+//public boolean optimize(Objective o,OptimizerStats stats, StopingCriteria stop){
+// System.out.println("Objective is not a projected objective");
+// throw new RuntimeException();
+//}
+