summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/stats
diff options
context:
space:
mode:
authordesaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-09 16:59:55 +0000
committerdesaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-09 16:59:55 +0000
commitbdea91300c85539ab7153ccba58689612f66bb4d (patch)
treee778ffa1ea4d04a239b58c6e6191c0d4549006f0 /gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/stats
parent0d1d84630a08f1c901cf09b4bcc9356c4165302f (diff)
add optimization library source code
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@204 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/stats')
-rw-r--r--gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/stats/OptimizerStats.java86
-rw-r--r--gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/stats/ProjectedOptimizerStats.java70
2 files changed, 156 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/stats/OptimizerStats.java b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/stats/OptimizerStats.java
new file mode 100644
index 00000000..6340ef73
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/stats/OptimizerStats.java
@@ -0,0 +1,86 @@
+package optimization.gradientBasedMethods.stats;
+
+import java.util.ArrayList;
+
+import optimization.gradientBasedMethods.Objective;
+import optimization.gradientBasedMethods.Optimizer;
+import optimization.util.MathUtils;
+import optimization.util.StaticTools;
+
+
+public class OptimizerStats {
+
+ double start = 0;
+ double totalTime = 0;
+
+ String objectiveFinalStats;
+
+ ArrayList<Double> gradientNorms = new ArrayList<Double>();
+ ArrayList<Double> steps = new ArrayList<Double>();
+ ArrayList<Double> value = new ArrayList<Double>();
+ ArrayList<Integer> iterations = new ArrayList<Integer>();
+ double prevValue =0;
+
+ public void reset(){
+ start = 0;
+ totalTime = 0;
+
+ objectiveFinalStats="";
+
+ gradientNorms.clear();
+ steps.clear();
+ value.clear();
+ iterations.clear();
+ prevValue =0;
+ }
+
+ public void startTime() {
+ start = System.currentTimeMillis();
+ }
+ public void stopTime() {
+ totalTime += System.currentTimeMillis() - start;
+ }
+
+ public String prettyPrint(int level){
+ StringBuffer res = new StringBuffer();
+ res.append("Total time " + totalTime/1000 + " seconds \n" + "Iterations " + iterations.size() + "\n");
+ res.append(objectiveFinalStats+"\n");
+ if(level > 0){
+ if(iterations.size() > 0){
+ res.append("\tIteration"+iterations.get(0)+"\tstep: "+StaticTools.prettyPrint(steps.get(0), "0.00E00", 6)+ "\tgradientNorm "+
+ StaticTools.prettyPrint(gradientNorms.get(0), "0.00000E00", 10)+ "\tvalue "+ StaticTools.prettyPrint(value.get(0), "0.000000E00",11)+"\n");
+ }
+ for(int i = 1; i < iterations.size(); i++){
+ res.append("\tIteration:\t"+iterations.get(i)+"\tstep:"+StaticTools.prettyPrint(steps.get(i), "0.00E00", 6)+ "\tgradientNorm "+
+ StaticTools.prettyPrint(gradientNorms.get(i), "0.00000E00", 10)+
+ "\tvalue:\t"+ StaticTools.prettyPrint(value.get(i), "0.000000E00",11)+
+ "\tvalueDiff:\t"+ StaticTools.prettyPrint((value.get(i-1)-value.get(i)), "0.000000E00",11)+
+ "\n");
+ }
+ }
+ return res.toString();
+ }
+
+
+ public void collectInitStats(Optimizer optimizer, Objective objective){
+ startTime();
+ iterations.add(-1);
+ gradientNorms.add(MathUtils.L2Norm(objective.getGradient()));
+ steps.add(0.0);
+ value.add(objective.getValue());
+ }
+
+ public void collectIterationStats(Optimizer optimizer, Objective objective){
+ iterations.add(optimizer.getCurrentIteration());
+ gradientNorms.add(MathUtils.L2Norm(objective.getGradient()));
+ steps.add(optimizer.getCurrentStep());
+ value.add(optimizer.getCurrentValue());
+ }
+
+
+ public void collectFinalStats(Optimizer optimizer, Objective objective){
+ stopTime();
+ objectiveFinalStats = objective.finalInfoString();
+ }
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/stats/ProjectedOptimizerStats.java b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/stats/ProjectedOptimizerStats.java
new file mode 100644
index 00000000..d65a1267
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/stats/ProjectedOptimizerStats.java
@@ -0,0 +1,70 @@
+package optimization.gradientBasedMethods.stats;
+
+import java.util.ArrayList;
+
+import optimization.gradientBasedMethods.Objective;
+import optimization.gradientBasedMethods.Optimizer;
+import optimization.gradientBasedMethods.ProjectedObjective;
+import optimization.gradientBasedMethods.ProjectedOptimizer;
+import optimization.util.MathUtils;
+import optimization.util.StaticTools;
+
+
+public class ProjectedOptimizerStats extends OptimizerStats{
+
+
+
+ public void reset(){
+ super.reset();
+ projectedGradientNorms.clear();
+ }
+
+ ArrayList<Double> projectedGradientNorms = new ArrayList<Double>();
+
+ public String prettyPrint(int level){
+ StringBuffer res = new StringBuffer();
+ res.append("Total time " + totalTime/1000 + " seconds \n" + "Iterations " + iterations.size() + "\n");
+ res.append(objectiveFinalStats+"\n");
+ if(level > 0){
+ if(iterations.size() > 0){
+ res.append("\tIteration"+iterations.get(0)+"\tstep: "+
+ StaticTools.prettyPrint(steps.get(0), "0.00E00", 6)+ "\tgradientNorm "+
+ StaticTools.prettyPrint(gradientNorms.get(0), "0.00000E00", 10)
+ + "\tdirection"+
+ StaticTools.prettyPrint(projectedGradientNorms.get(0), "0.00000E00", 10)+
+ "\tvalue "+ StaticTools.prettyPrint(value.get(0), "0.000000E00",11)+"\n");
+ }
+ for(int i = 1; i < iterations.size(); i++){
+ res.append("\tIteration"+iterations.get(i)+"\tstep: "+StaticTools.prettyPrint(steps.get(i), "0.00E00", 6)+ "\tgradientNorm "+
+ StaticTools.prettyPrint(gradientNorms.get(i), "0.00000E00", 10)+
+ "\t direction "+
+ StaticTools.prettyPrint(projectedGradientNorms.get(i), "0.00000E00", 10)+
+ "\tvalue "+ StaticTools.prettyPrint(value.get(i), "0.000000E00",11)+
+ "\tvalueDiff "+ StaticTools.prettyPrint((value.get(i-1)-value.get(i)), "0.000000E00",11)+
+ "\n");
+ }
+ }
+ return res.toString();
+ }
+
+
+ public void collectInitStats(Optimizer optimizer, Objective objective){
+ startTime();
+ }
+
+ public void collectIterationStats(Optimizer optimizer, Objective objective){
+ iterations.add(optimizer.getCurrentIteration());
+ gradientNorms.add(MathUtils.L2Norm(objective.getGradient()));
+ projectedGradientNorms.add(MathUtils.L2Norm(optimizer.getDirection()));
+ steps.add(optimizer.getCurrentStep());
+ value.add(optimizer.getCurrentValue());
+ }
+
+
+
+ public void collectFinalStats(Optimizer optimizer, Objective objective){
+ stopTime();
+ objectiveFinalStats = objective.finalInfoString();
+ }
+
+}