diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/util')
13 files changed, 875 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/util/Array.java b/gi/posterior-regularisation/prjava/src/util/Array.java new file mode 100644 index 00000000..cc4725af --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/Array.java @@ -0,0 +1,41 @@ +package util; + +import java.util.Arrays; + +public class Array { + +	 +	 +	public  static void sortDescending(double[] ds){ +		for (int i = 0; i < ds.length; i++) ds[i] = -ds[i]; +		Arrays.sort(ds); +		for (int i = 0; i < ds.length; i++) ds[i] = -ds[i]; +	} +	 +	/**  +	 * Return a new reversed array +	 * @param array +	 * @return +	 */ +	public static int[] reverseIntArray(int[] array){ +		int[] reversed = new int[array.length]; +		for (int i = 0; i < reversed.length; i++) { +			reversed[i] = array[reversed.length-1-i]; +		} +		return reversed; +	} +	 +	public static String[] sumArray(String[] in, int from){ +		String[] res = new String[in.length-from]; +		for (int i = from; i < in.length; i++) { +			res[i-from] = in[i]; +		} +		return res; +	} +	 +	public static void main(String[] args) { +		int[] i = {1,2,3,4}; +		util.Printing.printIntArray(i, null, "original"); +		util.Printing.printIntArray(reverseIntArray(i), null, "reversed"); +	} +} diff --git a/gi/posterior-regularisation/prjava/src/util/ArrayMath.java b/gi/posterior-regularisation/prjava/src/util/ArrayMath.java new file mode 100644 index 00000000..398a13a2 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/ArrayMath.java @@ -0,0 +1,186 @@ +package util; + +import java.util.Arrays; + +public class ArrayMath { + +	public static double dotProduct(double[] v1, double[] v2) { +		assert(v1.length == v2.length); +		double result = 0; +		for(int i = 0; i < v1.length; i++) +			result += v1[i]*v2[i]; +		return result; +	} + +	public static double twoNormSquared(double[] v) { +		double result = 0; +		for(double d : v) +			result += d*d; +		return result; +	} + +	public static boolean containsInvalid(double[] v) { +		for(int i = 0; i < v.length; i++) +			if(Double.isNaN(v[i]) || Double.isInfinite(v[i])) +				return true; +		return false; +	} + + +	 +	public static double safeAdd(double[] toAdd) { +		// Make sure there are no positive infinities +		double sum = 0; +		for(int i = 0; i < toAdd.length; i++) { +			assert(!(Double.isInfinite(toAdd[i]) && toAdd[i] > 0)); +			assert(!Double.isNaN(toAdd[i])); +			sum += toAdd[i]; +		} +		 +		return sum; +	} + +	/* Methods for filling integer and double arrays (of up to four dimensions) with the given value. */ +	 +	public static void set(int[][][][] array, int value) { +		for(int i = 0; i < array.length; i++) { +			set(array[i], value); +		} +	} +	 +	public static void set(int[][][] array, int value) { +		for(int i = 0; i < array.length; i++) { +			set(array[i], value); +		} +	} +	 +	public static void set(int[][] array, int value) { +		for(int i = 0; i < array.length; i++) { +			set(array[i], value); +		} +	} +	 +	public static void set(int[] array, int value) { +		Arrays.fill(array, value); +	} +	 +	 +	public static void set(double[][][][] array, double value) { +		for(int i = 0; i < array.length; i++) { +			set(array[i], value); +		} +	} +	 +	public static void set(double[][][] array, double value) { +		for(int i = 0; i < array.length; i++) { +			set(array[i], value); +		} +	} +	 +	public static void set(double[][] array, double value) { +		for(int i = 0; i < array.length; i++) { +			set(array[i], value); +		} +	} +	 +	public static void set(double[] array, double value) { +		Arrays.fill(array, value); +	} + +	public static void setEqual(double[][][][] dest, double[][][][] source){ +		for (int i = 0; i < source.length; i++) { +			setEqual(dest[i],source[i]); +		} +	} + +	 +	public static void setEqual(double[][][] dest, double[][][] source){ +		for (int i = 0; i < source.length; i++) { +			set(dest[i],source[i]); +		} +	} + +	 +	public static void set(double[][] dest, double[][] source){ +		for (int i = 0; i < source.length; i++) { +			setEqual(dest[i],source[i]); +		} +	} + +	public static void setEqual(double[] dest, double[] source){ +		System.arraycopy(source, 0, dest, 0, source.length); +	} + +	public static void plusEquals(double[][][][] array, double val){ +		for (int i = 0; i < array.length; i++) { +			plusEquals(array[i], val); +		} +	}	 +	 +	public static void plusEquals(double[][][] array, double val){ +		for (int i = 0; i < array.length; i++) { +			plusEquals(array[i], val); +		} +	}	 +	 +	public static void plusEquals(double[][] array, double val){ +		for (int i = 0; i < array.length; i++) { +			plusEquals(array[i], val); +		} +	}	 +	 +	public static void plusEquals(double[] array, double val){ +		for (int i = 0; i < array.length; i++) { +			array[i] += val; +		} +	} + +	 +	public static double sum(double[] array) { +		double res = 0; +		for (int i = 0; i < array.length; i++) res += array[i]; +		return res; +	} + + +	 +	public static  double[][] deepclone(double[][] in){ +		double[][] res = new double[in.length][]; +		for (int i = 0; i < res.length; i++) { +			res[i] = in[i].clone(); +		} +		return res; +	} + +	 +	public static  double[][][] deepclone(double[][][] in){ +		double[][][] res = new double[in.length][][]; +		for (int i = 0; i < res.length; i++) { +			res[i] = deepclone(in[i]); +		} +		return res; +	} + +	public static double cosine(double[] a, +			double[] b) { +		return (dotProduct(a, b)+1e-5)/(Math.sqrt(dotProduct(a, a)+1e-5)*Math.sqrt(dotProduct(b, b)+1e-5)); +	} + +	public static double max(double[] ds) { +		double max = Double.NEGATIVE_INFINITY; +		for(double d:ds) max = Math.max(d,max); +		return max; +	} + +	public static void exponentiate(double[] a) { +		for (int i = 0; i < a.length; i++) { +			a[i] = Math.exp(a[i]); +		} +	} + +	public static int sum(int[] array) { +		int res = 0; +		for (int i = 0; i < array.length; i++) res += array[i]; +		return res; +	} +} diff --git a/gi/posterior-regularisation/prjava/src/util/DifferentiableObjective.java b/gi/posterior-regularisation/prjava/src/util/DifferentiableObjective.java new file mode 100644 index 00000000..1ff1ae4a --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/DifferentiableObjective.java @@ -0,0 +1,14 @@ +package util; + +public interface DifferentiableObjective { + +	public double getValue(); + +	public void getGradient(double[] gradient); + +	public void getParameters(double[] params); + +	public void setParameters(double[] newParameters); + +	public int getNumParameters(); +} diff --git a/gi/posterior-regularisation/prjava/src/util/DigammaFunction.java b/gi/posterior-regularisation/prjava/src/util/DigammaFunction.java new file mode 100644 index 00000000..ff1478ad --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/DigammaFunction.java @@ -0,0 +1,21 @@ +package util; + +public class DigammaFunction { +	public static double expDigamma(double number){ +		if(number==0)return number; +		return Math.exp(digamma(number)); +	} +	 +	public static double digamma(double number){ +		if(number > 7){ +			return digammApprox(number-0.5); +		}else{ +			return digamma(number+1) - 1.0/number; +		} +	} +	 +	private static double digammApprox(double value){ +		return Math.log(value) + 0.04167*Math.pow(value, -2) - 0.00729*Math.pow(value, -4)  +		+  0.00384*Math.pow(value, -6) - 0.00413*Math.pow(value, -8); +	} +} diff --git a/gi/posterior-regularisation/prjava/src/util/FileSystem.java b/gi/posterior-regularisation/prjava/src/util/FileSystem.java new file mode 100644 index 00000000..d7812e40 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/FileSystem.java @@ -0,0 +1,21 @@ +package util; + +import java.io.File; + +public class FileSystem { +	public static boolean createDir(String directory) { + +		File dir = new File(directory); +		if (!dir.isDirectory()) { +			boolean success = dir.mkdirs(); +			if (!success) { +				System.out.println("Unable to create directory " + directory); +				return false; +			} +			System.out.println("Created directory " + directory); +		} else { +			System.out.println("Reusing directory " + directory); +		} +		return true; +	} +} diff --git a/gi/posterior-regularisation/prjava/src/util/InputOutput.java b/gi/posterior-regularisation/prjava/src/util/InputOutput.java new file mode 100644 index 00000000..da7f71bf --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/InputOutput.java @@ -0,0 +1,67 @@ +package util; + +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.PrintStream; +import java.io.UnsupportedEncodingException; +import java.util.Properties; +import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; + +public class InputOutput { + +	/** +	 * Opens a file either compress with gzip or not compressed. +	 */ +	public static BufferedReader openReader(String fileName) throws UnsupportedEncodingException, FileNotFoundException, IOException{ +		System.out.println("Reading: " + fileName); +		BufferedReader reader; +		fileName = fileName.trim(); +		if(fileName.endsWith("gz")){ +			reader = new BufferedReader( +			new InputStreamReader(new GZIPInputStream(new FileInputStream(fileName)),"UTF8")); +		}else{ +			reader = new BufferedReader(new InputStreamReader( +					new FileInputStream(fileName), "UTF8")); +		} +		 +		return reader; +	} +	 +	 +	public static PrintStream openWriter(String fileName)  +	throws UnsupportedEncodingException, FileNotFoundException, IOException{ +		System.out.println("Writting to file: " + fileName); +		PrintStream writter; +		fileName = fileName.trim(); +		if(fileName.endsWith("gz")){ +			writter = new PrintStream(new GZIPOutputStream(new FileOutputStream(fileName)), +					true, "UTF-8"); + +		}else{ +			writter = new PrintStream(new FileOutputStream(fileName), +					true, "UTF-8"); + +		} +		 +		return writter; +	} +	 +	public static Properties readPropertiesFile(String fileName) { +		Properties properties = new Properties(); +		try { +			properties.load(new FileInputStream(fileName)); +		} catch (IOException e) { +			e.printStackTrace(); +			throw new AssertionError("Wrong properties file " + fileName); +		} +		System.out.println(properties.toString()); +		 +		return properties; +	} +} diff --git a/gi/posterior-regularisation/prjava/src/util/LogSummer.java b/gi/posterior-regularisation/prjava/src/util/LogSummer.java new file mode 100644 index 00000000..117393b9 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/LogSummer.java @@ -0,0 +1,86 @@ +package util; + +import java.lang.Math; + +/* + * Math tool for computing logs of sums, when the terms of the sum are already in log form. + * (Useful if the terms of the sum are very small numbers.) + */ +public class LogSummer { +	 +	private LogSummer() { +	} +		 +	/** +	 * Given log(a) and log(b), computes log(a + b). +	 *  +	 * @param  loga log of first sum term +	 * @param  logb log of second sum term +	 * @return     log(sum), where sum = a + b +	 */ +	public static double sum(double loga, double logb) { +		assert(!Double.isNaN(loga)); +		assert(!Double.isNaN(logb)); +		 +		if(Double.isInfinite(loga)) +			return logb; +		if(Double.isInfinite(logb)) +			return loga; + +		double maxLog; +		double difference; +		if(loga > logb) { +			difference = logb - loga; +			maxLog = loga; +		} +		else { +			difference = loga - logb; +			maxLog = logb; +		} + +		return Math.log1p(Math.exp(difference)) + maxLog; +	} + +	/** +	 * Computes log(exp(array[index]) + b), and +	 * modifies array[index] to contain this new value. +	 *  +	 * @param array array to modify +	 * @param index index at which to modify +	 * @param logb  log of the second sum term +	 */ +	public static void sum(double[] array, int index, double logb) { +		array[index] = sum(array[index], logb); +	} +	 +	/** +	 * Computes log(a + b + c + ...) from log(a), log(b), log(c), ... +	 * by recursively splitting the input and delegating to the sum method. +	 *  +	 * @param  terms an array containing the log of all the terms for the sum +	 * @return log(sum), where sum = exp(terms[0]) + exp(terms[1]) + ... +	 */ +	public static double sumAll(double... terms) { +		return sumAllHelper(terms, 0, terms.length); +	} +	 +	/** +	 * Computes log(a_0 + a_1 + ...) from a_0 = exp(terms[begin]), +	 * a_1 = exp(terms[begin + 1]), ..., a_{end - 1 - begin} = exp(terms[end - 1]). +	 *  +	 * @param  terms an array containing the log of all the terms for the sum, +	 *               and possibly some other terms that will not go into the sum +	 * @return log of the sum of the elements in the [begin, end) region of the terms array +	 */ +	private static double sumAllHelper(final double[] terms, final int begin, final int end) { +		int length = end - begin; +		switch(length) { +			case 0: return Double.NEGATIVE_INFINITY; +			case 1: return terms[begin]; +			default: +				int midIndex = begin + length/2; +				return sum(sumAllHelper(terms, begin, midIndex), sumAllHelper(terms, midIndex, end)); +		} +	} + +}
\ No newline at end of file diff --git a/gi/posterior-regularisation/prjava/src/util/MathUtil.java b/gi/posterior-regularisation/prjava/src/util/MathUtil.java new file mode 100644 index 00000000..799b1faf --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/MathUtil.java @@ -0,0 +1,148 @@ +package util; + +import java.util.Random; + +public class MathUtil { +	public static final boolean closeToOne(double number){ +		return Math.abs(number-1) < 1.E-10; +	} +	 +	public static final boolean closeToZero(double number){ +		return Math.abs(number) < 1.E-5; +	} +	 +	/** +	 * Return a ramdom multinominal distribution. +	 *  +	 * @param size +	 * @return +	 */ +	public static final double[] randomVector(int size, Random r){ +		double[] random = new double[size]; +		double sum=0; +		for(int i = 0; i < size; i++){ +			double number = r.nextDouble(); +			random[i] = number; +			sum+=number; +		} +		for(int i = 0; i < size; i++){ +			random[i] = random[i]/sum; +		} +		return random; +	} +	 +	 + +	public static double sum(double[] ds) { +		double res = 0; +		for (int i = 0; i < ds.length; i++) { +			res+=ds[i]; +		} +		return res; +	} + +	public static double max(double[] ds) { +		double res = Double.NEGATIVE_INFINITY; +		for (int i = 0; i < ds.length; i++) { +			res = Math.max(res, ds[i]); +		} +		return res; +	} + +	public static double min(double[] ds) { +		double res = Double.POSITIVE_INFINITY; +		for (int i = 0; i < ds.length; i++) { +			res = Math.min(res, ds[i]); +		} +		return res; +	} + +	 +	public static double KLDistance(double[] p, double[] q) { +		int len = p.length; +		double kl = 0; +		for (int j = 0; j < len; j++) { +				if (p[j] == 0 || q[j] == 0) { +					continue; +				} else { +					kl += q[j] * Math.log(q[j] / p[j]); +				} + +		} +		return kl; +	} +	 +	public static double L2Distance(double[] p, double[] q) { +		int len = p.length; +		double l2 = 0; +		for (int j = 0; j < len; j++) { +				if (p[j] == 0 || q[j] == 0) { +					continue; +				} else { +					l2 += (q[j] - p[j])*(q[j] - p[j]); +				} + +		} +		return Math.sqrt(l2); +	} +	 +	public static double L1Distance(double[] p, double[] q) { +		int len = p.length; +		double l1 = 0; +		for (int j = 0; j < len; j++) { +				if (p[j] == 0 || q[j] == 0) { +					continue; +				} else { +					l1 += Math.abs(q[j] - p[j]); +				} + +		} +		return l1; +	} + +	public static double dot(double[] ds, double[] ds2) { +		double res = 0; +		for (int i = 0; i < ds2.length; i++) { +			res+= ds[i]*ds2[i]; +		} +		return res; +	} +	 +	public static double expDigamma(double number){ +		return Math.exp(digamma(number)); +	} +	 +	public static double digamma(double number){ +		if(number > 7){ +			return digammApprox(number-0.5); +		}else{ +			return digamma(number+1) - 1.0/number; +		} +	} +	 +	private static double digammApprox(double value){ +		return Math.log(value) + 0.04167*Math.pow(value, -2) - 0.00729*Math.pow(value, -4)  +		+  0.00384*Math.pow(value, -6) - 0.00413*Math.pow(value, -8); +	} + +	public static double eulerGamma = 0.57721566490152386060651209008240243; +	// FIXME -- so far just the initialization from Minka's paper "Estimating a Dirichlet distribution".  +	public static double invDigamma(double y) { +		if (y>= -2.22) return Math.exp(y)+0.5; +		return -1.0/(y+eulerGamma); +	} + +	 +	 +	public static void main(String[] args) { +		for(double i = 0; i < 10 ; i+=0.1){ +			System.out.println(i+"\t"+expDigamma(i)+"\t"+(i-0.5)); +		} +//		double gammaValue = (expDigamma(3)/expDigamma(10) + expDigamma(3)/expDigamma(10) + expDigamma(4)/expDigamma(10)); +//		double normalValue = 3/10+3/4+10/10; +//		System.out.println("Gamma " + gammaValue + " normal " + normalValue); +	} + +	 +	 +} diff --git a/gi/posterior-regularisation/prjava/src/util/Matrix.java b/gi/posterior-regularisation/prjava/src/util/Matrix.java new file mode 100644 index 00000000..8fb6d911 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/Matrix.java @@ -0,0 +1,16 @@ +package util; + +public class Matrix { +	int x; +	int y; +	double[][] values; +	 +	public Matrix(int x, int y){ +		this.x = x; +		this.y=y; +		values = new double[x][y]; +	} +	 +	 +	 +} diff --git a/gi/posterior-regularisation/prjava/src/util/MemoryTracker.java b/gi/posterior-regularisation/prjava/src/util/MemoryTracker.java new file mode 100644 index 00000000..83a65611 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/MemoryTracker.java @@ -0,0 +1,47 @@ +package util; + + +public class MemoryTracker { +	 +	double initM,finalM; +	boolean start = false,finish = false; +	 +	public MemoryTracker(){ +		 +	} +	 +	public void start(){ +		System.gc(); +	    System.gc(); +	    System.gc(); +	    initM = (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory())/(1024*1024);   +	    start = true; +	} +	 +	public void finish(){ +		if(!start){ +			throw new RuntimeException("Canot stop before starting"); +		} +		System.gc(); +	    System.gc(); +	    System.gc(); +	    finalM = (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory())/(1024*1024);   +	    finish = true; +	} +	 +	public String print(){ +		if(!finish){ +			throw new RuntimeException("Canot print before stopping"); +		} +		return "Used: " + (finalM - initM) + "MB"; +	} +	 +	public void clear(){ +		initM = 0; +		finalM = 0; +		finish = false; +		start = false; +	} +	 +	 +} diff --git a/gi/posterior-regularisation/prjava/src/util/Pair.java b/gi/posterior-regularisation/prjava/src/util/Pair.java new file mode 100644 index 00000000..7b1f108d --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/Pair.java @@ -0,0 +1,31 @@ +package util; + +public class Pair<O1, O2> { +	public O1 _first; +	public O2 _second; + +	public final O1 first() { +		return _first; +	} + +	public final O2 second() { +		return _second; +	} + +	public final void setFirst(O1 value){ +		_first = value; +	} +	 +	public final void setSecond(O2 value){ +		_second = value; +	} +	 +	public Pair(O1 first, O2 second) { +		_first = first; +		_second = second; +	} + +	public String toString(){ +		return _first + " " + _second;  +	} +} diff --git a/gi/posterior-regularisation/prjava/src/util/Printing.java b/gi/posterior-regularisation/prjava/src/util/Printing.java new file mode 100644 index 00000000..14fcbe91 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/Printing.java @@ -0,0 +1,158 @@ +package util; + +public class Printing { +	static java.text.DecimalFormat fmt = new java.text.DecimalFormat(); + +	public static String padWithSpace(String s, int len){ +		StringBuffer sb = new StringBuffer(); +		while(sb.length() +s.length() < len){ +			sb.append(" "); +		} +		sb.append(s); +		return sb.toString(); +	} +	 +	public static String prettyPrint(double d, String patt, int len) { +		fmt.applyPattern(patt); +		String s = fmt.format(d); +		while (s.length() < len) { +			s = " " + s; +		} +		return s; +	} +	 +	public static  String formatTime(long duration) { +		StringBuilder sb = new StringBuilder(); +		double d = duration / 1000; +		fmt.applyPattern("00"); +		sb.append(fmt.format((int) (d / (60 * 60))) + ":"); +		d -= ((int) d / (60 * 60)) * 60 * 60; +		sb.append(fmt.format((int) (d / 60)) + ":"); +		d -= ((int) d / 60) * 60; +		fmt.applyPattern("00.0"); +		sb.append(fmt.format(d)); +		return sb.toString(); +	} +	 +	 +	public static String doubleArrayToString(double[] array, String[] labels, String arrayName) { +		StringBuffer res = new StringBuffer(); +		res.append(arrayName); +		res.append("\n"); +		for (int i = 0; i < array.length; i++) { +			if (labels == null){ +				res.append(i+"       \t"); +			}else{ +				res.append(labels[i]+     "\t"); +			} +		} +		res.append("sum\n"); +		double sum = 0; +		for (int i = 0; i < array.length; i++) { +			res.append(prettyPrint(array[i], +					"0.00000E00", 8) + "\t"); +			sum+=array[i]; +		} +		res.append(prettyPrint(sum, +				"0.00000E00", 8)+"\n"); +		return res.toString(); +	} +	 +	 +	 +	public static void printDoubleArray(double[] array, String labels[], String arrayName) { +		System.out.println(doubleArrayToString(array, labels,arrayName)); +	} +	 +	 +	public static String doubleArrayToString(double[][] array, String[] labels1, String[] labels2, +			String arrayName){ +		StringBuffer res = new StringBuffer(); +		res.append(arrayName); +		res.append("\n\t"); +		//Calculates the column sum to keeps the sums +		double[] sums = new double[array[0].length+1]; +		//Prints rows headings +		for (int i = 0; i < array[0].length; i++) { +			if (labels1 == null){ +				res.append(i+"        \t"); +			}else{ +				res.append(labels1[i]+"        \t"); +			} +		} +		res.append("sum\n"); +		double sum = 0; +		//For each row print heading +		for (int i = 0; i < array.length; i++) { +			if (labels2 == null){ +				res.append(i+"\t"); +			}else{ +				res.append(labels2[i]+"\t"); +			} +			//Print values for that row +			for (int j = 0; j < array[0].length; j++) { +				res.append(" " + prettyPrint(array[i][j], +						"0.00000E00", 8) + "\t"); +				sums[j] += array[i][j];  +				sum+=array[i][j]; //Sum all values of that row +			} +			//Print row sum +			res.append(prettyPrint(sum,"0.00000E00", 8)+"\n"); +			sums[array[0].length]+=sum; +			sum=0; +		} +		res.append("sum\t"); +		//Print values for colums sum +		for (int i = 0; i < array[0].length+1; i++) { +			res.append(prettyPrint(sums[i],"0.00000E00", 8)+"\t"); +		} +		res.append("\n"); +		return res.toString(); +	} +	 +	public static void printDoubleArray(double[][] array, String[] labels1, String[] labels2 +			, String arrayName) { +		System.out.println(doubleArrayToString(array, labels1,labels2,arrayName)); +	} +	 +	 +	public static void printIntArray(int[][] array, String[] labels1, String[] labels2, String arrayName, +			int size1, int size2) { +		System.out.println(arrayName); +		for (int i = 0; i < size1; i++) { +			for (int j = 0; j < size2; j++) { +				System.out.print(" " + array[i][j] +  " "); + +			} +			System.out.println(); +		} +		System.out.println(); +	} +	 +	public static String intArrayToString(int[] array, String[] labels, String arrayName) { +		StringBuffer res = new StringBuffer(); +		res.append(arrayName); +		for (int i = 0; i < array.length; i++) { +			res.append(" " + array[i] + " "); +			 +		} +		res.append("\n"); +		return res.toString(); +	} +	 +	public static void printIntArray(int[] array, String[] labels, String arrayName) { +		System.out.println(intArrayToString(array, labels,arrayName)); +	} +	 +	public static String toString(double[][] d){ +		StringBuffer sb = new StringBuffer(); +		for (int i = 0; i < d.length; i++) { +			for (int j = 0; j < d[0].length; j++) { +				sb.append(prettyPrint(d[i][j], "0.00E0", 10)); +			} +			sb.append("\n"); +		} +		return sb.toString(); +	} +	 +} diff --git a/gi/posterior-regularisation/prjava/src/util/Sorters.java b/gi/posterior-regularisation/prjava/src/util/Sorters.java new file mode 100644 index 00000000..836444e5 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/Sorters.java @@ -0,0 +1,39 @@ +package util; + +import java.util.Comparator; + +public class Sorters { +	public static class sortWordsCounts implements Comparator{ +		 +		/** +		 * Sorter for a pair of word id, counts. Sort ascending by counts +		 */ +		public int compare(Object arg0, Object arg1) { +			Pair<Integer,Integer> p1 = (Pair<Integer,Integer>)arg0; +			Pair<Integer,Integer> p2 = (Pair<Integer,Integer>)arg1; +			if(p1.second() > p2.second()){ +				return 1; +			}else{ +				return -1; +			} +		} +		 +	} +	 +public static class sortWordsDouble implements Comparator{ +		 +		/** +		 * Sorter for a pair of word id, counts. Sort by counts +		 */ +		public int compare(Object arg0, Object arg1) { +			Pair<Integer,Double> p1 = (Pair<Integer,Double>)arg0; +			Pair<Integer,Double> p2 = (Pair<Integer,Double>)arg1; +			if(p1.second() < p2.second()){ +				return 1; +			}else{ +				return -1; +			} +		} +		 +	} +}  | 
