summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/LBFGS.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/LBFGS.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/LBFGS.java234
1 files changed, 234 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/LBFGS.java b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/LBFGS.java
new file mode 100644
index 00000000..dedbc942
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/LBFGS.java
@@ -0,0 +1,234 @@
+package optimization.gradientBasedMethods;
+
+
+import optimization.gradientBasedMethods.stats.OptimizerStats;
+import optimization.linesearch.DifferentiableLineSearchObjective;
+import optimization.linesearch.LineSearchMethod;
+import optimization.stopCriteria.StopingCriteria;
+import optimization.util.MathUtils;
+
+public class LBFGS extends AbstractGradientBaseMethod{
+
+ //How many previous values are being saved
+ int history;
+ double[][] skList;
+ double[][] ykList;
+ double initialHessianParameters;
+ double[] previousGradient;
+ double[] previousParameters;
+
+ //auxiliar structures
+ double q[];
+ double[] roi;
+ double[] alphai;
+
+ public LBFGS(LineSearchMethod ls, int history) {
+ lineSearch = ls;
+ this.history = history;
+ skList = new double[history][];
+ ykList = new double[history][];
+
+ }
+
+ public void reset(){
+ super.reset();
+ initialHessianParameters = 0;
+ previousParameters = null;
+ previousGradient = null;
+ skList = new double[history][];
+ ykList = new double[history][];
+ q = null;
+ roi = null;
+ alphai = null;
+ }
+
+ public double[] LBFGSTwoLoopRecursion(double hessianConst){
+ //Only create array once
+ if(q == null){
+ q = new double[gradient.length];
+ }
+ System.arraycopy(gradient, 0, q, 0, gradient.length);
+ //Only create array once
+ if(roi == null){
+ roi = new double[history];
+ }
+ //Only create array once
+ if(alphai == null){
+ alphai = new double[history];
+ }
+
+ for(int i = history-1; i >=0 && skList[i]!= null && ykList[i]!=null; i-- ){
+ // System.out.println("New to Old proj " + currentProjectionIteration + " history "+history + " index " + i);
+ double[] si = skList[i];
+ double[] yi = ykList[i];
+ roi[i]= 1.0/MathUtils.dotProduct(yi,si);
+ alphai[i] = MathUtils.dotProduct(si, q)*roi[i];
+ MathUtils.plusEquals(q, yi, -alphai[i]);
+ }
+ //Initial Hessian is just a constant
+ MathUtils.scalarMultiplication(q, hessianConst);
+ for(int i = 0; i <history && skList[i]!= null && ykList[i]!=null; i++ ){
+ // System.out.println("Old to New proj " + currentProjectionIteration + " history "+history + " index " + i);
+ double beta = MathUtils.dotProduct(ykList[i], q)*roi[i];
+ MathUtils.plusEquals(q, skList[i], (alphai[i]-beta));
+ }
+ return q;
+ }
+
+
+
+
+ @Override
+ public double[] getDirection() {
+
+ calculateInitialHessianParameter();
+// System.out.println("Initial hessian " + initialHessianParameters);
+ return direction = MathUtils.negation(LBFGSTwoLoopRecursion(initialHessianParameters));
+ }
+
+ public void calculateInitialHessianParameter(){
+ if(currentProjectionIteration == 1){
+ //Use gradient
+ initialHessianParameters = 1;
+ }else if(currentProjectionIteration <= history){
+ double[] sk = skList[currentProjectionIteration-2];
+ double[] yk = ykList[currentProjectionIteration-2];
+ initialHessianParameters = MathUtils.dotProduct(sk, yk)/MathUtils.dotProduct(yk, yk);
+ }else{
+ //get the last one
+ double[] sk = skList[history-1];
+ double[] yk = ykList[history-1];
+ initialHessianParameters = MathUtils.dotProduct(sk, yk)/MathUtils.dotProduct(yk, yk);
+ }
+ }
+
+ //TODO if structures exit just reset them to zero
+ public void initializeStructures(Objective o,OptimizerStats stats, StopingCriteria stop){
+ super.initializeStructures(o, stats, stop);
+ previousParameters = new double[o.getNumParameters()];
+ previousGradient = new double[o.getNumParameters()];
+ }
+ public void updateStructuresBeforeStep(Objective o,OptimizerStats stats, StopingCriteria stop){
+ super.initializeStructures(o, stats, stop);
+ System.arraycopy(o.getParameters(), 0, previousParameters, 0, previousParameters.length);
+ System.arraycopy(gradient, 0, previousGradient, 0, gradient.length);
+ }
+
+ public void updateStructuresAfterStep( Objective o,OptimizerStats stats, StopingCriteria stop){
+ double[] diffX = MathUtils.arrayMinus(o.getParameters(), previousParameters);
+ double[] diffGrad = MathUtils.arrayMinus(gradient, previousGradient);
+ //Save new values and discard new ones
+ if(currentProjectionIteration > history){
+ for(int i = 0; i < history-1;i++){
+ skList[i]=skList[i+1];
+ ykList[i]=ykList[i+1];
+ }
+ skList[history-1]=diffX;
+ ykList[history-1]=diffGrad;
+ }else{
+ skList[currentProjectionIteration-1]=diffX;
+ ykList[currentProjectionIteration-1]=diffGrad;
+ }
+ }
+
+// public boolean optimize(Objective o, OptimizerStats stats, StopingCriteria stop) {
+// DifferentiableLineSearchObjective lso = new DifferentiableLineSearchObjective(o);
+// gradient = o.getGradient();
+// direction = new double[o.getNumParameters()];
+// previousGradient = new double[o.getNumParameters()];
+//
+// previousParameters = new double[o.getNumParameters()];
+//
+// stats.collectInitStats(this, o);
+// previousValue = Double.MAX_VALUE;
+// currValue= o.getValue();
+// //Used for stopping criteria
+// double[] originalGradient = o.getGradient();
+//
+// originalGradientL2Norm = MathUtils.L2Norm(originalGradient);
+// if(stop.stopOptimization(originalGradient)){
+// stats.collectFinalStats(this, o);
+// return true;
+// }
+// for (currentProjectionIteration = 1; currentProjectionIteration < maxNumberOfIterations; currentProjectionIteration++){
+//
+//
+// currValue = o.getValue();
+// gradient = o.getGradient();
+// currParameters = o.getParameters();
+//
+//
+// if(currentProjectionIteration == 1){
+// //Use gradient
+// initialHessianParameters = 1;
+// }else if(currentProjectionIteration <= history){
+// double[] sk = skList[currentProjectionIteration-2];
+// double[] yk = ykList[currentProjectionIteration-2];
+// initialHessianParameters = MathUtils.dotProduct(sk, yk)/MathUtils.dotProduct(yk, yk);
+// }else{
+// //get the last one
+// double[] sk = skList[history-1];
+// double[] yk = ykList[history-1];
+// initialHessianParameters = MathUtils.dotProduct(sk, yk)/MathUtils.dotProduct(yk, yk);
+// }
+//
+// getDirection();
+//
+// //MatrixOutput.printDoubleArray(direction, "direction");
+// double dot = MathUtils.dotProduct(direction, gradient);
+// if(dot > 0){
+// throw new RuntimeException("Not a descent direction");
+// } if (Double.isNaN(dot)){
+// throw new RuntimeException("dot is not a number!!");
+// }
+// System.arraycopy(currParameters, 0, previousParameters, 0, currParameters.length);
+// System.arraycopy(gradient, 0, previousGradient, 0, gradient.length);
+// lso.reset(direction);
+// step = lineSearch.getStepSize(lso);
+// if(step==-1){
+// System.out.println("Failed to find a step size");
+//// lso.printLineSearchSteps();
+//// System.out.println(stats.prettyPrint(1));
+// stats.collectFinalStats(this, o);
+// return false;
+// }
+// stats.collectIterationStats(this, o);
+//
+// //We are not updating the alpha since it is done in line search already
+// currParameters = o.getParameters();
+// gradient = o.getGradient();
+//
+// if(stop.stopOptimization(gradient)){
+// stats.collectFinalStats(this, o);
+// return true;
+// }
+// double[] diffX = MathUtils.arrayMinus(currParameters, previousParameters);
+// double[] diffGrad = MathUtils.arrayMinus(gradient, previousGradient);
+// //Save new values and discard new ones
+// if(currentProjectionIteration > history){
+// for(int i = 0; i < history-1;i++){
+// skList[i]=skList[i+1];
+// ykList[i]=ykList[i+1];
+// }
+// skList[history-1]=diffX;
+// ykList[history-1]=diffGrad;
+// }else{
+// skList[currentProjectionIteration-1]=diffX;
+// ykList[currentProjectionIteration-1]=diffGrad;
+// }
+// previousValue = currValue;
+// }
+// stats.collectFinalStats(this, o);
+// return false;
+// }
+
+
+
+
+
+
+
+
+
+
+}