diff options
author | desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-09 18:32:00 +0000 |
---|---|---|
committer | desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-09 18:32:00 +0000 |
commit | 23c1f218e0cb1b0849fe5170526b8949cd8c969e (patch) | |
tree | d0e466bd001fea02866eed585850ea7ceecb7d00 /gi/posterior-regularisation/prjava/src/util | |
parent | b128b2b107b3e7b77cfd5dbc578c060e52767afa (diff) |
forgot to add util folder in optimization library
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@206 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/util')
13 files changed, 875 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/util/Array.java b/gi/posterior-regularisation/prjava/src/util/Array.java new file mode 100644 index 00000000..cc4725af --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/Array.java @@ -0,0 +1,41 @@ +package util; + +import java.util.Arrays; + +public class Array { + + + + public static void sortDescending(double[] ds){ + for (int i = 0; i < ds.length; i++) ds[i] = -ds[i]; + Arrays.sort(ds); + for (int i = 0; i < ds.length; i++) ds[i] = -ds[i]; + } + + /** + * Return a new reversed array + * @param array + * @return + */ + public static int[] reverseIntArray(int[] array){ + int[] reversed = new int[array.length]; + for (int i = 0; i < reversed.length; i++) { + reversed[i] = array[reversed.length-1-i]; + } + return reversed; + } + + public static String[] sumArray(String[] in, int from){ + String[] res = new String[in.length-from]; + for (int i = from; i < in.length; i++) { + res[i-from] = in[i]; + } + return res; + } + + public static void main(String[] args) { + int[] i = {1,2,3,4}; + util.Printing.printIntArray(i, null, "original"); + util.Printing.printIntArray(reverseIntArray(i), null, "reversed"); + } +} diff --git a/gi/posterior-regularisation/prjava/src/util/ArrayMath.java b/gi/posterior-regularisation/prjava/src/util/ArrayMath.java new file mode 100644 index 00000000..398a13a2 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/ArrayMath.java @@ -0,0 +1,186 @@ +package util; + +import java.util.Arrays; + +public class ArrayMath { + + public static double dotProduct(double[] v1, double[] v2) { + assert(v1.length == v2.length); + double result = 0; + for(int i = 0; i < v1.length; i++) + result += v1[i]*v2[i]; + return result; + } + + public static double twoNormSquared(double[] v) { + double result = 0; + for(double d : v) + result += d*d; + return result; + } + + public static boolean containsInvalid(double[] v) { + for(int i = 0; i < v.length; i++) + if(Double.isNaN(v[i]) || Double.isInfinite(v[i])) + return true; + return false; + } + + + + public static double safeAdd(double[] toAdd) { + // Make sure there are no positive infinities + double sum = 0; + for(int i = 0; i < toAdd.length; i++) { + assert(!(Double.isInfinite(toAdd[i]) && toAdd[i] > 0)); + assert(!Double.isNaN(toAdd[i])); + sum += toAdd[i]; + } + + return sum; + } + + /* Methods for filling integer and double arrays (of up to four dimensions) with the given value. */ + + public static void set(int[][][][] array, int value) { + for(int i = 0; i < array.length; i++) { + set(array[i], value); + } + } + + public static void set(int[][][] array, int value) { + for(int i = 0; i < array.length; i++) { + set(array[i], value); + } + } + + public static void set(int[][] array, int value) { + for(int i = 0; i < array.length; i++) { + set(array[i], value); + } + } + + public static void set(int[] array, int value) { + Arrays.fill(array, value); + } + + + public static void set(double[][][][] array, double value) { + for(int i = 0; i < array.length; i++) { + set(array[i], value); + } + } + + public static void set(double[][][] array, double value) { + for(int i = 0; i < array.length; i++) { + set(array[i], value); + } + } + + public static void set(double[][] array, double value) { + for(int i = 0; i < array.length; i++) { + set(array[i], value); + } + } + + public static void set(double[] array, double value) { + Arrays.fill(array, value); + } + + public static void setEqual(double[][][][] dest, double[][][][] source){ + for (int i = 0; i < source.length; i++) { + setEqual(dest[i],source[i]); + } + } + + + public static void setEqual(double[][][] dest, double[][][] source){ + for (int i = 0; i < source.length; i++) { + set(dest[i],source[i]); + } + } + + + public static void set(double[][] dest, double[][] source){ + for (int i = 0; i < source.length; i++) { + setEqual(dest[i],source[i]); + } + } + + public static void setEqual(double[] dest, double[] source){ + System.arraycopy(source, 0, dest, 0, source.length); + } + + public static void plusEquals(double[][][][] array, double val){ + for (int i = 0; i < array.length; i++) { + plusEquals(array[i], val); + } + } + + public static void plusEquals(double[][][] array, double val){ + for (int i = 0; i < array.length; i++) { + plusEquals(array[i], val); + } + } + + public static void plusEquals(double[][] array, double val){ + for (int i = 0; i < array.length; i++) { + plusEquals(array[i], val); + } + } + + public static void plusEquals(double[] array, double val){ + for (int i = 0; i < array.length; i++) { + array[i] += val; + } + } + + + public static double sum(double[] array) { + double res = 0; + for (int i = 0; i < array.length; i++) res += array[i]; + return res; + } + + + + public static double[][] deepclone(double[][] in){ + double[][] res = new double[in.length][]; + for (int i = 0; i < res.length; i++) { + res[i] = in[i].clone(); + } + return res; + } + + + public static double[][][] deepclone(double[][][] in){ + double[][][] res = new double[in.length][][]; + for (int i = 0; i < res.length; i++) { + res[i] = deepclone(in[i]); + } + return res; + } + + public static double cosine(double[] a, + double[] b) { + return (dotProduct(a, b)+1e-5)/(Math.sqrt(dotProduct(a, a)+1e-5)*Math.sqrt(dotProduct(b, b)+1e-5)); + } + + public static double max(double[] ds) { + double max = Double.NEGATIVE_INFINITY; + for(double d:ds) max = Math.max(d,max); + return max; + } + + public static void exponentiate(double[] a) { + for (int i = 0; i < a.length; i++) { + a[i] = Math.exp(a[i]); + } + } + + public static int sum(int[] array) { + int res = 0; + for (int i = 0; i < array.length; i++) res += array[i]; + return res; + } +} diff --git a/gi/posterior-regularisation/prjava/src/util/DifferentiableObjective.java b/gi/posterior-regularisation/prjava/src/util/DifferentiableObjective.java new file mode 100644 index 00000000..1ff1ae4a --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/DifferentiableObjective.java @@ -0,0 +1,14 @@ +package util; + +public interface DifferentiableObjective { + + public double getValue(); + + public void getGradient(double[] gradient); + + public void getParameters(double[] params); + + public void setParameters(double[] newParameters); + + public int getNumParameters(); +} diff --git a/gi/posterior-regularisation/prjava/src/util/DigammaFunction.java b/gi/posterior-regularisation/prjava/src/util/DigammaFunction.java new file mode 100644 index 00000000..ff1478ad --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/DigammaFunction.java @@ -0,0 +1,21 @@ +package util; + +public class DigammaFunction { + public static double expDigamma(double number){ + if(number==0)return number; + return Math.exp(digamma(number)); + } + + public static double digamma(double number){ + if(number > 7){ + return digammApprox(number-0.5); + }else{ + return digamma(number+1) - 1.0/number; + } + } + + private static double digammApprox(double value){ + return Math.log(value) + 0.04167*Math.pow(value, -2) - 0.00729*Math.pow(value, -4) + + 0.00384*Math.pow(value, -6) - 0.00413*Math.pow(value, -8); + } +} diff --git a/gi/posterior-regularisation/prjava/src/util/FileSystem.java b/gi/posterior-regularisation/prjava/src/util/FileSystem.java new file mode 100644 index 00000000..d7812e40 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/FileSystem.java @@ -0,0 +1,21 @@ +package util; + +import java.io.File; + +public class FileSystem { + public static boolean createDir(String directory) { + + File dir = new File(directory); + if (!dir.isDirectory()) { + boolean success = dir.mkdirs(); + if (!success) { + System.out.println("Unable to create directory " + directory); + return false; + } + System.out.println("Created directory " + directory); + } else { + System.out.println("Reusing directory " + directory); + } + return true; + } +} diff --git a/gi/posterior-regularisation/prjava/src/util/InputOutput.java b/gi/posterior-regularisation/prjava/src/util/InputOutput.java new file mode 100644 index 00000000..da7f71bf --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/InputOutput.java @@ -0,0 +1,67 @@ +package util; + +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.PrintStream; +import java.io.UnsupportedEncodingException; +import java.util.Properties; +import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; + +public class InputOutput { + + /** + * Opens a file either compress with gzip or not compressed. + */ + public static BufferedReader openReader(String fileName) throws UnsupportedEncodingException, FileNotFoundException, IOException{ + System.out.println("Reading: " + fileName); + BufferedReader reader; + fileName = fileName.trim(); + if(fileName.endsWith("gz")){ + reader = new BufferedReader( + new InputStreamReader(new GZIPInputStream(new FileInputStream(fileName)),"UTF8")); + }else{ + reader = new BufferedReader(new InputStreamReader( + new FileInputStream(fileName), "UTF8")); + } + + return reader; + } + + + public static PrintStream openWriter(String fileName) + throws UnsupportedEncodingException, FileNotFoundException, IOException{ + System.out.println("Writting to file: " + fileName); + PrintStream writter; + fileName = fileName.trim(); + if(fileName.endsWith("gz")){ + writter = new PrintStream(new GZIPOutputStream(new FileOutputStream(fileName)), + true, "UTF-8"); + + }else{ + writter = new PrintStream(new FileOutputStream(fileName), + true, "UTF-8"); + + } + + return writter; + } + + public static Properties readPropertiesFile(String fileName) { + Properties properties = new Properties(); + try { + properties.load(new FileInputStream(fileName)); + } catch (IOException e) { + e.printStackTrace(); + throw new AssertionError("Wrong properties file " + fileName); + } + System.out.println(properties.toString()); + + return properties; + } +} diff --git a/gi/posterior-regularisation/prjava/src/util/LogSummer.java b/gi/posterior-regularisation/prjava/src/util/LogSummer.java new file mode 100644 index 00000000..117393b9 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/LogSummer.java @@ -0,0 +1,86 @@ +package util; + +import java.lang.Math; + +/* + * Math tool for computing logs of sums, when the terms of the sum are already in log form. + * (Useful if the terms of the sum are very small numbers.) + */ +public class LogSummer { + + private LogSummer() { + } + + /** + * Given log(a) and log(b), computes log(a + b). + * + * @param loga log of first sum term + * @param logb log of second sum term + * @return log(sum), where sum = a + b + */ + public static double sum(double loga, double logb) { + assert(!Double.isNaN(loga)); + assert(!Double.isNaN(logb)); + + if(Double.isInfinite(loga)) + return logb; + if(Double.isInfinite(logb)) + return loga; + + double maxLog; + double difference; + if(loga > logb) { + difference = logb - loga; + maxLog = loga; + } + else { + difference = loga - logb; + maxLog = logb; + } + + return Math.log1p(Math.exp(difference)) + maxLog; + } + + /** + * Computes log(exp(array[index]) + b), and + * modifies array[index] to contain this new value. + * + * @param array array to modify + * @param index index at which to modify + * @param logb log of the second sum term + */ + public static void sum(double[] array, int index, double logb) { + array[index] = sum(array[index], logb); + } + + /** + * Computes log(a + b + c + ...) from log(a), log(b), log(c), ... + * by recursively splitting the input and delegating to the sum method. + * + * @param terms an array containing the log of all the terms for the sum + * @return log(sum), where sum = exp(terms[0]) + exp(terms[1]) + ... + */ + public static double sumAll(double... terms) { + return sumAllHelper(terms, 0, terms.length); + } + + /** + * Computes log(a_0 + a_1 + ...) from a_0 = exp(terms[begin]), + * a_1 = exp(terms[begin + 1]), ..., a_{end - 1 - begin} = exp(terms[end - 1]). + * + * @param terms an array containing the log of all the terms for the sum, + * and possibly some other terms that will not go into the sum + * @return log of the sum of the elements in the [begin, end) region of the terms array + */ + private static double sumAllHelper(final double[] terms, final int begin, final int end) { + int length = end - begin; + switch(length) { + case 0: return Double.NEGATIVE_INFINITY; + case 1: return terms[begin]; + default: + int midIndex = begin + length/2; + return sum(sumAllHelper(terms, begin, midIndex), sumAllHelper(terms, midIndex, end)); + } + } + +}
\ No newline at end of file diff --git a/gi/posterior-regularisation/prjava/src/util/MathUtil.java b/gi/posterior-regularisation/prjava/src/util/MathUtil.java new file mode 100644 index 00000000..799b1faf --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/MathUtil.java @@ -0,0 +1,148 @@ +package util; + +import java.util.Random; + +public class MathUtil { + public static final boolean closeToOne(double number){ + return Math.abs(number-1) < 1.E-10; + } + + public static final boolean closeToZero(double number){ + return Math.abs(number) < 1.E-5; + } + + /** + * Return a ramdom multinominal distribution. + * + * @param size + * @return + */ + public static final double[] randomVector(int size, Random r){ + double[] random = new double[size]; + double sum=0; + for(int i = 0; i < size; i++){ + double number = r.nextDouble(); + random[i] = number; + sum+=number; + } + for(int i = 0; i < size; i++){ + random[i] = random[i]/sum; + } + return random; + } + + + + public static double sum(double[] ds) { + double res = 0; + for (int i = 0; i < ds.length; i++) { + res+=ds[i]; + } + return res; + } + + public static double max(double[] ds) { + double res = Double.NEGATIVE_INFINITY; + for (int i = 0; i < ds.length; i++) { + res = Math.max(res, ds[i]); + } + return res; + } + + public static double min(double[] ds) { + double res = Double.POSITIVE_INFINITY; + for (int i = 0; i < ds.length; i++) { + res = Math.min(res, ds[i]); + } + return res; + } + + + public static double KLDistance(double[] p, double[] q) { + int len = p.length; + double kl = 0; + for (int j = 0; j < len; j++) { + if (p[j] == 0 || q[j] == 0) { + continue; + } else { + kl += q[j] * Math.log(q[j] / p[j]); + } + + } + return kl; + } + + public static double L2Distance(double[] p, double[] q) { + int len = p.length; + double l2 = 0; + for (int j = 0; j < len; j++) { + if (p[j] == 0 || q[j] == 0) { + continue; + } else { + l2 += (q[j] - p[j])*(q[j] - p[j]); + } + + } + return Math.sqrt(l2); + } + + public static double L1Distance(double[] p, double[] q) { + int len = p.length; + double l1 = 0; + for (int j = 0; j < len; j++) { + if (p[j] == 0 || q[j] == 0) { + continue; + } else { + l1 += Math.abs(q[j] - p[j]); + } + + } + return l1; + } + + public static double dot(double[] ds, double[] ds2) { + double res = 0; + for (int i = 0; i < ds2.length; i++) { + res+= ds[i]*ds2[i]; + } + return res; + } + + public static double expDigamma(double number){ + return Math.exp(digamma(number)); + } + + public static double digamma(double number){ + if(number > 7){ + return digammApprox(number-0.5); + }else{ + return digamma(number+1) - 1.0/number; + } + } + + private static double digammApprox(double value){ + return Math.log(value) + 0.04167*Math.pow(value, -2) - 0.00729*Math.pow(value, -4) + + 0.00384*Math.pow(value, -6) - 0.00413*Math.pow(value, -8); + } + + public static double eulerGamma = 0.57721566490152386060651209008240243; + // FIXME -- so far just the initialization from Minka's paper "Estimating a Dirichlet distribution". + public static double invDigamma(double y) { + if (y>= -2.22) return Math.exp(y)+0.5; + return -1.0/(y+eulerGamma); + } + + + + public static void main(String[] args) { + for(double i = 0; i < 10 ; i+=0.1){ + System.out.println(i+"\t"+expDigamma(i)+"\t"+(i-0.5)); + } +// double gammaValue = (expDigamma(3)/expDigamma(10) + expDigamma(3)/expDigamma(10) + expDigamma(4)/expDigamma(10)); +// double normalValue = 3/10+3/4+10/10; +// System.out.println("Gamma " + gammaValue + " normal " + normalValue); + } + + + +} diff --git a/gi/posterior-regularisation/prjava/src/util/Matrix.java b/gi/posterior-regularisation/prjava/src/util/Matrix.java new file mode 100644 index 00000000..8fb6d911 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/Matrix.java @@ -0,0 +1,16 @@ +package util; + +public class Matrix { + int x; + int y; + double[][] values; + + public Matrix(int x, int y){ + this.x = x; + this.y=y; + values = new double[x][y]; + } + + + +} diff --git a/gi/posterior-regularisation/prjava/src/util/MemoryTracker.java b/gi/posterior-regularisation/prjava/src/util/MemoryTracker.java new file mode 100644 index 00000000..83a65611 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/MemoryTracker.java @@ -0,0 +1,47 @@ +package util; + + +public class MemoryTracker { + + double initM,finalM; + boolean start = false,finish = false; + + public MemoryTracker(){ + + } + + public void start(){ + System.gc(); + System.gc(); + System.gc(); + initM = (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory())/(1024*1024); + start = true; + } + + public void finish(){ + if(!start){ + throw new RuntimeException("Canot stop before starting"); + } + System.gc(); + System.gc(); + System.gc(); + finalM = (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory())/(1024*1024); + finish = true; + } + + public String print(){ + if(!finish){ + throw new RuntimeException("Canot print before stopping"); + } + return "Used: " + (finalM - initM) + "MB"; + } + + public void clear(){ + initM = 0; + finalM = 0; + finish = false; + start = false; + } + + +} diff --git a/gi/posterior-regularisation/prjava/src/util/Pair.java b/gi/posterior-regularisation/prjava/src/util/Pair.java new file mode 100644 index 00000000..7b1f108d --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/Pair.java @@ -0,0 +1,31 @@ +package util; + +public class Pair<O1, O2> { + public O1 _first; + public O2 _second; + + public final O1 first() { + return _first; + } + + public final O2 second() { + return _second; + } + + public final void setFirst(O1 value){ + _first = value; + } + + public final void setSecond(O2 value){ + _second = value; + } + + public Pair(O1 first, O2 second) { + _first = first; + _second = second; + } + + public String toString(){ + return _first + " " + _second; + } +} diff --git a/gi/posterior-regularisation/prjava/src/util/Printing.java b/gi/posterior-regularisation/prjava/src/util/Printing.java new file mode 100644 index 00000000..14fcbe91 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/Printing.java @@ -0,0 +1,158 @@ +package util; + +public class Printing { + static java.text.DecimalFormat fmt = new java.text.DecimalFormat(); + + public static String padWithSpace(String s, int len){ + StringBuffer sb = new StringBuffer(); + while(sb.length() +s.length() < len){ + sb.append(" "); + } + sb.append(s); + return sb.toString(); + } + + public static String prettyPrint(double d, String patt, int len) { + fmt.applyPattern(patt); + String s = fmt.format(d); + while (s.length() < len) { + s = " " + s; + } + return s; + } + + public static String formatTime(long duration) { + StringBuilder sb = new StringBuilder(); + double d = duration / 1000; + fmt.applyPattern("00"); + sb.append(fmt.format((int) (d / (60 * 60))) + ":"); + d -= ((int) d / (60 * 60)) * 60 * 60; + sb.append(fmt.format((int) (d / 60)) + ":"); + d -= ((int) d / 60) * 60; + fmt.applyPattern("00.0"); + sb.append(fmt.format(d)); + return sb.toString(); + } + + + public static String doubleArrayToString(double[] array, String[] labels, String arrayName) { + StringBuffer res = new StringBuffer(); + res.append(arrayName); + res.append("\n"); + for (int i = 0; i < array.length; i++) { + if (labels == null){ + res.append(i+" \t"); + }else{ + res.append(labels[i]+ "\t"); + } + } + res.append("sum\n"); + double sum = 0; + for (int i = 0; i < array.length; i++) { + res.append(prettyPrint(array[i], + "0.00000E00", 8) + "\t"); + sum+=array[i]; + } + res.append(prettyPrint(sum, + "0.00000E00", 8)+"\n"); + return res.toString(); + } + + + + public static void printDoubleArray(double[] array, String labels[], String arrayName) { + System.out.println(doubleArrayToString(array, labels,arrayName)); + } + + + public static String doubleArrayToString(double[][] array, String[] labels1, String[] labels2, + String arrayName){ + StringBuffer res = new StringBuffer(); + res.append(arrayName); + res.append("\n\t"); + //Calculates the column sum to keeps the sums + double[] sums = new double[array[0].length+1]; + //Prints rows headings + for (int i = 0; i < array[0].length; i++) { + if (labels1 == null){ + res.append(i+" \t"); + }else{ + res.append(labels1[i]+" \t"); + } + } + res.append("sum\n"); + double sum = 0; + //For each row print heading + for (int i = 0; i < array.length; i++) { + if (labels2 == null){ + res.append(i+"\t"); + }else{ + res.append(labels2[i]+"\t"); + } + //Print values for that row + for (int j = 0; j < array[0].length; j++) { + res.append(" " + prettyPrint(array[i][j], + "0.00000E00", 8) + "\t"); + sums[j] += array[i][j]; + sum+=array[i][j]; //Sum all values of that row + } + //Print row sum + res.append(prettyPrint(sum,"0.00000E00", 8)+"\n"); + sums[array[0].length]+=sum; + sum=0; + } + res.append("sum\t"); + //Print values for colums sum + for (int i = 0; i < array[0].length+1; i++) { + res.append(prettyPrint(sums[i],"0.00000E00", 8)+"\t"); + } + res.append("\n"); + return res.toString(); + } + + public static void printDoubleArray(double[][] array, String[] labels1, String[] labels2 + , String arrayName) { + System.out.println(doubleArrayToString(array, labels1,labels2,arrayName)); + } + + + public static void printIntArray(int[][] array, String[] labels1, String[] labels2, String arrayName, + int size1, int size2) { + System.out.println(arrayName); + for (int i = 0; i < size1; i++) { + for (int j = 0; j < size2; j++) { + System.out.print(" " + array[i][j] + " "); + + } + System.out.println(); + } + System.out.println(); + } + + public static String intArrayToString(int[] array, String[] labels, String arrayName) { + StringBuffer res = new StringBuffer(); + res.append(arrayName); + for (int i = 0; i < array.length; i++) { + res.append(" " + array[i] + " "); + + } + res.append("\n"); + return res.toString(); + } + + public static void printIntArray(int[] array, String[] labels, String arrayName) { + System.out.println(intArrayToString(array, labels,arrayName)); + } + + public static String toString(double[][] d){ + StringBuffer sb = new StringBuffer(); + for (int i = 0; i < d.length; i++) { + for (int j = 0; j < d[0].length; j++) { + sb.append(prettyPrint(d[i][j], "0.00E0", 10)); + } + sb.append("\n"); + } + return sb.toString(); + } + +} diff --git a/gi/posterior-regularisation/prjava/src/util/Sorters.java b/gi/posterior-regularisation/prjava/src/util/Sorters.java new file mode 100644 index 00000000..836444e5 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/util/Sorters.java @@ -0,0 +1,39 @@ +package util; + +import java.util.Comparator; + +public class Sorters { + public static class sortWordsCounts implements Comparator{ + + /** + * Sorter for a pair of word id, counts. Sort ascending by counts + */ + public int compare(Object arg0, Object arg1) { + Pair<Integer,Integer> p1 = (Pair<Integer,Integer>)arg0; + Pair<Integer,Integer> p2 = (Pair<Integer,Integer>)arg1; + if(p1.second() > p2.second()){ + return 1; + }else{ + return -1; + } + } + + } + +public static class sortWordsDouble implements Comparator{ + + /** + * Sorter for a pair of word id, counts. Sort by counts + */ + public int compare(Object arg0, Object arg1) { + Pair<Integer,Double> p1 = (Pair<Integer,Double>)arg0; + Pair<Integer,Double> p2 = (Pair<Integer,Double>)arg1; + if(p1.second() < p2.second()){ + return 1; + }else{ + return -1; + } + } + + } +} |