package phrase; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import optimization.gradientBasedMethods.ProjectedGradientDescent; import optimization.gradientBasedMethods.ProjectedObjective; import optimization.gradientBasedMethods.stats.OptimizerStats; import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc; import optimization.linesearch.InterpolationPickFirstStep; import optimization.linesearch.LineSearchMethod; import optimization.projections.SimplexProjection; import optimization.stopCriteria.CompositeStopingCriteria; import optimization.stopCriteria.ProjectedGradientL2Norm; import optimization.stopCriteria.StopingCriteria; import optimization.stopCriteria.ValueDifference; import optimization.util.MathUtils; import phrase.Corpus.Edge; public class PhraseContextObjective extends ProjectedObjective { private static final double GRAD_DIFF = 0.00002; private static double INIT_STEP_SIZE = 300; private static double VAL_DIFF = 1e-8; private static int ITERATIONS = 20; boolean debug = false; private PhraseCluster c; // un-regularized unnormalized posterior, p[edge][tag] // P(tag|edge) \propto P(tag|phrase)P(context|tag) private double p[][]; // regularized unnormalized posterior // q[edge][tag] propto p[edge][tag]*exp(-lambda) private double q[][]; private List<Corpus.Edge> data; // log likelihood under q private double loglikelihood; private SimplexProjection projectionPhrase; private SimplexProjection projectionContext; double[] newPoint; private int n_param; // likelihood under p public double llh; private static Map<Corpus.Edge, Integer> edgeIndex; private long projectionTime; private long objectiveTime; private long actualProjectionTime; private ExecutorService pool; double scalePT; double scaleCT; public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters, ExecutorService pool, double scalePT, double scaleCT) { c=cluster; data=c.c.getEdges(); n_param=data.size()*c.K*2; this.pool=pool; this.scalePT = scalePT; this.scaleCT = scaleCT; parameters = startingParameters; if (parameters == null) parameters = new double[n_param]; System.out.println("Num parameters " + n_param); newPoint = new double[n_param]; gradient = new double[n_param]; initP(); projectionPhrase = new SimplexProjection(scalePT); projectionContext = new SimplexProjection(scaleCT); q=new double [data.size()][c.K]; if (edgeIndex == null) { edgeIndex = new HashMap<Edge, Integer>(); for (int e=0; e<data.size(); e++) { edgeIndex.put(data.get(e), e); //if (debug) System.out.println("Edge " + data.get(e) + " index " + e); } } setParameters(parameters); } private void initP(){ p=new double[data.size()][]; for(int edge=0;edge<data.size();edge++) { p[edge]=c.posterior(data.get(edge)); llh += data.get(edge).getCount() * Math.log(arr.F.l1norm(p[edge])); arr.F.l1normalize(p[edge]); } } @Override public void setParameters(double[] params) { //System.out.println("setParameters " + Arrays.toString(parameters)); // TODO: test if params have changed and skip update otherwise super.setParameters(params); updateFunction(); } private void updateFunction() { updateCalls++; loglikelihood=0; System.out.print("."); System.out.flush(); long begin = System.currentTimeMillis(); for (int e=0; e<data.size(); e++) { Edge edge = data.get(e); for(int tag=0; tag<c.K; tag++) { int ip = index(e, tag, true); int ic = index(e, tag, false); q[e][tag] = p[e][tag]* Math.exp((-parameters[ip]-parameters[ic]) / edge.getCount()); //if (debug) //System.out.println("\tposterior " + edge + " with tag " + tag + " p " + p[e][tag] + " params " + parameters[ip] + " and " + parameters[ic] + " q " + q[e][tag]); } } for(int edge=0;edge<data.size();edge++) { loglikelihood+=data.get(edge).getCount() * Math.log(arr.F.l1norm(q[edge])); arr.F.l1normalize(q[edge]); } for (int e=0; e<data.size(); e++) { for(int tag=0; tag<c.K; tag++) { int ip = index(e, tag, true); int ic = index(e, tag, false); gradient[ip]=-q[e][tag]; gradient[ic]=-q[e][tag]; } } //if (debug) { //System.out.println("objective " + loglikelihood + " ||gradient||_2: " + arr.F.l2norm(gradient)); //System.out.println("gradient " + Arrays.toString(gradient)); //} objectiveTime += System.currentTimeMillis() - begin; } @Override public double[] projectPoint(double[] point) { long begin = System.currentTimeMillis(); List<Future<?>> tasks = new ArrayList<Future<?>>(); System.out.print(","); System.out.flush(); Arrays.fill(newPoint, 0, newPoint.length, 0); // first project using the phrase-tag constraints, // for all p,t: sum_c lambda_ptc < scaleP if (pool == null) { for (int p = 0; p < c.c.getNumPhrases(); ++p) { List<Edge> edges = c.c.getEdgesForPhrase(p); double[] toProject = new double[edges.size()]; for(int tag=0;tag<c.K;tag++) { // FIXME: slow hash lookup for e (twice) for(int e=0; e<edges.size(); e++) toProject[e] = point[index(edges.get(e), tag, true)]; long lbegin = System.currentTimeMillis(); projectionPhrase.project(toProject); actualProjectionTime += System.currentTimeMillis() - lbegin; for(int e=0; e<edges.size(); e++) newPoint[index(edges.get(e), tag, true)] = toProject[e]; } } } else // do above in parallel using thread pool { for (int p = 0; p < c.c.getNumPhrases(); ++p) { final int phrase = p; final double[] inPoint = point; Runnable task = new Runnable() { public void run() { List<Edge> edges = c.c.getEdgesForPhrase(phrase); double toProject[] = new double[edges.size()]; for(int tag=0;tag<c.K;tag++) { // FIXME: slow hash lookup for e for(int e=0; e<edges.size(); e++) toProject[e] = inPoint[index(edges.get(e), tag, true)]; projectionPhrase.project(toProject); for(int e=0; e<edges.size(); e++) newPoint[index(edges.get(e), tag, true)] = toProject[e]; } } }; tasks.add(pool.submit(task)); } } //System.out.println("after PT " + Arrays.toString(newPoint)); // now project using the context-tag constraints, // for all c,t: sum_p omega_pct < scaleC if (pool == null) { for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx) { List<Edge> edges = c.c.getEdgesForContext(ctx); double toProject[] = new double[edges.size()]; for(int tag=0;tag<c.K;tag++) { // FIXME: slow hash lookup for e for(int e=0; e<edges.size(); e++) toProject[e] = point[index(edges.get(e), tag, false)]; long lbegin = System.currentTimeMillis(); projectionContext.project(toProject); actualProjectionTime += System.currentTimeMillis() - lbegin; for(int e=0; e<edges.size(); e++) newPoint[index(edges.get(e), tag, false)] = toProject[e]; } } } else { // do above in parallel using thread pool for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx) { final int context = ctx; final double[] inPoint = point; Runnable task = new Runnable() { public void run() { List<Edge> edges = c.c.getEdgesForContext(context); double toProject[] = new double[edges.size()]; for(int tag=0;tag<c.K;tag++) { // FIXME: slow hash lookup for e for(int e=0; e<edges.size(); e++) toProject[e] = inPoint[index(edges.get(e), tag, false)]; projectionContext.project(toProject); for(int e=0; e<edges.size(); e++) newPoint[index(edges.get(e), tag, false)] = toProject[e]; } } }; tasks.add(pool.submit(task)); } } if (pool != null) { // wait for all the jobs to complete Exception failure = null; for (Future<?> task: tasks) { try { task.get(); } catch (InterruptedException e) { System.err.println("ERROR: Projection thread interrupted"); e.printStackTrace(); failure = e; } catch (ExecutionException e) { System.err.println("ERROR: Projection thread died"); e.printStackTrace(); failure = e; } } // rethrow the exception if (failure != null) { pool.shutdownNow(); throw new RuntimeException(failure); } } double[] tmp = newPoint; newPoint = point; projectionTime += System.currentTimeMillis() - begin; //if (debug) //System.out.println("\t\treturning " + Arrays.toString(tmp)); return tmp; } private int index(Edge edge, int tag, boolean phrase) { // NB if indexing changes must also change code in updateFunction and constructor if (phrase) return tag * edgeIndex.size() + edgeIndex.get(edge); else return (c.K + tag) * edgeIndex.size() + edgeIndex.get(edge); } private int index(int e, int tag, boolean phrase) { // NB if indexing changes must also change code in updateFunction and constructor if (phrase) return tag * edgeIndex.size() + e; else return (c.K + tag) * edgeIndex.size() + e; } @Override public double[] getGradient() { gradientCalls++; return gradient; } @Override public double getValue() { functionCalls++; return loglikelihood; } @Override public String toString() { return "No need for pointless toString"; } public double []posterior(int edgeIndex){ return q[edgeIndex]; } public boolean optimizeWithProjectedGradientDescent() { projectionTime = 0; actualProjectionTime = 0; objectiveTime = 0; long start = System.currentTimeMillis(); LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc (new InterpolationPickFirstStep(INIT_STEP_SIZE)); //LineSearchMethod ls = new WolfRuleLineSearch( // (new InterpolationPickFirstStep(INIT_STEP_SIZE)), c1, c2); OptimizerStats stats = new OptimizerStats(); ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls); StopingCriteria stopGrad = new ProjectedGradientL2Norm(GRAD_DIFF); StopingCriteria stopValue = new ValueDifference(VAL_DIFF*(-llh)); CompositeStopingCriteria compositeStop = new CompositeStopingCriteria(); compositeStop.add(stopGrad); compositeStop.add(stopValue); optimizer.setMaxIterations(ITERATIONS); updateFunction(); boolean success = optimizer.optimize(this,stats,compositeStop); System.out.println(); System.out.println(stats.prettyPrint(1)); if (success) System.out.print("\toptimization took " + optimizer.getCurrentIteration() + " iterations"); else System.out.print("\toptimization failed to converge"); long total = System.currentTimeMillis() - start; System.out.println(" and " + total + " ms: projection " + projectionTime + " actual " + actualProjectionTime + " objective " + objectiveTime); return success; } double loglikelihood() { return llh; } double KL_divergence() { return -loglikelihood + MathUtils.dotProduct(parameters, gradient); } double phrase_l1lmax() { // \sum_{tag,phrase} max_{context} P(tag|context,phrase) double sum=0; for (int p = 0; p < c.c.getNumPhrases(); ++p) { List<Edge> edges = c.c.getEdgesForPhrase(p); for(int tag=0;tag<c.K;tag++) { double max=0; for (Edge edge: edges) max = Math.max(max, q[edgeIndex.get(edge)][tag]); sum+=max; } } return sum; } double context_l1lmax() { // \sum_{tag,context} max_{phrase} P(tag|context,phrase) double sum=0; for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx) { List<Edge> edges = c.c.getEdgesForContext(ctx); for(int tag=0; tag<c.K; tag++) { double max=0; for (Edge edge: edges) max = Math.max(max, q[edgeIndex.get(edge)][tag]); sum+=max; } } return sum; } // L - KL(q||p) - scalePT * l1lmax_phrase - scaleCT * l1lmax_context public double primal() { return loglikelihood() - KL_divergence() - scalePT * phrase_l1lmax() - scaleCT * context_l1lmax(); } }