summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
diff options
context:
space:
mode:
authortrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-08 21:46:05 +0000
committertrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-08 21:46:05 +0000
commita034f92b1fe0c6368ebb140bc691f0718dd23a23 (patch)
treeaa691e12661a4f1e695a79842310d0dd7c08b51c /gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
parentaadddcd945f72983e7803369037362cb38921db6 (diff)
New context constraints.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@190 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java302
1 files changed, 302 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
new file mode 100644
index 00000000..3273f0ad
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
@@ -0,0 +1,302 @@
+package phrase;
+
+import java.io.PrintStream;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import optimization.gradientBasedMethods.ProjectedGradientDescent;
+import optimization.gradientBasedMethods.ProjectedObjective;
+import optimization.gradientBasedMethods.stats.OptimizerStats;
+import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;
+import optimization.linesearch.InterpolationPickFirstStep;
+import optimization.linesearch.LineSearchMethod;
+import optimization.linesearch.WolfRuleLineSearch;
+import optimization.projections.SimplexProjection;
+import optimization.stopCriteria.CompositeStopingCriteria;
+import optimization.stopCriteria.ProjectedGradientL2Norm;
+import optimization.stopCriteria.StopingCriteria;
+import optimization.stopCriteria.ValueDifference;
+import optimization.util.MathUtils;
+import phrase.Corpus.Edge;
+
+public class PhraseContextObjective extends ProjectedObjective
+{
+ private static final double GRAD_DIFF = 0.00002;
+ private static double INIT_STEP_SIZE = 10;
+ private static double VAL_DIFF = 1e-4; // FIXME needs to be tuned
+ private static int ITERATIONS = 100;
+
+ private PhraseCluster c;
+
+ // un-regularized unnormalized posterior, p[edge][tag]
+ // P(tag|edge) \propto P(tag|phrase)P(context|tag)
+ private double p[][];
+
+ // regularized unnormalized posterior
+ // q[edge][tag] propto p[edge][tag]*exp(-lambda)
+ private double q[][];
+ private List<Corpus.Edge> data;
+
+ // log likelihood under q
+ private double loglikelihood;
+ private SimplexProjection projectionPhrase;
+ private SimplexProjection projectionContext;
+
+ double[] newPoint;
+ private int n_param;
+
+ // likelihood under p
+ public double llh;
+
+ private Map<Corpus.Edge, Integer> edgeIndex;
+
+ public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters)
+ {
+ c=cluster;
+ data=c.c.getEdges();
+ n_param=data.size()*c.K*2;
+
+ parameters = startingParameters;
+ if (parameters == null)
+ parameters = new double[n_param];
+
+ newPoint = new double[n_param];
+ gradient = new double[n_param];
+ initP();
+ projectionPhrase = new SimplexProjection(c.scalePT);
+ projectionContext = new SimplexProjection(c.scaleCT);
+ q=new double [data.size()][c.K];
+
+ edgeIndex = new HashMap<Edge, Integer>();
+ for (int e=0; e<data.size(); e++)
+ edgeIndex.put(data.get(e), e);
+
+ setParameters(parameters);
+ }
+
+ private void initP(){
+ p=new double[data.size()][];
+ for(int edge=0;edge<data.size();edge++)
+ {
+ p[edge]=c.posterior(data.get(edge));
+ llh += data.get(edge).getCount() * Math.log(arr.F.l1norm(p[edge]));
+ arr.F.l1normalize(p[edge]);
+ }
+ }
+
+ @Override
+ public void setParameters(double[] params) {
+ //System.out.println("setParameters " + Arrays.toString(parameters));
+ // TODO: test if params have changed and skip update otherwise
+ super.setParameters(params);
+ updateFunction();
+ }
+
+ private void updateFunction()
+ {
+ updateCalls++;
+ loglikelihood=0;
+
+ for (int e=0; e<data.size(); e++)
+ {
+ Edge edge = data.get(e);
+ int offset = edgeIndex.get(edge)*c.K*2;
+ for(int tag=0; tag<c.K; tag++)
+ {
+ int ip = offset + tag*2;
+ int ic = ip + 1;
+ q[e][tag] = p[e][tag]*
+ Math.exp((-parameters[ip]-parameters[ic]) / edge.getCount());
+ }
+ }
+
+ for(int edge=0;edge<data.size();edge++){
+ loglikelihood+=data.get(edge).getCount() * Math.log(arr.F.l1norm(q[edge]));
+ arr.F.l1normalize(q[edge]);
+ }
+
+ for (int e=0; e<data.size(); e++)
+ {
+ Edge edge = data.get(e);
+ int offset = edgeIndex.get(edge)*c.K*2;
+ for(int tag=0; tag<c.K; tag++)
+ {
+ int ip = offset + tag*2;
+ int ic = ip + 1;
+ gradient[ip]=-q[e][tag];
+ gradient[ic]=-q[e][tag];
+ }
+ }
+ //System.out.println("objective " + loglikelihood + " gradient: " + Arrays.toString(gradient));
+ }
+
+ @Override
+ public double[] projectPoint(double[] point)
+ {
+ //System.out.println("projectPoint: " + Arrays.toString(point));
+ Arrays.fill(newPoint, 0, newPoint.length, 0);
+ if (c.scalePT > 0)
+ {
+ // first project using the phrase-tag constraints,
+ // for all p,t: sum_c lambda_ptc < scaleP
+ for (int p = 0; p < c.c.getNumPhrases(); ++p)
+ {
+ List<Edge> edges = c.c.getEdgesForPhrase(p);
+ double toProject[] = new double[edges.size()];
+ for(int tag=0;tag<c.K;tag++)
+ {
+ for(int e=0; e<edges.size(); e++)
+ toProject[e] = point[index(edges.get(e), tag, true)];
+ projectionPhrase.project(toProject);
+ for(int e=0; e<edges.size(); e++)
+ newPoint[index(edges.get(e),tag, true)] = toProject[e];
+ }
+ }
+ }
+ //System.out.println("after PT " + Arrays.toString(newPoint));
+
+ if (c.scaleCT > 1e-6)
+ {
+ // now project using the context-tag constraints,
+ // for all c,t: sum_p omega_pct < scaleC
+ for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
+ {
+ List<Edge> edges = c.c.getEdgesForContext(ctx);
+ double toProject[] = new double[edges.size()];
+ for(int tag=0;tag<c.K;tag++)
+ {
+ for(int e=0; e<edges.size(); e++)
+ toProject[e] = point[index(edges.get(e), tag, false)];
+ projectionContext.project(toProject);
+ for(int e=0; e<edges.size(); e++)
+ newPoint[index(edges.get(e),tag, false)] = toProject[e];
+ }
+ }
+ }
+ double[] tmp = newPoint;
+ newPoint = point;
+
+ //System.out.println("\treturning " + Arrays.toString(tmp));
+ return tmp;
+ }
+
+ private int index(Edge edge, int tag, boolean phrase)
+ {
+ // NB if indexing changes must also change code in updateFunction and constructor
+ if (phrase)
+ return edgeIndex.get(edge)*c.K*2 + tag*2;
+ else
+ return edgeIndex.get(edge)*c.K*2 + tag*2 + 1;
+ }
+
+ @Override
+ public double[] getGradient() {
+ gradientCalls++;
+ return gradient;
+ }
+
+ @Override
+ public double getValue() {
+ functionCalls++;
+ return loglikelihood;
+ }
+
+ @Override
+ public String toString() {
+ return "No need for pointless toString";
+ }
+
+ public double []posterior(int edgeIndex){
+ return q[edgeIndex];
+ }
+
+ public double[] optimizeWithProjectedGradientDescent()
+ {
+ LineSearchMethod ls =
+ new ArmijoLineSearchMinimizationAlongProjectionArc
+ (new InterpolationPickFirstStep(INIT_STEP_SIZE));
+ //LineSearchMethod ls = new WolfRuleLineSearch(
+ // (new InterpolationPickFirstStep(INIT_STEP_SIZE)), c1, c2);
+ OptimizerStats stats = new OptimizerStats();
+
+
+ ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
+ StopingCriteria stopGrad = new ProjectedGradientL2Norm(GRAD_DIFF);
+ StopingCriteria stopValue = new ValueDifference(VAL_DIFF*(-llh));
+ CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
+ compositeStop.add(stopGrad);
+ compositeStop.add(stopValue);
+ optimizer.setMaxIterations(ITERATIONS);
+ updateFunction();
+ boolean succed = optimizer.optimize(this,stats,compositeStop);
+// System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
+ if(succed){
+ //System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
+ }else{
+ System.out.println("Failed to optimize");
+ }
+ // ps.println(Arrays.toString(parameters));
+
+ // for(int edge=0;edge<data.getSize();edge++){
+ // ps.println(Arrays.toString(q[edge]));
+ // }
+ //System.out.println(Arrays.toString(parameters));
+
+ return parameters;
+ }
+
+ double loglikelihood()
+ {
+ return llh;
+ }
+
+ double KL_divergence()
+ {
+ return -loglikelihood + MathUtils.dotProduct(parameters, gradient);
+ }
+
+ double phrase_l1lmax()
+ {
+ // \sum_{tag,phrase} max_{context} P(tag|context,phrase)
+ double sum=0;
+ for (int p = 0; p < c.c.getNumPhrases(); ++p)
+ {
+ List<Edge> edges = c.c.getEdgesForPhrase(p);
+ for(int tag=0;tag<c.K;tag++)
+ {
+ double max=0;
+ for (Edge edge: edges)
+ max = Math.max(max, q[edgeIndex.get(edge)][tag]);
+ sum+=max;
+ }
+ }
+ return sum;
+ }
+
+ double context_l1lmax()
+ {
+ // \sum_{tag,context} max_{phrase} P(tag|context,phrase)
+ double sum=0;
+ for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
+ {
+ List<Edge> edges = c.c.getEdgesForContext(ctx);
+ for(int tag=0; tag<c.K; tag++)
+ {
+ double max=0;
+ for (Edge edge: edges)
+ max = Math.max(max, q[edgeIndex.get(edge)][tag]);
+ sum+=max;
+ }
+ }
+ return sum;
+ }
+
+ // L - KL(q||p) - scalePT * l1lmax_phrase - scaleCT * l1lmax_context
+ public double primal()
+ {
+ return loglikelihood() - KL_divergence() - c.scalePT * phrase_l1lmax() - c.scalePT * context_l1lmax();
+ }
+
+}