diff options
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase')
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java | 66 | ||||
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java | 9 |
2 files changed, 47 insertions, 28 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java index 5947c4be..646ff392 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java @@ -27,8 +27,9 @@ public class PhraseContextObjective extends ProjectedObjective {
private static final double GRAD_DIFF = 0.00002;
private static double INIT_STEP_SIZE = 300;
- private static double VAL_DIFF = 1e-4; // FIXME needs to be tuned
- private static int ITERATIONS = 100;
+ private static double VAL_DIFF = 1e-8;
+ private static int ITERATIONS = 20;
+ boolean debug = false;
private PhraseCluster c;
@@ -52,7 +53,7 @@ public class PhraseContextObjective extends ProjectedObjective // likelihood under p
public double llh;
- private Map<Corpus.Edge, Integer> edgeIndex;
+ private static Map<Corpus.Edge, Integer> edgeIndex;
private long projectionTime;
private long objectiveTime;
@@ -84,10 +85,15 @@ public class PhraseContextObjective extends ProjectedObjective projectionContext = new SimplexProjection(scaleCT);
q=new double [data.size()][c.K];
- edgeIndex = new HashMap<Edge, Integer>();
- for (int e=0; e<data.size(); e++)
- edgeIndex.put(data.get(e), e);
-
+ if (edgeIndex == null) {
+ edgeIndex = new HashMap<Edge, Integer>();
+ for (int e=0; e<data.size(); e++)
+ {
+ edgeIndex.put(data.get(e), e);
+ //if (debug) System.out.println("Edge " + data.get(e) + " index " + e);
+ }
+ }
+
setParameters(parameters);
}
@@ -113,6 +119,7 @@ public class PhraseContextObjective extends ProjectedObjective {
updateCalls++;
loglikelihood=0;
+
System.out.print(".");
System.out.flush();
@@ -120,34 +127,36 @@ public class PhraseContextObjective extends ProjectedObjective for (int e=0; e<data.size(); e++)
{
Edge edge = data.get(e);
- int offset = edgeIndex.get(edge)*c.K*2;
for(int tag=0; tag<c.K; tag++)
{
- int ip = offset + tag*2;
- int ic = ip + 1;
+ int ip = index(e, tag, true);
+ int ic = index(e, tag, false);
q[e][tag] = p[e][tag]*
Math.exp((-parameters[ip]-parameters[ic]) / edge.getCount());
+ //if (debug)
+ //System.out.println("\tposterior " + edge + " with tag " + tag + " p " + p[e][tag] + " params " + parameters[ip] + " and " + parameters[ic] + " q " + q[e][tag]);
}
}
- for(int edge=0;edge<data.size();edge++){
+ for(int edge=0;edge<data.size();edge++) {
loglikelihood+=data.get(edge).getCount() * Math.log(arr.F.l1norm(q[edge]));
arr.F.l1normalize(q[edge]);
}
for (int e=0; e<data.size(); e++)
{
- Edge edge = data.get(e);
- int offset = edgeIndex.get(edge)*c.K*2;
for(int tag=0; tag<c.K; tag++)
{
- int ip = offset + tag*2;
- int ic = ip + 1;
+ int ip = index(e, tag, true);
+ int ic = index(e, tag, false);
gradient[ip]=-q[e][tag];
gradient[ic]=-q[e][tag];
}
}
- //System.out.println("objective " + loglikelihood + " ||gradient||_2: " + arr.F.l2norm(gradient));
+ //if (debug) {
+ //System.out.println("objective " + loglikelihood + " ||gradient||_2: " + arr.F.l2norm(gradient));
+ //System.out.println("gradient " + Arrays.toString(gradient));
+ //}
objectiveTime += System.currentTimeMillis() - begin;
}
@@ -160,7 +169,6 @@ public class PhraseContextObjective extends ProjectedObjective System.out.print(",");
System.out.flush();
- //System.out.println("\t\tprojectPoint: " + Arrays.toString(point));
Arrays.fill(newPoint, 0, newPoint.length, 0);
// first project using the phrase-tag constraints,
@@ -173,7 +181,8 @@ public class PhraseContextObjective extends ProjectedObjective double[] toProject = new double[edges.size()];
for(int tag=0;tag<c.K;tag++)
{
- for(int e=0; e<edges.size(); e++)
+ // FIXME: slow hash lookup for e (twice)
+ for(int e=0; e<edges.size(); e++)
toProject[e] = point[index(edges.get(e), tag, true)];
long lbegin = System.currentTimeMillis();
projectionPhrase.project(toProject);
@@ -197,6 +206,7 @@ public class PhraseContextObjective extends ProjectedObjective double toProject[] = new double[edges.size()];
for(int tag=0;tag<c.K;tag++)
{
+ // FIXME: slow hash lookup for e
for(int e=0; e<edges.size(); e++)
toProject[e] = inPoint[index(edges.get(e), tag, true)];
projectionPhrase.project(toProject);
@@ -220,6 +230,7 @@ public class PhraseContextObjective extends ProjectedObjective double toProject[] = new double[edges.size()];
for(int tag=0;tag<c.K;tag++)
{
+ // FIXME: slow hash lookup for e
for(int e=0; e<edges.size(); e++)
toProject[e] = point[index(edges.get(e), tag, false)];
long lbegin = System.currentTimeMillis();
@@ -245,6 +256,7 @@ public class PhraseContextObjective extends ProjectedObjective double toProject[] = new double[edges.size()];
for(int tag=0;tag<c.K;tag++)
{
+ // FIXME: slow hash lookup for e
for(int e=0; e<edges.size(); e++)
toProject[e] = inPoint[index(edges.get(e), tag, false)];
projectionContext.project(toProject);
@@ -287,7 +299,8 @@ public class PhraseContextObjective extends ProjectedObjective newPoint = point;
projectionTime += System.currentTimeMillis() - begin;
- //System.out.println("\t\treturning " + Arrays.toString(tmp));
+ //if (debug)
+ //System.out.println("\t\treturning " + Arrays.toString(tmp));
return tmp;
}
@@ -295,11 +308,20 @@ public class PhraseContextObjective extends ProjectedObjective {
// NB if indexing changes must also change code in updateFunction and constructor
if (phrase)
- return edgeIndex.get(edge)*c.K*2 + tag*2;
+ return tag * edgeIndex.size() + edgeIndex.get(edge);
else
- return edgeIndex.get(edge)*c.K*2 + tag*2 + 1;
+ return (c.K + tag) * edgeIndex.size() + edgeIndex.get(edge);
}
+ private int index(int e, int tag, boolean phrase)
+ {
+ // NB if indexing changes must also change code in updateFunction and constructor
+ if (phrase)
+ return tag * edgeIndex.size() + e;
+ else
+ return (c.K + tag) * edgeIndex.size() + e;
+ }
+
@Override
public double[] getGradient() {
gradientCalls++;
@@ -345,9 +367,9 @@ public class PhraseContextObjective extends ProjectedObjective optimizer.setMaxIterations(ITERATIONS);
updateFunction();
boolean success = optimizer.optimize(this,stats,compositeStop);
-// System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
System.out.println();
+ System.out.println(stats.prettyPrint(1));
if (success)
System.out.print("\toptimization took " + optimizer.getCurrentIteration() + " iterations");
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 5efe778a..ac73a075 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -21,11 +21,8 @@ public class PhraseObjective extends ProjectedObjective {
static final double GRAD_DIFF = 0.00002;
static double INIT_STEP_SIZE = 300;
- static double VAL_DIFF = 1e-6; // FIXME needs to be tuned - and this might be too weak
+ static double VAL_DIFF = 1e-8; // tuned to BTEC subsample
static int ITERATIONS = 100;
- //private double c1=0.0001; // wolf stuff
- //private double c2=0.9;
- //private static double lambda[][];
private PhraseCluster c;
/**@brief
@@ -181,13 +178,13 @@ public class PhraseObjective extends ProjectedObjective optimizer.setMaxIterations(ITERATIONS);
updateFunction();
boolean success = optimizer.optimize(this,stats,compositeStop);
-// System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
+ //System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
//if(succed){
//System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
//}else{
// System.out.println("Failed to optimize");
//}
- // ps.println(Arrays.toString(parameters));
+ //System.out.println(Arrays.toString(parameters));
// for(int edge=0;edge<data.getSize();edge++){
// ps.println(Arrays.toString(q[edge]));
|