package optimization.linesearch;

import gnu.trove.TDoubleArrayList;
import gnu.trove.TIntArrayList;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;

import optimization.gradientBasedMethods.Objective;
import optimization.util.MathUtils;
import optimization.util.StaticTools;



import util.MathUtil;
import util.Printing;


/**
 * A wrapper class for the actual objective in order to perform 
 * line search.  The optimization code assumes that this does a lot 
 * of caching in order to simplify legibility.  For the applications 
 * we use it for, caching the entire history of evaluations should be 
 * a win. 
 * 
 * Note: the lastEvaluatedAt value is very important, since we will use
 * it to avoid doing an evaluation of the gradient after the line search.  
 * 
 * The differentiable line search objective defines a search along the ray
 * given by a direction of the main objective.
 * It defines the following function, 
 * where f is the original objective function:
 * g(alpha) = f(x_0 + alpha*direction)
 * g'(alpha) = f'(x_0 + alpha*direction)*direction
 * 
 * @author joao
 *
 */
public class DifferentiableLineSearchObjective {

	
	
	Objective o;
	int nrIterations;
	TDoubleArrayList steps;
	TDoubleArrayList values;
	TDoubleArrayList gradients;
	
	//This variables cannot change
	public double[] originalParameters;
	public double[] searchDirection;

	
	/**
	 * Defines a line search objective:
	 * Receives:
	 * Objective to each we are performing the line search, is used to calculate values and gradients
	 * Direction where to do the ray search, note that the direction does not depend of the 
	 * objective but depends from the method.
	 * @param o
	 * @param direction
	 */
	public DifferentiableLineSearchObjective(Objective o) {
		this.o = o;
		originalParameters = new double[o.getNumParameters()];
		searchDirection = new double[o.getNumParameters()];
		steps = new TDoubleArrayList();
		values = new TDoubleArrayList();
		gradients = new TDoubleArrayList();
	}
	/**
	 * Called whenever we start a new iteration. 
	 * Receives the ray where we are searching for and resets all values
	 * 
	 */
	public void reset(double[] direction){
		//Copy initial values
		System.arraycopy(o.getParameters(), 0, originalParameters, 0, o.getNumParameters());
		System.arraycopy(direction, 0, searchDirection, 0, o.getNumParameters());
		
		//Initialize variables
		nrIterations = 0;
		steps.clear();
		values.clear();
		gradients.clear();
	
		values.add(o.getValue());
		gradients.add(MathUtils.dotProduct(o.getGradient(),direction));	
		steps.add(0);
	}
	
	
	/**
	 * update the current value of alpha.
	 * Takes a step with that alpha in direction
	 * Get the real objective value and gradient and calculate all required information.
	 */
	public void updateAlpha(double alpha){
		if(alpha < 0){
			System.out.println("alpha may not be smaller that zero");
			throw new RuntimeException();
		}
		nrIterations++;
		steps.add(alpha);
		//x_t+1 = x_t + alpha*direction
		System.arraycopy(originalParameters,0, o.getParameters(), 0, originalParameters.length);
		MathUtils.plusEquals(o.getParameters(), searchDirection, alpha);
		o.setParameters(o.getParameters());
//		System.out.println("Took a step of " + alpha + " new value " + o.getValue());
		values.add(o.getValue());
		gradients.add(MathUtils.dotProduct(o.getGradient(),searchDirection));		
	}

	
	
	public int getNrIterations(){
		return nrIterations;
	}
	
	/**
	 * return g(alpha) for the current value of alpha
	 * @param iter
	 * @return
	 */
	public double getValue(int iter){
		return values.get(iter);
	}
	
	public double getCurrentValue(){
		return values.get(nrIterations);
	}
	
	public double getOriginalValue(){
		return values.get(0);
	}

	/**
	 * return g'(alpha) for the current value of alpha
	 * @param iter
	 * @return
	 */
	public double getGradient(int iter){
		return gradients.get(iter);
	}
	
	public double getCurrentGradient(){
		return gradients.get(nrIterations);
	}
	
	public double getInitialGradient(){
		return gradients.get(0);
	}
	
	
	
	
	public double getAlpha(){
		return steps.get(nrIterations);
	}
	
	public void printLineSearchSteps(){
		System.out.println(
				" Steps size "+steps.size() + 
				"Values size "+values.size() +
				"Gradeients size "+gradients.size());
		for(int i =0; i < steps.size();i++){
			System.out.println("Iter " + i + " step " + steps.get(i) +
					" value " + values.get(i) + " grad "  + gradients.get(i));
		}
	}
	
	public void printSmallLineSearchSteps(){
		for(int i =0; i < steps.size();i++){
			System.out.print(StaticTools.prettyPrint(steps.get(i), "0.0000E00",8) + " ");
		}
		System.out.println();
	}
	
	public static void main(String[] args) {
		
	}
	
}