package optimization.gradientBasedMethods;

import optimization.gradientBasedMethods.stats.OptimizerStats;
import optimization.linesearch.DifferentiableLineSearchObjective;
import optimization.linesearch.LineSearchMethod;
import optimization.stopCriteria.StopingCriteria;
import optimization.util.MathUtils;

/**
 * 
 * @author javg
 *
 */
public abstract class AbstractGradientBaseMethod implements Optimizer{
	
	protected int maxNumberOfIterations=10000;
	
	
	
	protected int currentProjectionIteration;
	protected double currValue;	
	protected double previousValue = Double.MAX_VALUE;;
	protected double step;
	protected double[] gradient;
	public double[] direction;
	
	//Original values
	protected double originalGradientL2Norm;
	
	protected LineSearchMethod lineSearch;
	DifferentiableLineSearchObjective lso;
	
	
	public void reset(){
		direction = null;
		gradient = null;
		previousValue = Double.MAX_VALUE;
		currentProjectionIteration = 0;
		originalGradientL2Norm = 0;
		step = 0;
		currValue = 0;
	}
	
	public void initializeStructures(Objective o,OptimizerStats stats, StopingCriteria stop){
		lso =   new DifferentiableLineSearchObjective(o);
	}
	public void updateStructuresBeforeStep(Objective o,OptimizerStats stats, StopingCriteria stop){
	}
	
	public void updateStructuresAfterStep(Objective o,OptimizerStats stats, StopingCriteria stop){
	}
	
	public boolean optimize(Objective o,OptimizerStats stats, StopingCriteria stop){
		//Initialize structures
			
		stats.collectInitStats(this, o);
		direction = new double[o.getNumParameters()];
		initializeStructures(o, stats, stop);
		for (currentProjectionIteration = 1; currentProjectionIteration < maxNumberOfIterations; currentProjectionIteration++){
			//System.out.println("\tgradient descent iteration " + currentProjectionIteration);
			//System.out.print("\tparameters:" );
			//o.printParameters();
			previousValue = currValue;
			currValue = o.getValue();
			gradient = o.getGradient();
			if(stop.stopOptimization(o)){
				stats.collectFinalStats(this, o);
				return true;
			}	
			
			getDirection();
			if(MathUtils.dotProduct(gradient, direction) > 0){
				System.out.println("Not a descent direction");
				System.out.println(" current stats " + stats.prettyPrint(1));
				System.exit(-1);
			}
			updateStructuresBeforeStep(o, stats, stop);
			lso.reset(direction);
			step = lineSearch.getStepSize(lso);
			//System.out.println("\t\tLeave with step: " + step);
			if(step==-1){
				System.out.println("Failed to find step");
				stats.collectFinalStats(this, o);
				return false;		
			}
			updateStructuresAfterStep( o, stats,  stop);
//			previousValue = currValue;
//			currValue = o.getValue();
//			gradient = o.getGradient();
			stats.collectIterationStats(this, o);
		}
		stats.collectFinalStats(this, o);
		return false;
	}
	
	
	public int getCurrentIteration() {
		return currentProjectionIteration;
	}

	
	/**
	 * Method specific
	 */
	public abstract double[] getDirection();

	public double getCurrentStep() {
		return step;
	}



	public void setMaxIterations(int max) {
		maxNumberOfIterations = max;
	}

	public double getCurrentValue() {
		return currValue;
	}
}