summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/optimization/util/StaticTools.java
diff options
context:
space:
mode:
authordesaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-09 16:59:55 +0000
committerdesaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-09 16:59:55 +0000
commit7f69c868c41e4b36eecf9d3b1dc22f3f3aa1540c (patch)
treed22aa7b6f47248ed6da02b77a0680b6b83e67b63 /gi/posterior-regularisation/prjava/src/optimization/util/StaticTools.java
parent4e37402323c3227e90a89345387834e149732b5c (diff)
add optimization library source code
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@204 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/optimization/util/StaticTools.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/optimization/util/StaticTools.java180
1 files changed, 180 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/optimization/util/StaticTools.java b/gi/posterior-regularisation/prjava/src/optimization/util/StaticTools.java
new file mode 100644
index 00000000..bcabee06
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/optimization/util/StaticTools.java
@@ -0,0 +1,180 @@
+package optimization.util;
+
+
+import java.io.File;
+import java.io.PrintStream;
+
+public class StaticTools {
+
+ static java.text.DecimalFormat fmt = new java.text.DecimalFormat();
+
+ public static void createDir(String directory) {
+
+ File dir = new File(directory);
+ if (!dir.isDirectory()) {
+ boolean success = dir.mkdirs();
+ if (!success) {
+ System.out.println("Unable to create directory " + directory);
+ System.exit(0);
+ }
+ System.out.println("Created directory " + directory);
+ } else {
+ System.out.println("Reusing directory " + directory);
+ }
+ }
+
+ /*
+ * q and p are indexed by source/foreign Sum_S(q) = 1 the same for p KL(q,p) =
+ * Eq*q/p
+ */
+ public static double KLDistance(double[][] p, double[][] q, int sourceSize,
+ int foreignSize) {
+ double totalKL = 0;
+ // common.StaticTools.printMatrix(q, sourceSize, foreignSize, "q",
+ // System.out);
+ // common.StaticTools.printMatrix(p, sourceSize, foreignSize, "p",
+ // System.out);
+ for (int i = 0; i < sourceSize; i++) {
+ double kl = 0;
+ for (int j = 0; j < foreignSize; j++) {
+ assert !Double.isNaN(q[i][j]) : "KLDistance q: prob is NaN";
+ assert !Double.isNaN(p[i][j]) : "KLDistance p: prob is NaN";
+ if (p[i][j] == 0 || q[i][j] == 0) {
+ continue;
+ } else {
+ kl += q[i][j] * Math.log(q[i][j] / p[i][j]);
+ }
+
+ }
+ totalKL += kl;
+ }
+ assert !Double.isNaN(totalKL) : "KLDistance: prob is NaN";
+ if (totalKL < -1.0E-10) {
+ System.out.println("KL Smaller than zero " + totalKL);
+ System.out.println("Source Size" + sourceSize);
+ System.out.println("Foreign Size" + foreignSize);
+ StaticTools.printMatrix(q, sourceSize, foreignSize, "q",
+ System.out);
+ StaticTools.printMatrix(p, sourceSize, foreignSize, "p",
+ System.out);
+ System.exit(-1);
+ }
+ return totalKL / sourceSize;
+ }
+
+ /*
+ * indexed the by [fi][si]
+ */
+ public static double KLDistancePrime(double[][] p, double[][] q,
+ int sourceSize, int foreignSize) {
+ double totalKL = 0;
+ for (int i = 0; i < sourceSize; i++) {
+ double kl = 0;
+ for (int j = 0; j < foreignSize; j++) {
+ assert !Double.isNaN(q[j][i]) : "KLDistance q: prob is NaN";
+ assert !Double.isNaN(p[j][i]) : "KLDistance p: prob is NaN";
+ if (p[j][i] == 0 || q[j][i] == 0) {
+ continue;
+ } else {
+ kl += q[j][i] * Math.log(q[j][i] / p[j][i]);
+ }
+
+ }
+ totalKL += kl;
+ }
+ assert !Double.isNaN(totalKL) : "KLDistance: prob is NaN";
+ return totalKL / sourceSize;
+ }
+
+ public static double Entropy(double[][] p, int sourceSize, int foreignSize) {
+ double totalE = 0;
+ for (int i = 0; i < foreignSize; i++) {
+ double e = 0;
+ for (int j = 0; j < sourceSize; j++) {
+ e += p[i][j] * Math.log(p[i][j]);
+ }
+ totalE += e;
+ }
+ return totalE / sourceSize;
+ }
+
+ public static double[][] copyMatrix(double[][] original, int sourceSize,
+ int foreignSize) {
+ double[][] result = new double[sourceSize][foreignSize];
+ for (int i = 0; i < sourceSize; i++) {
+ for (int j = 0; j < foreignSize; j++) {
+ result[i][j] = original[i][j];
+ }
+ }
+ return result;
+ }
+
+ public static void printMatrix(double[][] matrix, int sourceSize,
+ int foreignSize, String info, PrintStream out) {
+
+ java.text.DecimalFormat fmt = new java.text.DecimalFormat();
+ fmt.setMaximumFractionDigits(3);
+ fmt.setMaximumIntegerDigits(3);
+ fmt.setMinimumFractionDigits(3);
+ fmt.setMinimumIntegerDigits(3);
+
+ out.println(info);
+
+ for (int i = 0; i < foreignSize; i++) {
+ for (int j = 0; j < sourceSize; j++) {
+ out.print(prettyPrint(matrix[j][i], ".00E00", 6) + " ");
+ }
+ out.println();
+ }
+ out.println();
+ out.println();
+ }
+
+ public static void printMatrix(int[][] matrix, int sourceSize,
+ int foreignSize, String info, PrintStream out) {
+
+ out.println(info);
+ for (int i = 0; i < foreignSize; i++) {
+ for (int j = 0; j < sourceSize; j++) {
+ out.print(matrix[j][i] + " ");
+ }
+ out.println();
+ }
+ out.println();
+ out.println();
+ }
+
+ public static String formatTime(long duration) {
+ StringBuilder sb = new StringBuilder();
+ double d = duration / 1000;
+ fmt.applyPattern("00");
+ sb.append(fmt.format((int) (d / (60 * 60))) + ":");
+ d -= ((int) d / (60 * 60)) * 60 * 60;
+ sb.append(fmt.format((int) (d / 60)) + ":");
+ d -= ((int) d / 60) * 60;
+ fmt.applyPattern("00.0");
+ sb.append(fmt.format(d));
+ return sb.toString();
+ }
+
+ public static String prettyPrint(double d, String patt, int len) {
+ fmt.applyPattern(patt);
+ String s = fmt.format(d);
+ while (s.length() < len) {
+ s = " " + s;
+ }
+ return s;
+ }
+
+
+ public static long getUsedMemory(){
+ System.gc();
+ return (Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory())/ (1024 * 1024);
+ }
+
+ public final static boolean compareDoubles(double d1, double d2){
+ return Math.abs(d1-d2) <= 1.E-10;
+ }
+
+
+}