summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/util
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/util')
-rw-r--r--gi/posterior-regularisation/prjava/src/util/Array.java41
-rw-r--r--gi/posterior-regularisation/prjava/src/util/ArrayMath.java186
-rw-r--r--gi/posterior-regularisation/prjava/src/util/DifferentiableObjective.java14
-rw-r--r--gi/posterior-regularisation/prjava/src/util/DigammaFunction.java21
-rw-r--r--gi/posterior-regularisation/prjava/src/util/FileSystem.java21
-rw-r--r--gi/posterior-regularisation/prjava/src/util/InputOutput.java67
-rw-r--r--gi/posterior-regularisation/prjava/src/util/LogSummer.java86
-rw-r--r--gi/posterior-regularisation/prjava/src/util/MathUtil.java148
-rw-r--r--gi/posterior-regularisation/prjava/src/util/Matrix.java16
-rw-r--r--gi/posterior-regularisation/prjava/src/util/MemoryTracker.java47
-rw-r--r--gi/posterior-regularisation/prjava/src/util/Pair.java31
-rw-r--r--gi/posterior-regularisation/prjava/src/util/Printing.java158
-rw-r--r--gi/posterior-regularisation/prjava/src/util/Sorters.java39
13 files changed, 875 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/util/Array.java b/gi/posterior-regularisation/prjava/src/util/Array.java
new file mode 100644
index 00000000..cc4725af
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/util/Array.java
@@ -0,0 +1,41 @@
+package util;
+
+import java.util.Arrays;
+
+public class Array {
+
+
+
+ public static void sortDescending(double[] ds){
+ for (int i = 0; i < ds.length; i++) ds[i] = -ds[i];
+ Arrays.sort(ds);
+ for (int i = 0; i < ds.length; i++) ds[i] = -ds[i];
+ }
+
+ /**
+ * Return a new reversed array
+ * @param array
+ * @return
+ */
+ public static int[] reverseIntArray(int[] array){
+ int[] reversed = new int[array.length];
+ for (int i = 0; i < reversed.length; i++) {
+ reversed[i] = array[reversed.length-1-i];
+ }
+ return reversed;
+ }
+
+ public static String[] sumArray(String[] in, int from){
+ String[] res = new String[in.length-from];
+ for (int i = from; i < in.length; i++) {
+ res[i-from] = in[i];
+ }
+ return res;
+ }
+
+ public static void main(String[] args) {
+ int[] i = {1,2,3,4};
+ util.Printing.printIntArray(i, null, "original");
+ util.Printing.printIntArray(reverseIntArray(i), null, "reversed");
+ }
+}
diff --git a/gi/posterior-regularisation/prjava/src/util/ArrayMath.java b/gi/posterior-regularisation/prjava/src/util/ArrayMath.java
new file mode 100644
index 00000000..398a13a2
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/util/ArrayMath.java
@@ -0,0 +1,186 @@
+package util;
+
+import java.util.Arrays;
+
+public class ArrayMath {
+
+ public static double dotProduct(double[] v1, double[] v2) {
+ assert(v1.length == v2.length);
+ double result = 0;
+ for(int i = 0; i < v1.length; i++)
+ result += v1[i]*v2[i];
+ return result;
+ }
+
+ public static double twoNormSquared(double[] v) {
+ double result = 0;
+ for(double d : v)
+ result += d*d;
+ return result;
+ }
+
+ public static boolean containsInvalid(double[] v) {
+ for(int i = 0; i < v.length; i++)
+ if(Double.isNaN(v[i]) || Double.isInfinite(v[i]))
+ return true;
+ return false;
+ }
+
+
+
+ public static double safeAdd(double[] toAdd) {
+ // Make sure there are no positive infinities
+ double sum = 0;
+ for(int i = 0; i < toAdd.length; i++) {
+ assert(!(Double.isInfinite(toAdd[i]) && toAdd[i] > 0));
+ assert(!Double.isNaN(toAdd[i]));
+ sum += toAdd[i];
+ }
+
+ return sum;
+ }
+
+ /* Methods for filling integer and double arrays (of up to four dimensions) with the given value. */
+
+ public static void set(int[][][][] array, int value) {
+ for(int i = 0; i < array.length; i++) {
+ set(array[i], value);
+ }
+ }
+
+ public static void set(int[][][] array, int value) {
+ for(int i = 0; i < array.length; i++) {
+ set(array[i], value);
+ }
+ }
+
+ public static void set(int[][] array, int value) {
+ for(int i = 0; i < array.length; i++) {
+ set(array[i], value);
+ }
+ }
+
+ public static void set(int[] array, int value) {
+ Arrays.fill(array, value);
+ }
+
+
+ public static void set(double[][][][] array, double value) {
+ for(int i = 0; i < array.length; i++) {
+ set(array[i], value);
+ }
+ }
+
+ public static void set(double[][][] array, double value) {
+ for(int i = 0; i < array.length; i++) {
+ set(array[i], value);
+ }
+ }
+
+ public static void set(double[][] array, double value) {
+ for(int i = 0; i < array.length; i++) {
+ set(array[i], value);
+ }
+ }
+
+ public static void set(double[] array, double value) {
+ Arrays.fill(array, value);
+ }
+
+ public static void setEqual(double[][][][] dest, double[][][][] source){
+ for (int i = 0; i < source.length; i++) {
+ setEqual(dest[i],source[i]);
+ }
+ }
+
+
+ public static void setEqual(double[][][] dest, double[][][] source){
+ for (int i = 0; i < source.length; i++) {
+ set(dest[i],source[i]);
+ }
+ }
+
+
+ public static void set(double[][] dest, double[][] source){
+ for (int i = 0; i < source.length; i++) {
+ setEqual(dest[i],source[i]);
+ }
+ }
+
+ public static void setEqual(double[] dest, double[] source){
+ System.arraycopy(source, 0, dest, 0, source.length);
+ }
+
+ public static void plusEquals(double[][][][] array, double val){
+ for (int i = 0; i < array.length; i++) {
+ plusEquals(array[i], val);
+ }
+ }
+
+ public static void plusEquals(double[][][] array, double val){
+ for (int i = 0; i < array.length; i++) {
+ plusEquals(array[i], val);
+ }
+ }
+
+ public static void plusEquals(double[][] array, double val){
+ for (int i = 0; i < array.length; i++) {
+ plusEquals(array[i], val);
+ }
+ }
+
+ public static void plusEquals(double[] array, double val){
+ for (int i = 0; i < array.length; i++) {
+ array[i] += val;
+ }
+ }
+
+
+ public static double sum(double[] array) {
+ double res = 0;
+ for (int i = 0; i < array.length; i++) res += array[i];
+ return res;
+ }
+
+
+
+ public static double[][] deepclone(double[][] in){
+ double[][] res = new double[in.length][];
+ for (int i = 0; i < res.length; i++) {
+ res[i] = in[i].clone();
+ }
+ return res;
+ }
+
+
+ public static double[][][] deepclone(double[][][] in){
+ double[][][] res = new double[in.length][][];
+ for (int i = 0; i < res.length; i++) {
+ res[i] = deepclone(in[i]);
+ }
+ return res;
+ }
+
+ public static double cosine(double[] a,
+ double[] b) {
+ return (dotProduct(a, b)+1e-5)/(Math.sqrt(dotProduct(a, a)+1e-5)*Math.sqrt(dotProduct(b, b)+1e-5));
+ }
+
+ public static double max(double[] ds) {
+ double max = Double.NEGATIVE_INFINITY;
+ for(double d:ds) max = Math.max(d,max);
+ return max;
+ }
+
+ public static void exponentiate(double[] a) {
+ for (int i = 0; i < a.length; i++) {
+ a[i] = Math.exp(a[i]);
+ }
+ }
+
+ public static int sum(int[] array) {
+ int res = 0;
+ for (int i = 0; i < array.length; i++) res += array[i];
+ return res;
+ }
+}
diff --git a/gi/posterior-regularisation/prjava/src/util/DifferentiableObjective.java b/gi/posterior-regularisation/prjava/src/util/DifferentiableObjective.java
new file mode 100644
index 00000000..1ff1ae4a
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/util/DifferentiableObjective.java
@@ -0,0 +1,14 @@
+package util;
+
+public interface DifferentiableObjective {
+
+ public double getValue();
+
+ public void getGradient(double[] gradient);
+
+ public void getParameters(double[] params);
+
+ public void setParameters(double[] newParameters);
+
+ public int getNumParameters();
+}
diff --git a/gi/posterior-regularisation/prjava/src/util/DigammaFunction.java b/gi/posterior-regularisation/prjava/src/util/DigammaFunction.java
new file mode 100644
index 00000000..ff1478ad
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/util/DigammaFunction.java
@@ -0,0 +1,21 @@
+package util;
+
+public class DigammaFunction {
+ public static double expDigamma(double number){
+ if(number==0)return number;
+ return Math.exp(digamma(number));
+ }
+
+ public static double digamma(double number){
+ if(number > 7){
+ return digammApprox(number-0.5);
+ }else{
+ return digamma(number+1) - 1.0/number;
+ }
+ }
+
+ private static double digammApprox(double value){
+ return Math.log(value) + 0.04167*Math.pow(value, -2) - 0.00729*Math.pow(value, -4)
+ + 0.00384*Math.pow(value, -6) - 0.00413*Math.pow(value, -8);
+ }
+}
diff --git a/gi/posterior-regularisation/prjava/src/util/FileSystem.java b/gi/posterior-regularisation/prjava/src/util/FileSystem.java
new file mode 100644
index 00000000..d7812e40
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/util/FileSystem.java
@@ -0,0 +1,21 @@
+package util;
+
+import java.io.File;
+
+public class FileSystem {
+ public static boolean createDir(String directory) {
+
+ File dir = new File(directory);
+ if (!dir.isDirectory()) {
+ boolean success = dir.mkdirs();
+ if (!success) {
+ System.out.println("Unable to create directory " + directory);
+ return false;
+ }
+ System.out.println("Created directory " + directory);
+ } else {
+ System.out.println("Reusing directory " + directory);
+ }
+ return true;
+ }
+}
diff --git a/gi/posterior-regularisation/prjava/src/util/InputOutput.java b/gi/posterior-regularisation/prjava/src/util/InputOutput.java
new file mode 100644
index 00000000..da7f71bf
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/util/InputOutput.java
@@ -0,0 +1,67 @@
+package util;
+
+import java.io.BufferedReader;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.OutputStream;
+import java.io.PrintStream;
+import java.io.UnsupportedEncodingException;
+import java.util.Properties;
+import java.util.zip.GZIPInputStream;
+import java.util.zip.GZIPOutputStream;
+
+public class InputOutput {
+
+ /**
+ * Opens a file either compress with gzip or not compressed.
+ */
+ public static BufferedReader openReader(String fileName) throws UnsupportedEncodingException, FileNotFoundException, IOException{
+ System.out.println("Reading: " + fileName);
+ BufferedReader reader;
+ fileName = fileName.trim();
+ if(fileName.endsWith("gz")){
+ reader = new BufferedReader(
+ new InputStreamReader(new GZIPInputStream(new FileInputStream(fileName)),"UTF8"));
+ }else{
+ reader = new BufferedReader(new InputStreamReader(
+ new FileInputStream(fileName), "UTF8"));
+ }
+
+ return reader;
+ }
+
+
+ public static PrintStream openWriter(String fileName)
+ throws UnsupportedEncodingException, FileNotFoundException, IOException{
+ System.out.println("Writting to file: " + fileName);
+ PrintStream writter;
+ fileName = fileName.trim();
+ if(fileName.endsWith("gz")){
+ writter = new PrintStream(new GZIPOutputStream(new FileOutputStream(fileName)),
+ true, "UTF-8");
+
+ }else{
+ writter = new PrintStream(new FileOutputStream(fileName),
+ true, "UTF-8");
+
+ }
+
+ return writter;
+ }
+
+ public static Properties readPropertiesFile(String fileName) {
+ Properties properties = new Properties();
+ try {
+ properties.load(new FileInputStream(fileName));
+ } catch (IOException e) {
+ e.printStackTrace();
+ throw new AssertionError("Wrong properties file " + fileName);
+ }
+ System.out.println(properties.toString());
+
+ return properties;
+ }
+}
diff --git a/gi/posterior-regularisation/prjava/src/util/LogSummer.java b/gi/posterior-regularisation/prjava/src/util/LogSummer.java
new file mode 100644
index 00000000..117393b9
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/util/LogSummer.java
@@ -0,0 +1,86 @@
+package util;
+
+import java.lang.Math;
+
+/*
+ * Math tool for computing logs of sums, when the terms of the sum are already in log form.
+ * (Useful if the terms of the sum are very small numbers.)
+ */
+public class LogSummer {
+
+ private LogSummer() {
+ }
+
+ /**
+ * Given log(a) and log(b), computes log(a + b).
+ *
+ * @param loga log of first sum term
+ * @param logb log of second sum term
+ * @return log(sum), where sum = a + b
+ */
+ public static double sum(double loga, double logb) {
+ assert(!Double.isNaN(loga));
+ assert(!Double.isNaN(logb));
+
+ if(Double.isInfinite(loga))
+ return logb;
+ if(Double.isInfinite(logb))
+ return loga;
+
+ double maxLog;
+ double difference;
+ if(loga > logb) {
+ difference = logb - loga;
+ maxLog = loga;
+ }
+ else {
+ difference = loga - logb;
+ maxLog = logb;
+ }
+
+ return Math.log1p(Math.exp(difference)) + maxLog;
+ }
+
+ /**
+ * Computes log(exp(array[index]) + b), and
+ * modifies array[index] to contain this new value.
+ *
+ * @param array array to modify
+ * @param index index at which to modify
+ * @param logb log of the second sum term
+ */
+ public static void sum(double[] array, int index, double logb) {
+ array[index] = sum(array[index], logb);
+ }
+
+ /**
+ * Computes log(a + b + c + ...) from log(a), log(b), log(c), ...
+ * by recursively splitting the input and delegating to the sum method.
+ *
+ * @param terms an array containing the log of all the terms for the sum
+ * @return log(sum), where sum = exp(terms[0]) + exp(terms[1]) + ...
+ */
+ public static double sumAll(double... terms) {
+ return sumAllHelper(terms, 0, terms.length);
+ }
+
+ /**
+ * Computes log(a_0 + a_1 + ...) from a_0 = exp(terms[begin]),
+ * a_1 = exp(terms[begin + 1]), ..., a_{end - 1 - begin} = exp(terms[end - 1]).
+ *
+ * @param terms an array containing the log of all the terms for the sum,
+ * and possibly some other terms that will not go into the sum
+ * @return log of the sum of the elements in the [begin, end) region of the terms array
+ */
+ private static double sumAllHelper(final double[] terms, final int begin, final int end) {
+ int length = end - begin;
+ switch(length) {
+ case 0: return Double.NEGATIVE_INFINITY;
+ case 1: return terms[begin];
+ default:
+ int midIndex = begin + length/2;
+ return sum(sumAllHelper(terms, begin, midIndex), sumAllHelper(terms, midIndex, end));
+ }
+ }
+
+} \ No newline at end of file
diff --git a/gi/posterior-regularisation/prjava/src/util/MathUtil.java b/gi/posterior-regularisation/prjava/src/util/MathUtil.java
new file mode 100644
index 00000000..799b1faf
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/util/MathUtil.java
@@ -0,0 +1,148 @@
+package util;
+
+import java.util.Random;
+
+public class MathUtil {
+ public static final boolean closeToOne(double number){
+ return Math.abs(number-1) < 1.E-10;
+ }
+
+ public static final boolean closeToZero(double number){
+ return Math.abs(number) < 1.E-5;
+ }
+
+ /**
+ * Return a ramdom multinominal distribution.
+ *
+ * @param size
+ * @return
+ */
+ public static final double[] randomVector(int size, Random r){
+ double[] random = new double[size];
+ double sum=0;
+ for(int i = 0; i < size; i++){
+ double number = r.nextDouble();
+ random[i] = number;
+ sum+=number;
+ }
+ for(int i = 0; i < size; i++){
+ random[i] = random[i]/sum;
+ }
+ return random;
+ }
+
+
+
+ public static double sum(double[] ds) {
+ double res = 0;
+ for (int i = 0; i < ds.length; i++) {
+ res+=ds[i];
+ }
+ return res;
+ }
+
+ public static double max(double[] ds) {
+ double res = Double.NEGATIVE_INFINITY;
+ for (int i = 0; i < ds.length; i++) {
+ res = Math.max(res, ds[i]);
+ }
+ return res;
+ }
+
+ public static double min(double[] ds) {
+ double res = Double.POSITIVE_INFINITY;
+ for (int i = 0; i < ds.length; i++) {
+ res = Math.min(res, ds[i]);
+ }
+ return res;
+ }
+
+
+ public static double KLDistance(double[] p, double[] q) {
+ int len = p.length;
+ double kl = 0;
+ for (int j = 0; j < len; j++) {
+ if (p[j] == 0 || q[j] == 0) {
+ continue;
+ } else {
+ kl += q[j] * Math.log(q[j] / p[j]);
+ }
+
+ }
+ return kl;
+ }
+
+ public static double L2Distance(double[] p, double[] q) {
+ int len = p.length;
+ double l2 = 0;
+ for (int j = 0; j < len; j++) {
+ if (p[j] == 0 || q[j] == 0) {
+ continue;
+ } else {
+ l2 += (q[j] - p[j])*(q[j] - p[j]);
+ }
+
+ }
+ return Math.sqrt(l2);
+ }
+
+ public static double L1Distance(double[] p, double[] q) {
+ int len = p.length;
+ double l1 = 0;
+ for (int j = 0; j < len; j++) {
+ if (p[j] == 0 || q[j] == 0) {
+ continue;
+ } else {
+ l1 += Math.abs(q[j] - p[j]);
+ }
+
+ }
+ return l1;
+ }
+
+ public static double dot(double[] ds, double[] ds2) {
+ double res = 0;
+ for (int i = 0; i < ds2.length; i++) {
+ res+= ds[i]*ds2[i];
+ }
+ return res;
+ }
+
+ public static double expDigamma(double number){
+ return Math.exp(digamma(number));
+ }
+
+ public static double digamma(double number){
+ if(number > 7){
+ return digammApprox(number-0.5);
+ }else{
+ return digamma(number+1) - 1.0/number;
+ }
+ }
+
+ private static double digammApprox(double value){
+ return Math.log(value) + 0.04167*Math.pow(value, -2) - 0.00729*Math.pow(value, -4)
+ + 0.00384*Math.pow(value, -6) - 0.00413*Math.pow(value, -8);
+ }
+
+ public static double eulerGamma = 0.57721566490152386060651209008240243;
+ // FIXME -- so far just the initialization from Minka's paper "Estimating a Dirichlet distribution".
+ public static double invDigamma(double y) {
+ if (y>= -2.22) return Math.exp(y)+0.5;
+ return -1.0/(y+eulerGamma);
+ }
+
+
+
+ public static void main(String[] args) {
+ for(double i = 0; i < 10 ; i+=0.1){
+ System.out.println(i+"\t"+expDigamma(i)+"\t"+(i-0.5));
+ }
+// double gammaValue = (expDigamma(3)/expDigamma(10) + expDigamma(3)/expDigamma(10) + expDigamma(4)/expDigamma(10));
+// double normalValue = 3/10+3/4+10/10;
+// System.out.println("Gamma " + gammaValue + " normal " + normalValue);
+ }
+
+
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/util/Matrix.java b/gi/posterior-regularisation/prjava/src/util/Matrix.java
new file mode 100644
index 00000000..8fb6d911
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/util/Matrix.java
@@ -0,0 +1,16 @@
+package util;
+
+public class Matrix {
+ int x;
+ int y;
+ double[][] values;
+
+ public Matrix(int x, int y){
+ this.x = x;
+ this.y=y;
+ values = new double[x][y];
+ }
+
+
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/util/MemoryTracker.java b/gi/posterior-regularisation/prjava/src/util/MemoryTracker.java
new file mode 100644
index 00000000..83a65611
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/util/MemoryTracker.java
@@ -0,0 +1,47 @@
+package util;
+
+
+public class MemoryTracker {
+
+ double initM,finalM;
+ boolean start = false,finish = false;
+
+ public MemoryTracker(){
+
+ }
+
+ public void start(){
+ System.gc();
+ System.gc();
+ System.gc();
+ initM = (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory())/(1024*1024);
+ start = true;
+ }
+
+ public void finish(){
+ if(!start){
+ throw new RuntimeException("Canot stop before starting");
+ }
+ System.gc();
+ System.gc();
+ System.gc();
+ finalM = (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory())/(1024*1024);
+ finish = true;
+ }
+
+ public String print(){
+ if(!finish){
+ throw new RuntimeException("Canot print before stopping");
+ }
+ return "Used: " + (finalM - initM) + "MB";
+ }
+
+ public void clear(){
+ initM = 0;
+ finalM = 0;
+ finish = false;
+ start = false;
+ }
+
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/util/Pair.java b/gi/posterior-regularisation/prjava/src/util/Pair.java
new file mode 100644
index 00000000..7b1f108d
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/util/Pair.java
@@ -0,0 +1,31 @@
+package util;
+
+public class Pair<O1, O2> {
+ public O1 _first;
+ public O2 _second;
+
+ public final O1 first() {
+ return _first;
+ }
+
+ public final O2 second() {
+ return _second;
+ }
+
+ public final void setFirst(O1 value){
+ _first = value;
+ }
+
+ public final void setSecond(O2 value){
+ _second = value;
+ }
+
+ public Pair(O1 first, O2 second) {
+ _first = first;
+ _second = second;
+ }
+
+ public String toString(){
+ return _first + " " + _second;
+ }
+}
diff --git a/gi/posterior-regularisation/prjava/src/util/Printing.java b/gi/posterior-regularisation/prjava/src/util/Printing.java
new file mode 100644
index 00000000..14fcbe91
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/util/Printing.java
@@ -0,0 +1,158 @@
+package util;
+
+public class Printing {
+ static java.text.DecimalFormat fmt = new java.text.DecimalFormat();
+
+ public static String padWithSpace(String s, int len){
+ StringBuffer sb = new StringBuffer();
+ while(sb.length() +s.length() < len){
+ sb.append(" ");
+ }
+ sb.append(s);
+ return sb.toString();
+ }
+
+ public static String prettyPrint(double d, String patt, int len) {
+ fmt.applyPattern(patt);
+ String s = fmt.format(d);
+ while (s.length() < len) {
+ s = " " + s;
+ }
+ return s;
+ }
+
+ public static String formatTime(long duration) {
+ StringBuilder sb = new StringBuilder();
+ double d = duration / 1000;
+ fmt.applyPattern("00");
+ sb.append(fmt.format((int) (d / (60 * 60))) + ":");
+ d -= ((int) d / (60 * 60)) * 60 * 60;
+ sb.append(fmt.format((int) (d / 60)) + ":");
+ d -= ((int) d / 60) * 60;
+ fmt.applyPattern("00.0");
+ sb.append(fmt.format(d));
+ return sb.toString();
+ }
+
+
+ public static String doubleArrayToString(double[] array, String[] labels, String arrayName) {
+ StringBuffer res = new StringBuffer();
+ res.append(arrayName);
+ res.append("\n");
+ for (int i = 0; i < array.length; i++) {
+ if (labels == null){
+ res.append(i+" \t");
+ }else{
+ res.append(labels[i]+ "\t");
+ }
+ }
+ res.append("sum\n");
+ double sum = 0;
+ for (int i = 0; i < array.length; i++) {
+ res.append(prettyPrint(array[i],
+ "0.00000E00", 8) + "\t");
+ sum+=array[i];
+ }
+ res.append(prettyPrint(sum,
+ "0.00000E00", 8)+"\n");
+ return res.toString();
+ }
+
+
+
+ public static void printDoubleArray(double[] array, String labels[], String arrayName) {
+ System.out.println(doubleArrayToString(array, labels,arrayName));
+ }
+
+
+ public static String doubleArrayToString(double[][] array, String[] labels1, String[] labels2,
+ String arrayName){
+ StringBuffer res = new StringBuffer();
+ res.append(arrayName);
+ res.append("\n\t");
+ //Calculates the column sum to keeps the sums
+ double[] sums = new double[array[0].length+1];
+ //Prints rows headings
+ for (int i = 0; i < array[0].length; i++) {
+ if (labels1 == null){
+ res.append(i+" \t");
+ }else{
+ res.append(labels1[i]+" \t");
+ }
+ }
+ res.append("sum\n");
+ double sum = 0;
+ //For each row print heading
+ for (int i = 0; i < array.length; i++) {
+ if (labels2 == null){
+ res.append(i+"\t");
+ }else{
+ res.append(labels2[i]+"\t");
+ }
+ //Print values for that row
+ for (int j = 0; j < array[0].length; j++) {
+ res.append(" " + prettyPrint(array[i][j],
+ "0.00000E00", 8) + "\t");
+ sums[j] += array[i][j];
+ sum+=array[i][j]; //Sum all values of that row
+ }
+ //Print row sum
+ res.append(prettyPrint(sum,"0.00000E00", 8)+"\n");
+ sums[array[0].length]+=sum;
+ sum=0;
+ }
+ res.append("sum\t");
+ //Print values for colums sum
+ for (int i = 0; i < array[0].length+1; i++) {
+ res.append(prettyPrint(sums[i],"0.00000E00", 8)+"\t");
+ }
+ res.append("\n");
+ return res.toString();
+ }
+
+ public static void printDoubleArray(double[][] array, String[] labels1, String[] labels2
+ , String arrayName) {
+ System.out.println(doubleArrayToString(array, labels1,labels2,arrayName));
+ }
+
+
+ public static void printIntArray(int[][] array, String[] labels1, String[] labels2, String arrayName,
+ int size1, int size2) {
+ System.out.println(arrayName);
+ for (int i = 0; i < size1; i++) {
+ for (int j = 0; j < size2; j++) {
+ System.out.print(" " + array[i][j] + " ");
+
+ }
+ System.out.println();
+ }
+ System.out.println();
+ }
+
+ public static String intArrayToString(int[] array, String[] labels, String arrayName) {
+ StringBuffer res = new StringBuffer();
+ res.append(arrayName);
+ for (int i = 0; i < array.length; i++) {
+ res.append(" " + array[i] + " ");
+
+ }
+ res.append("\n");
+ return res.toString();
+ }
+
+ public static void printIntArray(int[] array, String[] labels, String arrayName) {
+ System.out.println(intArrayToString(array, labels,arrayName));
+ }
+
+ public static String toString(double[][] d){
+ StringBuffer sb = new StringBuffer();
+ for (int i = 0; i < d.length; i++) {
+ for (int j = 0; j < d[0].length; j++) {
+ sb.append(prettyPrint(d[i][j], "0.00E0", 10));
+ }
+ sb.append("\n");
+ }
+ return sb.toString();
+ }
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/util/Sorters.java b/gi/posterior-regularisation/prjava/src/util/Sorters.java
new file mode 100644
index 00000000..836444e5
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/util/Sorters.java
@@ -0,0 +1,39 @@
+package util;
+
+import java.util.Comparator;
+
+public class Sorters {
+ public static class sortWordsCounts implements Comparator{
+
+ /**
+ * Sorter for a pair of word id, counts. Sort ascending by counts
+ */
+ public int compare(Object arg0, Object arg1) {
+ Pair<Integer,Integer> p1 = (Pair<Integer,Integer>)arg0;
+ Pair<Integer,Integer> p2 = (Pair<Integer,Integer>)arg1;
+ if(p1.second() > p2.second()){
+ return 1;
+ }else{
+ return -1;
+ }
+ }
+
+ }
+
+public static class sortWordsDouble implements Comparator{
+
+ /**
+ * Sorter for a pair of word id, counts. Sort by counts
+ */
+ public int compare(Object arg0, Object arg1) {
+ Pair<Integer,Double> p1 = (Pair<Integer,Double>)arg0;
+ Pair<Integer,Double> p2 = (Pair<Integer,Double>)arg1;
+ if(p1.second() < p2.second()){
+ return 1;
+ }else{
+ return -1;
+ }
+ }
+
+ }
+}