summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/AbstractGradientBaseMethod.java
diff options
context:
space:
mode:
authordesaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-09 16:59:55 +0000
committerdesaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-09 16:59:55 +0000
commit7f69c868c41e4b36eecf9d3b1dc22f3f3aa1540c (patch)
treed22aa7b6f47248ed6da02b77a0680b6b83e67b63 /gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/AbstractGradientBaseMethod.java
parent4e37402323c3227e90a89345387834e149732b5c (diff)
add optimization library source code
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@204 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/AbstractGradientBaseMethod.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/AbstractGradientBaseMethod.java119
1 files changed, 119 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/AbstractGradientBaseMethod.java b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/AbstractGradientBaseMethod.java
new file mode 100644
index 00000000..0a4a5445
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/AbstractGradientBaseMethod.java
@@ -0,0 +1,119 @@
+package optimization.gradientBasedMethods;
+
+import optimization.gradientBasedMethods.stats.OptimizerStats;
+import optimization.linesearch.DifferentiableLineSearchObjective;
+import optimization.linesearch.LineSearchMethod;
+import optimization.stopCriteria.StopingCriteria;
+import optimization.util.MathUtils;
+
+/**
+ *
+ * @author javg
+ *
+ */
+public abstract class AbstractGradientBaseMethod implements Optimizer{
+
+ protected int maxNumberOfIterations=10000;
+
+
+
+ protected int currentProjectionIteration;
+ protected double currValue;
+ protected double previousValue = Double.MAX_VALUE;;
+ protected double step;
+ protected double[] gradient;
+ public double[] direction;
+
+ //Original values
+ protected double originalGradientL2Norm;
+
+ protected LineSearchMethod lineSearch;
+ DifferentiableLineSearchObjective lso;
+
+
+ public void reset(){
+ direction = null;
+ gradient = null;
+ previousValue = Double.MAX_VALUE;
+ currentProjectionIteration = 0;
+ originalGradientL2Norm = 0;
+ step = 0;
+ currValue = 0;
+ }
+
+ public void initializeStructures(Objective o,OptimizerStats stats, StopingCriteria stop){
+ lso = new DifferentiableLineSearchObjective(o);
+ }
+ public void updateStructuresBeforeStep(Objective o,OptimizerStats stats, StopingCriteria stop){
+ }
+
+ public void updateStructuresAfterStep(Objective o,OptimizerStats stats, StopingCriteria stop){
+ }
+
+ public boolean optimize(Objective o,OptimizerStats stats, StopingCriteria stop){
+ //Initialize structures
+
+ stats.collectInitStats(this, o);
+ direction = new double[o.getNumParameters()];
+ initializeStructures(o, stats, stop);
+ for (currentProjectionIteration = 1; currentProjectionIteration < maxNumberOfIterations; currentProjectionIteration++){
+// System.out.println("starting iterations: parameters:" );
+// o.printParameters();
+ previousValue = currValue;
+ currValue = o.getValue();
+ gradient = o.getGradient();
+ if(stop.stopOptimization(o)){
+ stats.collectFinalStats(this, o);
+ return true;
+ }
+
+ getDirection();
+ if(MathUtils.dotProduct(gradient, direction) > 0){
+ System.out.println("Not a descent direction");
+ System.out.println(" current stats " + stats.prettyPrint(1));
+ System.exit(-1);
+ }
+ updateStructuresBeforeStep(o, stats, stop);
+ lso.reset(direction);
+ step = lineSearch.getStepSize(lso);
+// System.out.println("Leave with step: " + step);
+ if(step==-1){
+ System.out.println("Failed to find step");
+ stats.collectFinalStats(this, o);
+ return false;
+ }
+ updateStructuresAfterStep( o, stats, stop);
+// previousValue = currValue;
+// currValue = o.getValue();
+// gradient = o.getGradient();
+ stats.collectIterationStats(this, o);
+ }
+ stats.collectFinalStats(this, o);
+ return false;
+ }
+
+
+ public int getCurrentIteration() {
+ return currentProjectionIteration;
+ }
+
+
+ /**
+ * Method specific
+ */
+ public abstract double[] getDirection();
+
+ public double getCurrentStep() {
+ return step;
+ }
+
+
+
+ public void setMaxIterations(int max) {
+ maxNumberOfIterations = max;
+ }
+
+ public double getCurrentValue() {
+ return currValue;
+ }
+}