diff options
author | desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-05 15:26:42 +0000 |
---|---|---|
committer | desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-05 15:26:42 +0000 |
commit | 33994330b8395c4c44ad0ddc1e678372404c3566 (patch) | |
tree | 563eb14b957c9a2cda5e49be3ef79ee5c7043718 /gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java | |
parent | 5605a42b8aa6568cc6f11f84fd1f9b0ac2dd596d (diff) |
forget to add files
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@126 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java')
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java | 229 |
1 files changed, 229 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java new file mode 100644 index 00000000..e9e063d6 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -0,0 +1,229 @@ +package phrase;
+
+import java.io.PrintStream;
+import java.util.Arrays;
+
+import optimization.gradientBasedMethods.ProjectedGradientDescent;
+import optimization.gradientBasedMethods.ProjectedObjective;
+import optimization.gradientBasedMethods.stats.OptimizerStats;
+import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;
+import optimization.linesearch.InterpolationPickFirstStep;
+import optimization.linesearch.LineSearchMethod;
+import optimization.linesearch.WolfRuleLineSearch;
+import optimization.projections.SimplexProjection;
+import optimization.stopCriteria.CompositeStopingCriteria;
+import optimization.stopCriteria.ProjectedGradientL2Norm;
+import optimization.stopCriteria.StopingCriteria;
+import optimization.stopCriteria.ValueDifference;
+import optimization.util.MathUtils;
+
+public class PhraseObjective extends ProjectedObjective{
+
+ private static final double GRAD_DIFF = 0.002;
+ public static double INIT_STEP_SIZE=1;
+ public static double VAL_DIFF=0.001;
+ private double scale=5;
+ private double c1=0.0001;
+ private double c2=0.9;
+
+ private PhraseCluster c;
+
+ /**@brief
+ * for debugging purposes
+ */
+ public static PrintStream ps;
+
+ /**@brief current phrase being optimzed*/
+ public int phrase;
+
+ /**@brief un-regularized posterior
+ * unnormalized
+ * p[edge][tag]
+ * P(tag|edge) \propto P(tag|phrase)P(context|tag)
+ */
+ private double[][]p;
+
+ /**@brief regularized posterior
+ * q[edge][tag] propto p[edge][tag]*exp(-lambda)
+ */
+ private double q[][];
+ private int data[][];
+
+ /**@brief log likelihood of the associated phrase
+ *
+ */
+ private double loglikelihood;
+ private SimplexProjection projection;
+
+ double[] newPoint ;
+
+ private int n_param;
+
+ /**@brief likelihood under p
+ *
+ */
+ private double llh;
+
+ public PhraseObjective(PhraseCluster cluster, int phraseIdx){
+ phrase=phraseIdx;
+ c=cluster;
+ data=c.c.data[phrase];
+ n_param=data.length*c.K;
+ parameters=new double [n_param];
+ newPoint = new double[n_param];
+ gradient = new double[n_param];
+ initP();
+ projection=new SimplexProjection (scale);
+ q=new double [data.length][c.K];
+
+ setParameters(parameters);
+ }
+
+ private void initP(){
+ int countIdx=data[0].length-1;
+
+ p=new double[data.length][];
+ for(int edge=0;edge<data.length;edge++){
+ p[edge]=c.posterior(phrase,data[edge]);
+ }
+ for(int edge=0;edge<data.length;edge++){
+ llh+=Math.log
+ (data[edge][countIdx]*arr.F.l1norm(p[edge]));
+ arr.F.l1normalize(p[edge]);
+ }
+ }
+
+ @Override
+ public void setParameters(double[] params) {
+ super.setParameters(params);
+ updateFunction();
+ }
+
+ private void updateFunction(){
+ updateCalls++;
+ loglikelihood=0;
+ int countIdx=data[0].length-1;
+ for(int tag=0;tag<c.K;tag++){
+ for(int edge=0;edge<data.length;edge++){
+ q[edge][tag]=p[edge][tag]*
+ Math.exp(-parameters[tag*data.length+edge]/data[edge][countIdx]);
+ }
+ }
+
+ for(int edge=0;edge<data.length;edge++){
+ loglikelihood+=Math.log
+ (data[edge][countIdx]*arr.F.l1norm(q[edge]));
+ arr.F.l1normalize(q[edge]);
+ }
+
+ for(int tag=0;tag<c.K;tag++){
+ for(int edge=0;edge<data.length;edge++){
+ gradient[tag*data.length+edge]=-q[edge][tag];
+ }
+ }
+ }
+
+ @Override
+ // TODO Auto-generated method stub
+ public double[] projectPoint(double[] point) {
+ double toProject[]=new double[data.length];
+ for(int tag=0;tag<c.K;tag++){
+ for(int edge=0;edge<data.length;edge++){
+ toProject[edge]=point[tag*data.length+edge];
+ }
+ projection.project(toProject);
+ for(int edge=0;edge<data.length;edge++){
+ newPoint[tag*data.length+edge]=toProject[edge];
+ }
+ }
+ return newPoint;
+ }
+
+ @Override
+ public double[] getGradient() {
+ // TODO Auto-generated method stub
+ gradientCalls++;
+ return gradient;
+ }
+
+ @Override
+ public double getValue() {
+ // TODO Auto-generated method stub
+ functionCalls++;
+ return loglikelihood;
+ }
+
+ @Override
+ public String toString() {
+ // TODO Auto-generated method stub
+ return "";
+ }
+
+ public double [][]posterior(){
+ return q;
+ }
+
+ public void optimizeWithProjectedGradientDescent(){
+ LineSearchMethod ls =
+ new ArmijoLineSearchMinimizationAlongProjectionArc
+ (new InterpolationPickFirstStep(INIT_STEP_SIZE));
+ //LineSearchMethod ls = new WolfRuleLineSearch(
+ // (new InterpolationPickFirstStep(INIT_STEP_SIZE)), c1, c2);
+ OptimizerStats stats = new OptimizerStats();
+
+
+ ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
+ StopingCriteria stopGrad = new ProjectedGradientL2Norm(GRAD_DIFF);
+ StopingCriteria stopValue = new ValueDifference(VAL_DIFF);
+ CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
+ compositeStop.add(stopGrad);
+ compositeStop.add(stopValue);
+ optimizer.setMaxIterations(100);
+ updateFunction();
+ boolean succed = optimizer.optimize(this,stats,compositeStop);
+// System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
+ if(succed){
+ System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
+ }else{
+ System.out.println("Failed to optimize");
+ }
+
+ // ps.println(Arrays.toString(parameters));
+
+ // for(int edge=0;edge<data.length;edge++){
+ // ps.println(Arrays.toString(q[edge]));
+ // }
+
+ }
+
+ /**
+ * L - KL(q||p) -
+ * scale * \sum_{tag,phrase} max_i P(tag|i th occurrence of phrase)
+ * @return
+ */
+ public double primal()
+ {
+
+ double l=llh;
+
+// ps.print("Phrase "+phrase+": "+l);
+ double kl=-loglikelihood
+ +MathUtils.dotProduct(parameters, gradient);
+// ps.print(", "+kl);
+ l=l-kl;
+ double sum=0;
+ for(int tag=0;tag<c.K;tag++){
+ double max=0;
+ for(int edge=0;edge<data.length;edge++){
+ if(q[edge][tag]>max){
+ max=q[edge][tag];
+ }
+ }
+ sum+=max;
+ }
+// ps.println(", "+sum);
+ l=l-scale*sum;
+ return l;
+ }
+
+}
|