diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/optimization/util')
5 files changed, 591 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/optimization/util/Interpolation.java b/gi/posterior-regularisation/prjava/src/optimization/util/Interpolation.java new file mode 100644 index 00000000..cdbdefc6 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/optimization/util/Interpolation.java @@ -0,0 +1,37 @@ +package optimization.util; + +public class Interpolation { + + /** + * Fits a cubic polinomyal to a function given two points, + * such that either gradB is bigger than zero or funcB >= funcA + * + * NonLinear Programming appendix C + * @param funcA + * @param gradA + * @param funcB + * @param gradB + */ + public final static double cubicInterpolation(double a, + double funcA, double gradA, double b,double funcB, double gradB ){ + if(gradB < 0 && funcA > funcB){ + System.out.println("Cannot call cubic interpolation"); + return -1; + } + + double z = 3*(funcA-funcB)/(b-a) + gradA + gradB; + double w = Math.sqrt(z*z - gradA*gradB); + double min = b -(gradB+w-z)*(b-a)/(gradB-gradA+2*w); + return min; + } + + public final static double quadraticInterpolation(double initFValue, + double initGrad, double point,double pointFValue){ + double min = -1*initGrad*point*point/(2*(pointFValue-initGrad*point-initFValue)); + return min; + } + + public static void main(String[] args) { + + } +} diff --git a/gi/posterior-regularisation/prjava/src/optimization/util/Logger.java b/gi/posterior-regularisation/prjava/src/optimization/util/Logger.java new file mode 100644 index 00000000..5343a39b --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/optimization/util/Logger.java @@ -0,0 +1,7 @@ +package optimization.util; + +public class Logger { + + + +} diff --git a/gi/posterior-regularisation/prjava/src/optimization/util/MathUtils.java b/gi/posterior-regularisation/prjava/src/optimization/util/MathUtils.java new file mode 100644 index 00000000..af66f82c --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/optimization/util/MathUtils.java @@ -0,0 +1,339 @@ +package optimization.util; + +import java.util.Arrays; + + + +public class MathUtils { + + /** + * + * @param vector + * @return + */ + public static double L2Norm(double[] vector){ + double value = 0; + for(int i = 0; i < vector.length; i++){ + double v = vector[i]; + value+=v*v; + } + return Math.sqrt(value); + } + + public static double sum(double[] v){ + double sum = 0; + for (int i = 0; i < v.length; i++) { + sum+=v[i]; + } + return sum; + } + + + + + /** + * w = w + v + * @param w + * @param v + */ + public static void plusEquals(double[] w, double[] v) { + for(int i=0; i<w.length;i++){ + w[i] += w[i] + v[i]; + } + } + + /** + * w[i] = w[i] + v + * @param w + * @param v + */ + public static void plusEquals(double[] w, double v) { + for(int i=0; i<w.length;i++){ + w[i] += w[i] + v; + } + } + + /** + * w[i] = w[i] - v + * @param w + * @param v + */ + public static void minusEquals(double[] w, double v) { + for(int i=0; i<w.length;i++){ + w[i] -= w[i] + v; + } + } + + /** + * w = w + a*v + * @param w + * @param v + * @param a + */ + public static void plusEquals(double[] w, double[] v, double a) { + for(int i=0; i<w.length;i++){ + w[i] += a*v[i]; + } + } + + /** + * w = w - a*v + * @param w + * @param v + * @param a + */ + public static void minusEquals(double[] w, double[] v, double a) { + for(int i=0; i<w.length;i++){ + w[i] -= a*v[i]; + } + } + /** + * v = w - a*v + * @param w + * @param v + * @param a + */ + public static void minusEqualsInverse(double[] w, double[] v, double a) { + for(int i=0; i<w.length;i++){ + v[i] = w[i] - a*v[i]; + } + } + + public static double dotProduct(double[] w, double[] v){ + double accum = 0; + for(int i=0; i<w.length;i++){ + accum += w[i]*v[i]; + } + return accum; + } + + public static double[] arrayMinus(double[]w, double[]v){ + double result[] = w.clone(); + for(int i=0; i<w.length;i++){ + result[i] -= v[i]; + } + return result; + } + + public static double[] arrayMinus(double[] result , double[]w, double[]v){ + for(int i=0; i<w.length;i++){ + result[i] = w[i]-v[i]; + } + return result; + } + + public static double[] negation(double[]w){ + double result[] = new double[w.length]; + for(int i=0; i<w.length;i++){ + result[i] = -w[i]; + } + return result; + } + + public static double square(double value){ + return value*value; + } + public static double[][] outerProduct(double[] w, double[] v){ + double[][] result = new double[w.length][v.length]; + for(int i = 0; i < w.length; i++){ + for(int j = 0; j < v.length; j++){ + result[i][j] = w[i]*v[j]; + } + } + return result; + } + /** + * results = a*W*V + * @param w + * @param v + * @param a + * @return + */ + public static double[][] weightedouterProduct(double[] w, double[] v, double a){ + double[][] result = new double[w.length][v.length]; + for(int i = 0; i < w.length; i++){ + for(int j = 0; j < v.length; j++){ + result[i][j] = a*w[i]*v[j]; + } + } + return result; + } + + public static double[][] identity(int size){ + double[][] result = new double[size][size]; + for(int i = 0; i < size; i++){ + result[i][i] = 1; + } + return result; + } + + /** + * v -= w + * @param v + * @param w + */ + public static void minusEquals(double[][] w, double[][] v){ + for(int i = 0; i < w.length; i++){ + for(int j = 0; j < w[0].length; j++){ + w[i][j] -= v[i][j]; + } + } + } + + /** + * v[i][j] -= a*w[i][j] + * @param v + * @param w + */ + public static void minusEquals(double[][] w, double[][] v, double a){ + for(int i = 0; i < w.length; i++){ + for(int j = 0; j < w[0].length; j++){ + w[i][j] -= a*v[i][j]; + } + } + } + + /** + * v += w + * @param v + * @param w + */ + public static void plusEquals(double[][] w, double[][] v){ + for(int i = 0; i < w.length; i++){ + for(int j = 0; j < w[0].length; j++){ + w[i][j] += v[i][j]; + } + } + } + + /** + * v[i][j] += a*w[i][j] + * @param v + * @param w + */ + public static void plusEquals(double[][] w, double[][] v, double a){ + for(int i = 0; i < w.length; i++){ + for(int j = 0; j < w[0].length; j++){ + w[i][j] += a*v[i][j]; + } + } + } + + + /** + * results = w*v + * @param w + * @param v + * @return + */ + public static double[][] matrixMultiplication(double[][] w,double[][] v){ + int w1 = w.length; + int w2 = w[0].length; + int v1 = v.length; + int v2 = v[0].length; + + if(w2 != v1){ + System.out.println("Matrix dimensions do not agree..."); + System.exit(-1); + } + + double[][] result = new double[w1][v2]; + for(int w_i1 = 0; w_i1 < w1; w_i1++){ + for(int v_i2 = 0; v_i2 < v2; v_i2++){ + double sum = 0; + for(int w_i2 = 0; w_i2 < w2; w_i2++){ + sum += w[w_i1 ][w_i2]*v[w_i2][v_i2]; + } + result[w_i1][v_i2] = sum; + } + } + return result; + } + + /** + * w = w.*v + * @param w + * @param v + */ + public static void matrixScalarMultiplication(double[][] w,double v){ + int w1 = w.length; + int w2 = w[0].length; + for(int w_i1 = 0; w_i1 < w1; w_i1++){ + for(int w_i2 = 0; w_i2 < w2; w_i2++){ + w[w_i1 ][w_i2] *= v; + } + } + } + + public static void scalarMultiplication(double[] w,double v){ + int w1 = w.length; + for(int w_i1 = 0; w_i1 < w1; w_i1++){ + w[w_i1 ] *= v; + } + + } + + public static double[] matrixVector(double[][] w,double[] v){ + int w1 = w.length; + int w2 = w[0].length; + int v1 = v.length; + + if(w2 != v1){ + System.out.println("Matrix dimensions do not agree..."); + System.exit(-1); + } + + double[] result = new double[w1]; + for(int w_i1 = 0; w_i1 < w1; w_i1++){ + double sum = 0; + for(int w_i2 = 0; w_i2 < w2; w_i2++){ + sum += w[w_i1 ][w_i2]*v[w_i2]; + } + result[w_i1] = sum; + } + return result; + } + + public static boolean allPositive(double[] array){ + for (int i = 0; i < array.length; i++) { + if(array[i] < 0) return false; + } + return true; + } + + + + + + public static void main(String[] args) { + double[][] m1 = new double[2][2]; + m1[0][0]=2; + m1[1][0]=2; + m1[0][1]=2; + m1[1][1]=2; + MatrixOutput.printDoubleArray(m1, "m1"); + double[][] m2 = new double[2][2]; + m2[0][0]=3; + m2[1][0]=3; + m2[0][1]=3; + m2[1][1]=3; + MatrixOutput.printDoubleArray(m2, "m2"); + double[][] result = matrixMultiplication(m1, m2); + MatrixOutput.printDoubleArray(result, "result"); + matrixScalarMultiplication(result, 3); + MatrixOutput.printDoubleArray(result, "result after multiply by 3"); + } + + public static boolean almost(double a, double b, double prec){ + return Math.abs(a-b)/Math.abs(a+b) <= prec || (almostZero(a) && almostZero(b)); + } + + public static boolean almost(double a, double b){ + return Math.abs(a-b)/Math.abs(a+b) <= 1e-10 || (almostZero(a) && almostZero(b)); + } + + public static boolean almostZero(double a) { + return Math.abs(a) <= 1e-30; + } + +} diff --git a/gi/posterior-regularisation/prjava/src/optimization/util/MatrixOutput.java b/gi/posterior-regularisation/prjava/src/optimization/util/MatrixOutput.java new file mode 100644 index 00000000..9fbdf955 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/optimization/util/MatrixOutput.java @@ -0,0 +1,28 @@ +package optimization.util; + + +public class MatrixOutput { + public static void printDoubleArray(double[][] array, String arrayName) { + int size1 = array.length; + int size2 = array[0].length; + System.out.println(arrayName); + for (int i = 0; i < size1; i++) { + for (int j = 0; j < size2; j++) { + System.out.print(" " + StaticTools.prettyPrint(array[i][j], + "00.00E00", 4) + " "); + + } + System.out.println(); + } + System.out.println(); + } + + public static void printDoubleArray(double[] array, String arrayName) { + System.out.println(arrayName); + for (int i = 0; i < array.length; i++) { + System.out.print(" " + StaticTools.prettyPrint(array[i], + "00.00E00", 4) + " "); + } + System.out.println(); + } +} diff --git a/gi/posterior-regularisation/prjava/src/optimization/util/StaticTools.java b/gi/posterior-regularisation/prjava/src/optimization/util/StaticTools.java new file mode 100644 index 00000000..bcabee06 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/optimization/util/StaticTools.java @@ -0,0 +1,180 @@ +package optimization.util; + + +import java.io.File; +import java.io.PrintStream; + +public class StaticTools { + + static java.text.DecimalFormat fmt = new java.text.DecimalFormat(); + + public static void createDir(String directory) { + + File dir = new File(directory); + if (!dir.isDirectory()) { + boolean success = dir.mkdirs(); + if (!success) { + System.out.println("Unable to create directory " + directory); + System.exit(0); + } + System.out.println("Created directory " + directory); + } else { + System.out.println("Reusing directory " + directory); + } + } + + /* + * q and p are indexed by source/foreign Sum_S(q) = 1 the same for p KL(q,p) = + * Eq*q/p + */ + public static double KLDistance(double[][] p, double[][] q, int sourceSize, + int foreignSize) { + double totalKL = 0; + // common.StaticTools.printMatrix(q, sourceSize, foreignSize, "q", + // System.out); + // common.StaticTools.printMatrix(p, sourceSize, foreignSize, "p", + // System.out); + for (int i = 0; i < sourceSize; i++) { + double kl = 0; + for (int j = 0; j < foreignSize; j++) { + assert !Double.isNaN(q[i][j]) : "KLDistance q: prob is NaN"; + assert !Double.isNaN(p[i][j]) : "KLDistance p: prob is NaN"; + if (p[i][j] == 0 || q[i][j] == 0) { + continue; + } else { + kl += q[i][j] * Math.log(q[i][j] / p[i][j]); + } + + } + totalKL += kl; + } + assert !Double.isNaN(totalKL) : "KLDistance: prob is NaN"; + if (totalKL < -1.0E-10) { + System.out.println("KL Smaller than zero " + totalKL); + System.out.println("Source Size" + sourceSize); + System.out.println("Foreign Size" + foreignSize); + StaticTools.printMatrix(q, sourceSize, foreignSize, "q", + System.out); + StaticTools.printMatrix(p, sourceSize, foreignSize, "p", + System.out); + System.exit(-1); + } + return totalKL / sourceSize; + } + + /* + * indexed the by [fi][si] + */ + public static double KLDistancePrime(double[][] p, double[][] q, + int sourceSize, int foreignSize) { + double totalKL = 0; + for (int i = 0; i < sourceSize; i++) { + double kl = 0; + for (int j = 0; j < foreignSize; j++) { + assert !Double.isNaN(q[j][i]) : "KLDistance q: prob is NaN"; + assert !Double.isNaN(p[j][i]) : "KLDistance p: prob is NaN"; + if (p[j][i] == 0 || q[j][i] == 0) { + continue; + } else { + kl += q[j][i] * Math.log(q[j][i] / p[j][i]); + } + + } + totalKL += kl; + } + assert !Double.isNaN(totalKL) : "KLDistance: prob is NaN"; + return totalKL / sourceSize; + } + + public static double Entropy(double[][] p, int sourceSize, int foreignSize) { + double totalE = 0; + for (int i = 0; i < foreignSize; i++) { + double e = 0; + for (int j = 0; j < sourceSize; j++) { + e += p[i][j] * Math.log(p[i][j]); + } + totalE += e; + } + return totalE / sourceSize; + } + + public static double[][] copyMatrix(double[][] original, int sourceSize, + int foreignSize) { + double[][] result = new double[sourceSize][foreignSize]; + for (int i = 0; i < sourceSize; i++) { + for (int j = 0; j < foreignSize; j++) { + result[i][j] = original[i][j]; + } + } + return result; + } + + public static void printMatrix(double[][] matrix, int sourceSize, + int foreignSize, String info, PrintStream out) { + + java.text.DecimalFormat fmt = new java.text.DecimalFormat(); + fmt.setMaximumFractionDigits(3); + fmt.setMaximumIntegerDigits(3); + fmt.setMinimumFractionDigits(3); + fmt.setMinimumIntegerDigits(3); + + out.println(info); + + for (int i = 0; i < foreignSize; i++) { + for (int j = 0; j < sourceSize; j++) { + out.print(prettyPrint(matrix[j][i], ".00E00", 6) + " "); + } + out.println(); + } + out.println(); + out.println(); + } + + public static void printMatrix(int[][] matrix, int sourceSize, + int foreignSize, String info, PrintStream out) { + + out.println(info); + for (int i = 0; i < foreignSize; i++) { + for (int j = 0; j < sourceSize; j++) { + out.print(matrix[j][i] + " "); + } + out.println(); + } + out.println(); + out.println(); + } + + public static String formatTime(long duration) { + StringBuilder sb = new StringBuilder(); + double d = duration / 1000; + fmt.applyPattern("00"); + sb.append(fmt.format((int) (d / (60 * 60))) + ":"); + d -= ((int) d / (60 * 60)) * 60 * 60; + sb.append(fmt.format((int) (d / 60)) + ":"); + d -= ((int) d / 60) * 60; + fmt.applyPattern("00.0"); + sb.append(fmt.format(d)); + return sb.toString(); + } + + public static String prettyPrint(double d, String patt, int len) { + fmt.applyPattern(patt); + String s = fmt.format(d); + while (s.length() < len) { + s = " " + s; + } + return s; + } + + + public static long getUsedMemory(){ + System.gc(); + return (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory())/ (1024 * 1024); + } + + public final static boolean compareDoubles(double d1, double d2){ + return Math.abs(d1-d2) <= 1.E-10; + } + + +} |