package phrase; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import optimization.gradientBasedMethods.ProjectedGradientDescent; import optimization.gradientBasedMethods.ProjectedObjective; import optimization.gradientBasedMethods.stats.OptimizerStats; import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc; import optimization.linesearch.InterpolationPickFirstStep; import optimization.linesearch.LineSearchMethod; import optimization.projections.SimplexProjection; import optimization.stopCriteria.CompositeStopingCriteria; import optimization.stopCriteria.ProjectedGradientL2Norm; import optimization.stopCriteria.StopingCriteria; import optimization.stopCriteria.ValueDifference; import optimization.util.MathUtils; import phrase.Corpus.Edge; public class PhraseContextObjective extends ProjectedObjective { private static final double GRAD_DIFF = 0.00002; private static double INIT_STEP_SIZE = 300; private static double VAL_DIFF = 1e-4; // FIXME needs to be tuned private static int ITERATIONS = 100; private PhraseCluster c; // un-regularized unnormalized posterior, p[edge][tag] // P(tag|edge) \propto P(tag|phrase)P(context|tag) private double p[][]; // regularized unnormalized posterior // q[edge][tag] propto p[edge][tag]*exp(-lambda) private double q[][]; private List data; // log likelihood under q private double loglikelihood; private SimplexProjection projectionPhrase; private SimplexProjection projectionContext; double[] newPoint; private int n_param; // likelihood under p public double llh; private Map edgeIndex; private long projectionTime; private long objectiveTime; private long actualProjectionTime; private ExecutorService pool; double scalePT; double scaleCT; public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters, ExecutorService pool, double scalePT, double scaleCT) { c=cluster; data=c.c.getEdges(); n_param=data.size()*c.K*2; this.pool=pool; this.scalePT = scalePT; this.scaleCT = scaleCT; parameters = startingParameters; if (parameters == null) parameters = new double[n_param]; newPoint = new double[n_param]; gradient = new double[n_param]; initP(); projectionPhrase = new SimplexProjection(scalePT); projectionContext = new SimplexProjection(scaleCT); q=new double [data.size()][c.K]; edgeIndex = new HashMap(); for (int e=0; e> tasks = new ArrayList>(); //System.out.println("\t\tprojectPoint: " + Arrays.toString(point)); Arrays.fill(newPoint, 0, newPoint.length, 0); // first project using the phrase-tag constraints, // for all p,t: sum_c lambda_ptc < scaleP if (pool == null) { for (int p = 0; p < c.c.getNumPhrases(); ++p) { List edges = c.c.getEdgesForPhrase(p); double[] toProject = new double[edges.size()]; for(int tag=0;tag edges = c.c.getEdgesForPhrase(phrase); double toProject[] = new double[edges.size()]; for(int tag=0;tag edges = c.c.getEdgesForContext(ctx); double toProject[] = new double[edges.size()]; for(int tag=0;tag edges = c.c.getEdgesForContext(context); double toProject[] = new double[edges.size()]; for(int tag=0;tag task: tasks) { try { task.get(); } catch (InterruptedException e) { System.err.println("ERROR: Projection thread interrupted"); e.printStackTrace(); failure = e; } catch (ExecutionException e) { System.err.println("ERROR: Projection thread died"); e.printStackTrace(); failure = e; } } // rethrow the exception if (failure != null) throw new RuntimeException(failure); } double[] tmp = newPoint; newPoint = point; projectionTime += System.currentTimeMillis() - begin; //System.out.println("\t\treturning " + Arrays.toString(tmp)); return tmp; } private int index(Edge edge, int tag, boolean phrase) { // NB if indexing changes must also change code in updateFunction and constructor if (phrase) return edgeIndex.get(edge)*c.K*2 + tag*2; else return edgeIndex.get(edge)*c.K*2 + tag*2 + 1; } @Override public double[] getGradient() { gradientCalls++; return gradient; } @Override public double getValue() { functionCalls++; return loglikelihood; } @Override public String toString() { return "No need for pointless toString"; } public double []posterior(int edgeIndex){ return q[edgeIndex]; } public double[] optimizeWithProjectedGradientDescent() { projectionTime = 0; actualProjectionTime = 0; objectiveTime = 0; long start = System.currentTimeMillis(); LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc (new InterpolationPickFirstStep(INIT_STEP_SIZE)); //LineSearchMethod ls = new WolfRuleLineSearch( // (new InterpolationPickFirstStep(INIT_STEP_SIZE)), c1, c2); OptimizerStats stats = new OptimizerStats(); ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls); StopingCriteria stopGrad = new ProjectedGradientL2Norm(GRAD_DIFF); StopingCriteria stopValue = new ValueDifference(VAL_DIFF*(-llh)); CompositeStopingCriteria compositeStop = new CompositeStopingCriteria(); compositeStop.add(stopGrad); compositeStop.add(stopValue); optimizer.setMaxIterations(ITERATIONS); updateFunction(); boolean success = optimizer.optimize(this,stats,compositeStop); // System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1)); if (success) System.out.print("\toptimization took " + optimizer.getCurrentIteration() + " iterations"); else System.out.print("\toptimization failed to converge"); long total = System.currentTimeMillis() - start; System.out.println(" and " + total + " ms: projection " + projectionTime + " actual " + actualProjectionTime + " objective " + objectiveTime); return parameters; } double loglikelihood() { return llh; } double KL_divergence() { return -loglikelihood + MathUtils.dotProduct(parameters, gradient); } double phrase_l1lmax() { // \sum_{tag,phrase} max_{context} P(tag|context,phrase) double sum=0; for (int p = 0; p < c.c.getNumPhrases(); ++p) { List edges = c.c.getEdgesForPhrase(p); for(int tag=0;tag edges = c.c.getEdgesForContext(ctx); for(int tag=0; tag