diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java')
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java | 120 |
1 files changed, 53 insertions, 67 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 0fdc169b..015ef106 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -2,6 +2,7 @@ package phrase; import java.io.PrintStream;
import java.util.Arrays;
+import java.util.List;
import optimization.gradientBasedMethods.ProjectedGradientDescent;
import optimization.gradientBasedMethods.ProjectedObjective;
@@ -17,11 +18,12 @@ import optimization.stopCriteria.StopingCriteria; import optimization.stopCriteria.ValueDifference;
import optimization.util.MathUtils;
-public class PhraseObjective extends ProjectedObjective{
-
- private static final double GRAD_DIFF = 0.00002;
- public static double INIT_STEP_SIZE = 10;
- public static double VAL_DIFF = 0.000001; // FIXME needs to be tuned
+public class PhraseObjective extends ProjectedObjective
+{
+ static final double GRAD_DIFF = 0.00002;
+ static double INIT_STEP_SIZE = 10;
+ static double VAL_DIFF = 1e-6; // FIXME needs to be tuned
+ static int ITERATIONS = 100;
//private double c1=0.0001; // wolf stuff
//private double c2=0.9;
private static double lambda[][];
@@ -46,7 +48,7 @@ public class PhraseObjective extends ProjectedObjective{ * q[edge][tag] propto p[edge][tag]*exp(-lambda)
*/
private double q[][];
- private int data[][];
+ private List<Corpus.Edge> data;
/**@brief log likelihood of the associated phrase
*
@@ -66,14 +68,14 @@ public class PhraseObjective extends ProjectedObjective{ public PhraseObjective(PhraseCluster cluster, int phraseIdx){
phrase=phraseIdx;
c=cluster;
- data=c.c.data[phrase];
- n_param=data.length*c.K;
+ data=c.c.getEdgesForPhrase(phrase);
+ n_param=data.size()*c.K;
- if( lambda==null){
- lambda=new double[c.c.data.length][];
+ if (lambda==null){
+ lambda=new double[c.c.getNumPhrases()][];
}
- if(lambda[phrase]==null){
+ if (lambda[phrase]==null){
lambda[phrase]=new double[n_param];
}
@@ -81,22 +83,17 @@ public class PhraseObjective extends ProjectedObjective{ newPoint = new double[n_param];
gradient = new double[n_param];
initP();
- projection=new SimplexProjection(c.scale);
- q=new double [data.length][c.K];
+ projection=new SimplexProjection(c.scalePT);
+ q=new double [data.size()][c.K];
setParameters(parameters);
}
private void initP(){
- int countIdx=data[0].length-1;
-
- p=new double[data.length][];
- for(int edge=0;edge<data.length;edge++){
- p[edge]=c.posterior(phrase,data[edge]);
- }
- for(int edge=0;edge<data.length;edge++){
- llh+=Math.log
- (data[edge][countIdx]*arr.F.l1norm(p[edge]));
+ p=new double[data.size()][];
+ for(int edge=0;edge<data.size();edge++){
+ p[edge]=c.posterior(data.get(edge));
+ llh += data.get(edge).getCount() * Math.log(arr.F.l1norm(p[edge])); // Was bug here - count inside log!
arr.F.l1normalize(p[edge]);
}
}
@@ -110,37 +107,36 @@ public class PhraseObjective extends ProjectedObjective{ private void updateFunction(){
updateCalls++;
loglikelihood=0;
- int countIdx=data[0].length-1;
+
for(int tag=0;tag<c.K;tag++){
- for(int edge=0;edge<data.length;edge++){
+ for(int edge=0;edge<data.size();edge++){
q[edge][tag]=p[edge][tag]*
- Math.exp(-parameters[tag*data.length+edge]/data[edge][countIdx]);
+ Math.exp(-parameters[tag*data.size()+edge]/data.get(edge).getCount());
}
}
- for(int edge=0;edge<data.length;edge++){
- loglikelihood+=data[edge][countIdx] * Math.log(arr.F.l1norm(q[edge]));
+ for(int edge=0;edge<data.size();edge++){
+ loglikelihood+=data.get(edge).getCount() * Math.log(arr.F.l1norm(q[edge]));
arr.F.l1normalize(q[edge]);
}
for(int tag=0;tag<c.K;tag++){
- for(int edge=0;edge<data.length;edge++){
- gradient[tag*data.length+edge]=-q[edge][tag];
+ for(int edge=0;edge<data.size();edge++){
+ gradient[tag*data.size()+edge]=-q[edge][tag];
}
}
}
@Override
- // TODO Auto-generated method stub
public double[] projectPoint(double[] point) {
- double toProject[]=new double[data.length];
+ double toProject[]=new double[data.size()];
for(int tag=0;tag<c.K;tag++){
- for(int edge=0;edge<data.length;edge++){
- toProject[edge]=point[tag*data.length+edge];
+ for(int edge=0;edge<data.size();edge++){
+ toProject[edge]=point[tag*data.size()+edge];
}
projection.project(toProject);
- for(int edge=0;edge<data.length;edge++){
- newPoint[tag*data.length+edge]=toProject[edge];
+ for(int edge=0;edge<data.size();edge++){
+ newPoint[tag*data.size()+edge]=toProject[edge];
}
}
return newPoint;
@@ -148,22 +144,19 @@ public class PhraseObjective extends ProjectedObjective{ @Override
public double[] getGradient() {
- // TODO Auto-generated method stub
gradientCalls++;
return gradient;
}
@Override
public double getValue() {
- // TODO Auto-generated method stub
functionCalls++;
return loglikelihood;
}
@Override
public String toString() {
- // TODO Auto-generated method stub
- return "";
+ return "No need for pointless toString";
}
public double [][]posterior(){
@@ -185,7 +178,7 @@ public class PhraseObjective extends ProjectedObjective{ CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
compositeStop.add(stopGrad);
compositeStop.add(stopValue);
- optimizer.setMaxIterations(100);
+ optimizer.setMaxIterations(ITERATIONS);
updateFunction();
boolean succed = optimizer.optimize(this,stats,compositeStop);
// System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
@@ -197,45 +190,38 @@ public class PhraseObjective extends ProjectedObjective{ lambda[phrase]=parameters;
// ps.println(Arrays.toString(parameters));
- // for(int edge=0;edge<data.length;edge++){
+ // for(int edge=0;edge<data.getSize();edge++){
// ps.println(Arrays.toString(q[edge]));
// }
}
- /**
- * L - KL(q||p) -
- * scale * \sum_{tag,phrase} max_i P(tag|i th occurrence of phrase)
- * @return
- */
- public double primal()
+ public double KL_divergence()
+ {
+ return -loglikelihood + MathUtils.dotProduct(parameters, gradient);
+ }
+
+ public double loglikelihood()
+ {
+ return llh;
+ }
+
+ public double l1lmax()
{
-
- double l=llh;
-
-// ps.print("Phrase "+phrase+": "+l);
- double kl=-loglikelihood
- +MathUtils.dotProduct(parameters, gradient);
-// ps.print(", "+kl);
- //System.out.println("llh " + llh);
- //System.out.println("kl " + kl);
-
-
- l=l-kl;
double sum=0;
for(int tag=0;tag<c.K;tag++){
double max=0;
- for(int edge=0;edge<data.length;edge++){
- if(q[edge][tag]>max){
+ for(int edge=0;edge<data.size();edge++){
+ if(q[edge][tag]>max)
max=q[edge][tag];
- }
}
sum+=max;
}
- //System.out.println("l1lmax " + sum);
-// ps.println(", "+sum);
- l=l-c.scale*sum;
- return l;
+ return sum;
+ }
+
+ public double primal()
+ {
+ return loglikelihood() - KL_divergence() - c.scalePT * l1lmax();
}
-
}
|