summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java436
1 files changed, 0 insertions, 436 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
deleted file mode 100644
index 646ff392..00000000
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
+++ /dev/null
@@ -1,436 +0,0 @@
-package phrase;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Future;
-
-import optimization.gradientBasedMethods.ProjectedGradientDescent;
-import optimization.gradientBasedMethods.ProjectedObjective;
-import optimization.gradientBasedMethods.stats.OptimizerStats;
-import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;
-import optimization.linesearch.InterpolationPickFirstStep;
-import optimization.linesearch.LineSearchMethod;
-import optimization.projections.SimplexProjection;
-import optimization.stopCriteria.CompositeStopingCriteria;
-import optimization.stopCriteria.ProjectedGradientL2Norm;
-import optimization.stopCriteria.StopingCriteria;
-import optimization.stopCriteria.ValueDifference;
-import optimization.util.MathUtils;
-import phrase.Corpus.Edge;
-
-public class PhraseContextObjective extends ProjectedObjective
-{
- private static final double GRAD_DIFF = 0.00002;
- private static double INIT_STEP_SIZE = 300;
- private static double VAL_DIFF = 1e-8;
- private static int ITERATIONS = 20;
- boolean debug = false;
-
- private PhraseCluster c;
-
- // un-regularized unnormalized posterior, p[edge][tag]
- // P(tag|edge) \propto P(tag|phrase)P(context|tag)
- private double p[][];
-
- // regularized unnormalized posterior
- // q[edge][tag] propto p[edge][tag]*exp(-lambda)
- private double q[][];
- private List<Corpus.Edge> data;
-
- // log likelihood under q
- private double loglikelihood;
- private SimplexProjection projectionPhrase;
- private SimplexProjection projectionContext;
-
- double[] newPoint;
- private int n_param;
-
- // likelihood under p
- public double llh;
-
- private static Map<Corpus.Edge, Integer> edgeIndex;
-
- private long projectionTime;
- private long objectiveTime;
- private long actualProjectionTime;
- private ExecutorService pool;
-
- double scalePT;
- double scaleCT;
-
- public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters, ExecutorService pool,
- double scalePT, double scaleCT)
- {
- c=cluster;
- data=c.c.getEdges();
- n_param=data.size()*c.K*2;
- this.pool=pool;
- this.scalePT = scalePT;
- this.scaleCT = scaleCT;
-
- parameters = startingParameters;
- if (parameters == null)
- parameters = new double[n_param];
-
- System.out.println("Num parameters " + n_param);
- newPoint = new double[n_param];
- gradient = new double[n_param];
- initP();
- projectionPhrase = new SimplexProjection(scalePT);
- projectionContext = new SimplexProjection(scaleCT);
- q=new double [data.size()][c.K];
-
- if (edgeIndex == null) {
- edgeIndex = new HashMap<Edge, Integer>();
- for (int e=0; e<data.size(); e++)
- {
- edgeIndex.put(data.get(e), e);
- //if (debug) System.out.println("Edge " + data.get(e) + " index " + e);
- }
- }
-
- setParameters(parameters);
- }
-
- private void initP(){
- p=new double[data.size()][];
- for(int edge=0;edge<data.size();edge++)
- {
- p[edge]=c.posterior(data.get(edge));
- llh += data.get(edge).getCount() * Math.log(arr.F.l1norm(p[edge]));
- arr.F.l1normalize(p[edge]);
- }
- }
-
- @Override
- public void setParameters(double[] params) {
- //System.out.println("setParameters " + Arrays.toString(parameters));
- // TODO: test if params have changed and skip update otherwise
- super.setParameters(params);
- updateFunction();
- }
-
- private void updateFunction()
- {
- updateCalls++;
- loglikelihood=0;
-
- System.out.print(".");
- System.out.flush();
-
- long begin = System.currentTimeMillis();
- for (int e=0; e<data.size(); e++)
- {
- Edge edge = data.get(e);
- for(int tag=0; tag<c.K; tag++)
- {
- int ip = index(e, tag, true);
- int ic = index(e, tag, false);
- q[e][tag] = p[e][tag]*
- Math.exp((-parameters[ip]-parameters[ic]) / edge.getCount());
- //if (debug)
- //System.out.println("\tposterior " + edge + " with tag " + tag + " p " + p[e][tag] + " params " + parameters[ip] + " and " + parameters[ic] + " q " + q[e][tag]);
- }
- }
-
- for(int edge=0;edge<data.size();edge++) {
- loglikelihood+=data.get(edge).getCount() * Math.log(arr.F.l1norm(q[edge]));
- arr.F.l1normalize(q[edge]);
- }
-
- for (int e=0; e<data.size(); e++)
- {
- for(int tag=0; tag<c.K; tag++)
- {
- int ip = index(e, tag, true);
- int ic = index(e, tag, false);
- gradient[ip]=-q[e][tag];
- gradient[ic]=-q[e][tag];
- }
- }
- //if (debug) {
- //System.out.println("objective " + loglikelihood + " ||gradient||_2: " + arr.F.l2norm(gradient));
- //System.out.println("gradient " + Arrays.toString(gradient));
- //}
- objectiveTime += System.currentTimeMillis() - begin;
- }
-
- @Override
- public double[] projectPoint(double[] point)
- {
- long begin = System.currentTimeMillis();
- List<Future<?>> tasks = new ArrayList<Future<?>>();
-
- System.out.print(",");
- System.out.flush();
-
- Arrays.fill(newPoint, 0, newPoint.length, 0);
-
- // first project using the phrase-tag constraints,
- // for all p,t: sum_c lambda_ptc < scaleP
- if (pool == null)
- {
- for (int p = 0; p < c.c.getNumPhrases(); ++p)
- {
- List<Edge> edges = c.c.getEdgesForPhrase(p);
- double[] toProject = new double[edges.size()];
- for(int tag=0;tag<c.K;tag++)
- {
- // FIXME: slow hash lookup for e (twice)
- for(int e=0; e<edges.size(); e++)
- toProject[e] = point[index(edges.get(e), tag, true)];
- long lbegin = System.currentTimeMillis();
- projectionPhrase.project(toProject);
- actualProjectionTime += System.currentTimeMillis() - lbegin;
- for(int e=0; e<edges.size(); e++)
- newPoint[index(edges.get(e), tag, true)] = toProject[e];
- }
- }
- }
- else // do above in parallel using thread pool
- {
- for (int p = 0; p < c.c.getNumPhrases(); ++p)
- {
- final int phrase = p;
- final double[] inPoint = point;
- Runnable task = new Runnable()
- {
- public void run()
- {
- List<Edge> edges = c.c.getEdgesForPhrase(phrase);
- double toProject[] = new double[edges.size()];
- for(int tag=0;tag<c.K;tag++)
- {
- // FIXME: slow hash lookup for e
- for(int e=0; e<edges.size(); e++)
- toProject[e] = inPoint[index(edges.get(e), tag, true)];
- projectionPhrase.project(toProject);
- for(int e=0; e<edges.size(); e++)
- newPoint[index(edges.get(e), tag, true)] = toProject[e];
- }
- }
- };
- tasks.add(pool.submit(task));
- }
- }
- //System.out.println("after PT " + Arrays.toString(newPoint));
-
- // now project using the context-tag constraints,
- // for all c,t: sum_p omega_pct < scaleC
- if (pool == null)
- {
- for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
- {
- List<Edge> edges = c.c.getEdgesForContext(ctx);
- double toProject[] = new double[edges.size()];
- for(int tag=0;tag<c.K;tag++)
- {
- // FIXME: slow hash lookup for e
- for(int e=0; e<edges.size(); e++)
- toProject[e] = point[index(edges.get(e), tag, false)];
- long lbegin = System.currentTimeMillis();
- projectionContext.project(toProject);
- actualProjectionTime += System.currentTimeMillis() - lbegin;
- for(int e=0; e<edges.size(); e++)
- newPoint[index(edges.get(e), tag, false)] = toProject[e];
- }
- }
- }
- else
- {
- // do above in parallel using thread pool
- for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
- {
- final int context = ctx;
- final double[] inPoint = point;
- Runnable task = new Runnable()
- {
- public void run()
- {
- List<Edge> edges = c.c.getEdgesForContext(context);
- double toProject[] = new double[edges.size()];
- for(int tag=0;tag<c.K;tag++)
- {
- // FIXME: slow hash lookup for e
- for(int e=0; e<edges.size(); e++)
- toProject[e] = inPoint[index(edges.get(e), tag, false)];
- projectionContext.project(toProject);
- for(int e=0; e<edges.size(); e++)
- newPoint[index(edges.get(e), tag, false)] = toProject[e];
- }
- }
- };
- tasks.add(pool.submit(task));
- }
- }
-
- if (pool != null)
- {
- // wait for all the jobs to complete
- Exception failure = null;
- for (Future<?> task: tasks)
- {
- try {
- task.get();
- } catch (InterruptedException e) {
- System.err.println("ERROR: Projection thread interrupted");
- e.printStackTrace();
- failure = e;
- } catch (ExecutionException e) {
- System.err.println("ERROR: Projection thread died");
- e.printStackTrace();
- failure = e;
- }
- }
- // rethrow the exception
- if (failure != null)
- {
- pool.shutdownNow();
- throw new RuntimeException(failure);
- }
- }
-
- double[] tmp = newPoint;
- newPoint = point;
- projectionTime += System.currentTimeMillis() - begin;
-
- //if (debug)
- //System.out.println("\t\treturning " + Arrays.toString(tmp));
- return tmp;
- }
-
- private int index(Edge edge, int tag, boolean phrase)
- {
- // NB if indexing changes must also change code in updateFunction and constructor
- if (phrase)
- return tag * edgeIndex.size() + edgeIndex.get(edge);
- else
- return (c.K + tag) * edgeIndex.size() + edgeIndex.get(edge);
- }
-
- private int index(int e, int tag, boolean phrase)
- {
- // NB if indexing changes must also change code in updateFunction and constructor
- if (phrase)
- return tag * edgeIndex.size() + e;
- else
- return (c.K + tag) * edgeIndex.size() + e;
- }
-
- @Override
- public double[] getGradient() {
- gradientCalls++;
- return gradient;
- }
-
- @Override
- public double getValue() {
- functionCalls++;
- return loglikelihood;
- }
-
- @Override
- public String toString() {
- return "No need for pointless toString";
- }
-
- public double []posterior(int edgeIndex){
- return q[edgeIndex];
- }
-
- public boolean optimizeWithProjectedGradientDescent()
- {
- projectionTime = 0;
- actualProjectionTime = 0;
- objectiveTime = 0;
- long start = System.currentTimeMillis();
-
- LineSearchMethod ls =
- new ArmijoLineSearchMinimizationAlongProjectionArc
- (new InterpolationPickFirstStep(INIT_STEP_SIZE));
- //LineSearchMethod ls = new WolfRuleLineSearch(
- // (new InterpolationPickFirstStep(INIT_STEP_SIZE)), c1, c2);
- OptimizerStats stats = new OptimizerStats();
-
-
- ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
- StopingCriteria stopGrad = new ProjectedGradientL2Norm(GRAD_DIFF);
- StopingCriteria stopValue = new ValueDifference(VAL_DIFF*(-llh));
- CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
- compositeStop.add(stopGrad);
- compositeStop.add(stopValue);
- optimizer.setMaxIterations(ITERATIONS);
- updateFunction();
- boolean success = optimizer.optimize(this,stats,compositeStop);
-
- System.out.println();
- System.out.println(stats.prettyPrint(1));
-
- if (success)
- System.out.print("\toptimization took " + optimizer.getCurrentIteration() + " iterations");
- else
- System.out.print("\toptimization failed to converge");
- long total = System.currentTimeMillis() - start;
- System.out.println(" and " + total + " ms: projection " + projectionTime +
- " actual " + actualProjectionTime + " objective " + objectiveTime);
-
- return success;
- }
-
- double loglikelihood()
- {
- return llh;
- }
-
- double KL_divergence()
- {
- return -loglikelihood + MathUtils.dotProduct(parameters, gradient);
- }
-
- double phrase_l1lmax()
- {
- // \sum_{tag,phrase} max_{context} P(tag|context,phrase)
- double sum=0;
- for (int p = 0; p < c.c.getNumPhrases(); ++p)
- {
- List<Edge> edges = c.c.getEdgesForPhrase(p);
- for(int tag=0;tag<c.K;tag++)
- {
- double max=0;
- for (Edge edge: edges)
- max = Math.max(max, q[edgeIndex.get(edge)][tag]);
- sum+=max;
- }
- }
- return sum;
- }
-
- double context_l1lmax()
- {
- // \sum_{tag,context} max_{phrase} P(tag|context,phrase)
- double sum=0;
- for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
- {
- List<Edge> edges = c.c.getEdgesForContext(ctx);
- for(int tag=0; tag<c.K; tag++)
- {
- double max=0;
- for (Edge edge: edges)
- max = Math.max(max, q[edgeIndex.get(edge)][tag]);
- sum+=max;
- }
- }
- return sum;
- }
-
- // L - KL(q||p) - scalePT * l1lmax_phrase - scaleCT * l1lmax_context
- public double primal()
- {
- return loglikelihood() - KL_divergence() - scalePT * phrase_l1lmax() - scaleCT * context_l1lmax();
- }
-} \ No newline at end of file