summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/optimization/util
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/optimization/util')
-rw-r--r--gi/posterior-regularisation/prjava/src/optimization/util/Interpolation.java37
-rw-r--r--gi/posterior-regularisation/prjava/src/optimization/util/Logger.java7
-rw-r--r--gi/posterior-regularisation/prjava/src/optimization/util/MathUtils.java339
-rw-r--r--gi/posterior-regularisation/prjava/src/optimization/util/MatrixOutput.java28
-rw-r--r--gi/posterior-regularisation/prjava/src/optimization/util/StaticTools.java180
5 files changed, 591 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/optimization/util/Interpolation.java b/gi/posterior-regularisation/prjava/src/optimization/util/Interpolation.java
new file mode 100644
index 00000000..cdbdefc6
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/optimization/util/Interpolation.java
@@ -0,0 +1,37 @@
+package optimization.util;
+
+public class Interpolation {
+
+ /**
+ * Fits a cubic polinomyal to a function given two points,
+ * such that either gradB is bigger than zero or funcB >= funcA
+ *
+ * NonLinear Programming appendix C
+ * @param funcA
+ * @param gradA
+ * @param funcB
+ * @param gradB
+ */
+ public final static double cubicInterpolation(double a,
+ double funcA, double gradA, double b,double funcB, double gradB ){
+ if(gradB < 0 && funcA > funcB){
+ System.out.println("Cannot call cubic interpolation");
+ return -1;
+ }
+
+ double z = 3*(funcA-funcB)/(b-a) + gradA + gradB;
+ double w = Math.sqrt(z*z - gradA*gradB);
+ double min = b -(gradB+w-z)*(b-a)/(gradB-gradA+2*w);
+ return min;
+ }
+
+ public final static double quadraticInterpolation(double initFValue,
+ double initGrad, double point,double pointFValue){
+ double min = -1*initGrad*point*point/(2*(pointFValue-initGrad*point-initFValue));
+ return min;
+ }
+
+ public static void main(String[] args) {
+
+ }
+}
diff --git a/gi/posterior-regularisation/prjava/src/optimization/util/Logger.java b/gi/posterior-regularisation/prjava/src/optimization/util/Logger.java
new file mode 100644
index 00000000..5343a39b
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/optimization/util/Logger.java
@@ -0,0 +1,7 @@
+package optimization.util;
+
+public class Logger {
+
+
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/optimization/util/MathUtils.java b/gi/posterior-regularisation/prjava/src/optimization/util/MathUtils.java
new file mode 100644
index 00000000..af66f82c
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/optimization/util/MathUtils.java
@@ -0,0 +1,339 @@
+package optimization.util;
+
+import java.util.Arrays;
+
+
+
+public class MathUtils {
+
+ /**
+ *
+ * @param vector
+ * @return
+ */
+ public static double L2Norm(double[] vector){
+ double value = 0;
+ for(int i = 0; i < vector.length; i++){
+ double v = vector[i];
+ value+=v*v;
+ }
+ return Math.sqrt(value);
+ }
+
+ public static double sum(double[] v){
+ double sum = 0;
+ for (int i = 0; i < v.length; i++) {
+ sum+=v[i];
+ }
+ return sum;
+ }
+
+
+
+
+ /**
+ * w = w + v
+ * @param w
+ * @param v
+ */
+ public static void plusEquals(double[] w, double[] v) {
+ for(int i=0; i<w.length;i++){
+ w[i] += w[i] + v[i];
+ }
+ }
+
+ /**
+ * w[i] = w[i] + v
+ * @param w
+ * @param v
+ */
+ public static void plusEquals(double[] w, double v) {
+ for(int i=0; i<w.length;i++){
+ w[i] += w[i] + v;
+ }
+ }
+
+ /**
+ * w[i] = w[i] - v
+ * @param w
+ * @param v
+ */
+ public static void minusEquals(double[] w, double v) {
+ for(int i=0; i<w.length;i++){
+ w[i] -= w[i] + v;
+ }
+ }
+
+ /**
+ * w = w + a*v
+ * @param w
+ * @param v
+ * @param a
+ */
+ public static void plusEquals(double[] w, double[] v, double a) {
+ for(int i=0; i<w.length;i++){
+ w[i] += a*v[i];
+ }
+ }
+
+ /**
+ * w = w - a*v
+ * @param w
+ * @param v
+ * @param a
+ */
+ public static void minusEquals(double[] w, double[] v, double a) {
+ for(int i=0; i<w.length;i++){
+ w[i] -= a*v[i];
+ }
+ }
+ /**
+ * v = w - a*v
+ * @param w
+ * @param v
+ * @param a
+ */
+ public static void minusEqualsInverse(double[] w, double[] v, double a) {
+ for(int i=0; i<w.length;i++){
+ v[i] = w[i] - a*v[i];
+ }
+ }
+
+ public static double dotProduct(double[] w, double[] v){
+ double accum = 0;
+ for(int i=0; i<w.length;i++){
+ accum += w[i]*v[i];
+ }
+ return accum;
+ }
+
+ public static double[] arrayMinus(double[]w, double[]v){
+ double result[] = w.clone();
+ for(int i=0; i<w.length;i++){
+ result[i] -= v[i];
+ }
+ return result;
+ }
+
+ public static double[] arrayMinus(double[] result , double[]w, double[]v){
+ for(int i=0; i<w.length;i++){
+ result[i] = w[i]-v[i];
+ }
+ return result;
+ }
+
+ public static double[] negation(double[]w){
+ double result[] = new double[w.length];
+ for(int i=0; i<w.length;i++){
+ result[i] = -w[i];
+ }
+ return result;
+ }
+
+ public static double square(double value){
+ return value*value;
+ }
+ public static double[][] outerProduct(double[] w, double[] v){
+ double[][] result = new double[w.length][v.length];
+ for(int i = 0; i < w.length; i++){
+ for(int j = 0; j < v.length; j++){
+ result[i][j] = w[i]*v[j];
+ }
+ }
+ return result;
+ }
+ /**
+ * results = a*W*V
+ * @param w
+ * @param v
+ * @param a
+ * @return
+ */
+ public static double[][] weightedouterProduct(double[] w, double[] v, double a){
+ double[][] result = new double[w.length][v.length];
+ for(int i = 0; i < w.length; i++){
+ for(int j = 0; j < v.length; j++){
+ result[i][j] = a*w[i]*v[j];
+ }
+ }
+ return result;
+ }
+
+ public static double[][] identity(int size){
+ double[][] result = new double[size][size];
+ for(int i = 0; i < size; i++){
+ result[i][i] = 1;
+ }
+ return result;
+ }
+
+ /**
+ * v -= w
+ * @param v
+ * @param w
+ */
+ public static void minusEquals(double[][] w, double[][] v){
+ for(int i = 0; i < w.length; i++){
+ for(int j = 0; j < w[0].length; j++){
+ w[i][j] -= v[i][j];
+ }
+ }
+ }
+
+ /**
+ * v[i][j] -= a*w[i][j]
+ * @param v
+ * @param w
+ */
+ public static void minusEquals(double[][] w, double[][] v, double a){
+ for(int i = 0; i < w.length; i++){
+ for(int j = 0; j < w[0].length; j++){
+ w[i][j] -= a*v[i][j];
+ }
+ }
+ }
+
+ /**
+ * v += w
+ * @param v
+ * @param w
+ */
+ public static void plusEquals(double[][] w, double[][] v){
+ for(int i = 0; i < w.length; i++){
+ for(int j = 0; j < w[0].length; j++){
+ w[i][j] += v[i][j];
+ }
+ }
+ }
+
+ /**
+ * v[i][j] += a*w[i][j]
+ * @param v
+ * @param w
+ */
+ public static void plusEquals(double[][] w, double[][] v, double a){
+ for(int i = 0; i < w.length; i++){
+ for(int j = 0; j < w[0].length; j++){
+ w[i][j] += a*v[i][j];
+ }
+ }
+ }
+
+
+ /**
+ * results = w*v
+ * @param w
+ * @param v
+ * @return
+ */
+ public static double[][] matrixMultiplication(double[][] w,double[][] v){
+ int w1 = w.length;
+ int w2 = w[0].length;
+ int v1 = v.length;
+ int v2 = v[0].length;
+
+ if(w2 != v1){
+ System.out.println("Matrix dimensions do not agree...");
+ System.exit(-1);
+ }
+
+ double[][] result = new double[w1][v2];
+ for(int w_i1 = 0; w_i1 < w1; w_i1++){
+ for(int v_i2 = 0; v_i2 < v2; v_i2++){
+ double sum = 0;
+ for(int w_i2 = 0; w_i2 < w2; w_i2++){
+ sum += w[w_i1 ][w_i2]*v[w_i2][v_i2];
+ }
+ result[w_i1][v_i2] = sum;
+ }
+ }
+ return result;
+ }
+
+ /**
+ * w = w.*v
+ * @param w
+ * @param v
+ */
+ public static void matrixScalarMultiplication(double[][] w,double v){
+ int w1 = w.length;
+ int w2 = w[0].length;
+ for(int w_i1 = 0; w_i1 < w1; w_i1++){
+ for(int w_i2 = 0; w_i2 < w2; w_i2++){
+ w[w_i1 ][w_i2] *= v;
+ }
+ }
+ }
+
+ public static void scalarMultiplication(double[] w,double v){
+ int w1 = w.length;
+ for(int w_i1 = 0; w_i1 < w1; w_i1++){
+ w[w_i1 ] *= v;
+ }
+
+ }
+
+ public static double[] matrixVector(double[][] w,double[] v){
+ int w1 = w.length;
+ int w2 = w[0].length;
+ int v1 = v.length;
+
+ if(w2 != v1){
+ System.out.println("Matrix dimensions do not agree...");
+ System.exit(-1);
+ }
+
+ double[] result = new double[w1];
+ for(int w_i1 = 0; w_i1 < w1; w_i1++){
+ double sum = 0;
+ for(int w_i2 = 0; w_i2 < w2; w_i2++){
+ sum += w[w_i1 ][w_i2]*v[w_i2];
+ }
+ result[w_i1] = sum;
+ }
+ return result;
+ }
+
+ public static boolean allPositive(double[] array){
+ for (int i = 0; i < array.length; i++) {
+ if(array[i] < 0) return false;
+ }
+ return true;
+ }
+
+
+
+
+
+ public static void main(String[] args) {
+ double[][] m1 = new double[2][2];
+ m1[0][0]=2;
+ m1[1][0]=2;
+ m1[0][1]=2;
+ m1[1][1]=2;
+ MatrixOutput.printDoubleArray(m1, "m1");
+ double[][] m2 = new double[2][2];
+ m2[0][0]=3;
+ m2[1][0]=3;
+ m2[0][1]=3;
+ m2[1][1]=3;
+ MatrixOutput.printDoubleArray(m2, "m2");
+ double[][] result = matrixMultiplication(m1, m2);
+ MatrixOutput.printDoubleArray(result, "result");
+ matrixScalarMultiplication(result, 3);
+ MatrixOutput.printDoubleArray(result, "result after multiply by 3");
+ }
+
+ public static boolean almost(double a, double b, double prec){
+ return Math.abs(a-b)/Math.abs(a+b) <= prec || (almostZero(a) && almostZero(b));
+ }
+
+ public static boolean almost(double a, double b){
+ return Math.abs(a-b)/Math.abs(a+b) <= 1e-10 || (almostZero(a) && almostZero(b));
+ }
+
+ public static boolean almostZero(double a) {
+ return Math.abs(a) <= 1e-30;
+ }
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/optimization/util/MatrixOutput.java b/gi/posterior-regularisation/prjava/src/optimization/util/MatrixOutput.java
new file mode 100644
index 00000000..9fbdf955
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/optimization/util/MatrixOutput.java
@@ -0,0 +1,28 @@
+package optimization.util;
+
+
+public class MatrixOutput {
+ public static void printDoubleArray(double[][] array, String arrayName) {
+ int size1 = array.length;
+ int size2 = array[0].length;
+ System.out.println(arrayName);
+ for (int i = 0; i < size1; i++) {
+ for (int j = 0; j < size2; j++) {
+ System.out.print(" " + StaticTools.prettyPrint(array[i][j],
+ "00.00E00", 4) + " ");
+
+ }
+ System.out.println();
+ }
+ System.out.println();
+ }
+
+ public static void printDoubleArray(double[] array, String arrayName) {
+ System.out.println(arrayName);
+ for (int i = 0; i < array.length; i++) {
+ System.out.print(" " + StaticTools.prettyPrint(array[i],
+ "00.00E00", 4) + " ");
+ }
+ System.out.println();
+ }
+}
diff --git a/gi/posterior-regularisation/prjava/src/optimization/util/StaticTools.java b/gi/posterior-regularisation/prjava/src/optimization/util/StaticTools.java
new file mode 100644
index 00000000..bcabee06
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/optimization/util/StaticTools.java
@@ -0,0 +1,180 @@
+package optimization.util;
+
+
+import java.io.File;
+import java.io.PrintStream;
+
+public class StaticTools {
+
+ static java.text.DecimalFormat fmt = new java.text.DecimalFormat();
+
+ public static void createDir(String directory) {
+
+ File dir = new File(directory);
+ if (!dir.isDirectory()) {
+ boolean success = dir.mkdirs();
+ if (!success) {
+ System.out.println("Unable to create directory " + directory);
+ System.exit(0);
+ }
+ System.out.println("Created directory " + directory);
+ } else {
+ System.out.println("Reusing directory " + directory);
+ }
+ }
+
+ /*
+ * q and p are indexed by source/foreign Sum_S(q) = 1 the same for p KL(q,p) =
+ * Eq*q/p
+ */
+ public static double KLDistance(double[][] p, double[][] q, int sourceSize,
+ int foreignSize) {
+ double totalKL = 0;
+ // common.StaticTools.printMatrix(q, sourceSize, foreignSize, "q",
+ // System.out);
+ // common.StaticTools.printMatrix(p, sourceSize, foreignSize, "p",
+ // System.out);
+ for (int i = 0; i < sourceSize; i++) {
+ double kl = 0;
+ for (int j = 0; j < foreignSize; j++) {
+ assert !Double.isNaN(q[i][j]) : "KLDistance q: prob is NaN";
+ assert !Double.isNaN(p[i][j]) : "KLDistance p: prob is NaN";
+ if (p[i][j] == 0 || q[i][j] == 0) {
+ continue;
+ } else {
+ kl += q[i][j] * Math.log(q[i][j] / p[i][j]);
+ }
+
+ }
+ totalKL += kl;
+ }
+ assert !Double.isNaN(totalKL) : "KLDistance: prob is NaN";
+ if (totalKL < -1.0E-10) {
+ System.out.println("KL Smaller than zero " + totalKL);
+ System.out.println("Source Size" + sourceSize);
+ System.out.println("Foreign Size" + foreignSize);
+ StaticTools.printMatrix(q, sourceSize, foreignSize, "q",
+ System.out);
+ StaticTools.printMatrix(p, sourceSize, foreignSize, "p",
+ System.out);
+ System.exit(-1);
+ }
+ return totalKL / sourceSize;
+ }
+
+ /*
+ * indexed the by [fi][si]
+ */
+ public static double KLDistancePrime(double[][] p, double[][] q,
+ int sourceSize, int foreignSize) {
+ double totalKL = 0;
+ for (int i = 0; i < sourceSize; i++) {
+ double kl = 0;
+ for (int j = 0; j < foreignSize; j++) {
+ assert !Double.isNaN(q[j][i]) : "KLDistance q: prob is NaN";
+ assert !Double.isNaN(p[j][i]) : "KLDistance p: prob is NaN";
+ if (p[j][i] == 0 || q[j][i] == 0) {
+ continue;
+ } else {
+ kl += q[j][i] * Math.log(q[j][i] / p[j][i]);
+ }
+
+ }
+ totalKL += kl;
+ }
+ assert !Double.isNaN(totalKL) : "KLDistance: prob is NaN";
+ return totalKL / sourceSize;
+ }
+
+ public static double Entropy(double[][] p, int sourceSize, int foreignSize) {
+ double totalE = 0;
+ for (int i = 0; i < foreignSize; i++) {
+ double e = 0;
+ for (int j = 0; j < sourceSize; j++) {
+ e += p[i][j] * Math.log(p[i][j]);
+ }
+ totalE += e;
+ }
+ return totalE / sourceSize;
+ }
+
+ public static double[][] copyMatrix(double[][] original, int sourceSize,
+ int foreignSize) {
+ double[][] result = new double[sourceSize][foreignSize];
+ for (int i = 0; i < sourceSize; i++) {
+ for (int j = 0; j < foreignSize; j++) {
+ result[i][j] = original[i][j];
+ }
+ }
+ return result;
+ }
+
+ public static void printMatrix(double[][] matrix, int sourceSize,
+ int foreignSize, String info, PrintStream out) {
+
+ java.text.DecimalFormat fmt = new java.text.DecimalFormat();
+ fmt.setMaximumFractionDigits(3);
+ fmt.setMaximumIntegerDigits(3);
+ fmt.setMinimumFractionDigits(3);
+ fmt.setMinimumIntegerDigits(3);
+
+ out.println(info);
+
+ for (int i = 0; i < foreignSize; i++) {
+ for (int j = 0; j < sourceSize; j++) {
+ out.print(prettyPrint(matrix[j][i], ".00E00", 6) + " ");
+ }
+ out.println();
+ }
+ out.println();
+ out.println();
+ }
+
+ public static void printMatrix(int[][] matrix, int sourceSize,
+ int foreignSize, String info, PrintStream out) {
+
+ out.println(info);
+ for (int i = 0; i < foreignSize; i++) {
+ for (int j = 0; j < sourceSize; j++) {
+ out.print(matrix[j][i] + " ");
+ }
+ out.println();
+ }
+ out.println();
+ out.println();
+ }
+
+ public static String formatTime(long duration) {
+ StringBuilder sb = new StringBuilder();
+ double d = duration / 1000;
+ fmt.applyPattern("00");
+ sb.append(fmt.format((int) (d / (60 * 60))) + ":");
+ d -= ((int) d / (60 * 60)) * 60 * 60;
+ sb.append(fmt.format((int) (d / 60)) + ":");
+ d -= ((int) d / 60) * 60;
+ fmt.applyPattern("00.0");
+ sb.append(fmt.format(d));
+ return sb.toString();
+ }
+
+ public static String prettyPrint(double d, String patt, int len) {
+ fmt.applyPattern(patt);
+ String s = fmt.format(d);
+ while (s.length() < len) {
+ s = " " + s;
+ }
+ return s;
+ }
+
+
+ public static long getUsedMemory(){
+ System.gc();
+ return (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory())/ (1024 * 1024);
+ }
+
+ public final static boolean compareDoubles(double d1, double d2){
+ return Math.abs(d1-d2) <= 1.E-10;
+ }
+
+
+}