diff options
| -rw-r--r-- | gi/posterior-regularisation/PhraseContextModel.java | 802 | 
1 files changed, 522 insertions, 280 deletions
| diff --git a/gi/posterior-regularisation/PhraseContextModel.java b/gi/posterior-regularisation/PhraseContextModel.java index 3af72d54..d0a92dde 100644 --- a/gi/posterior-regularisation/PhraseContextModel.java +++ b/gi/posterior-regularisation/PhraseContextModel.java @@ -24,294 +24,536 @@  // - agreement between phrase->context model and context->phrase model  import java.io.*; +import optimization.gradientBasedMethods.*; +import optimization.gradientBasedMethods.stats.OptimizerStats; +import optimization.gradientBasedMethods.stats.ProjectedOptimizerStats; +import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc; +import optimization.linesearch.GenericPickFirstStep; +import optimization.linesearch.InterpolationPickFirstStep; +import optimization.linesearch.LineSearchMethod; +import optimization.linesearch.WolfRuleLineSearch; +import optimization.projections.SimplexProjection; +import optimization.stopCriteria.CompositeStopingCriteria; +import optimization.stopCriteria.NormalizedProjectedGradientL2Norm; +import optimization.stopCriteria.NormalizedValueDifference; +import optimization.stopCriteria.ProjectedGradientL2Norm; +import optimization.stopCriteria.StopingCriteria; +import optimization.stopCriteria.ValueDifference; +import optimization.util.MathUtils; +  import java.util.*;  import java.util.regex.*; +import gnu.trove.TDoubleArrayList;  import static java.lang.Math.*;  class Lexicon<T>  { -    public int insert(T word) -    { -        Integer i = wordToIndex.get(word); -        if (i == null) -        { -            i = indexToWord.size(); -            wordToIndex.put(word, i); -            indexToWord.add(word); -        } -        return i; -    } - -    public T lookup(int index) -    { -        return indexToWord.get(index); -    } - -    public int size() -    { -        return indexToWord.size(); -    } -     -    private Map<T,Integer> wordToIndex = new HashMap<T,Integer>(); -    private List<T> indexToWord = new ArrayList<T>(); +	public int insert(T word) +	{ +		Integer i = wordToIndex.get(word); +		if (i == null) +		{ +			i = indexToWord.size(); +			wordToIndex.put(word, i); +			indexToWord.add(word); +		} +		return i; +	} + +	public T lookup(int index) +	{ +		return indexToWord.get(index); +	} + +	public int size() +	{ +		return indexToWord.size(); +	} + +	private Map<T, Integer> wordToIndex = new HashMap<T, Integer>(); +	private List<T> indexToWord = new ArrayList<T>();  }  class PhraseContextModel  { -    // model/optimisation configuration parameters -    int numTags; -    int numPRIterations = 5; -    boolean posteriorRegularisation = false; -    double constraintScale = 10; - -    // book keeping -    Lexicon<String> tokenLexicon = new Lexicon<String>(); -    int numPositions; -    Random rng = new Random(); - -    // training set; 1 entry for each unique phrase -    PhraseAndContexts training[]; - -    // model parameters (learnt) -    double emissions[][][]; // position in 0 .. 3 x tag x word      Pr(word | tag, position) -    double prior[][];       // phrase x tag                         Pr(tag | phrase) -    //double lambda[][][];    // word x context x tag               langrange multipliers - -    PhraseContextModel(File infile, int tags) throws IOException -    { -        numTags = tags; -        readTrainingFromFile(new FileReader(infile)); -        assert(training.length > 0); - -        // now initialise emissions -        assert(training[0].contexts.length > 0); -        numPositions = training[0].contexts[0].tokens.length; - -        emissions = new double[numPositions][numTags][tokenLexicon.size()]; -        prior = new double[training.length][numTags]; -        //lambda = new double[tokenLexicon.size()][???][numTags] - -        for (double[][] emissionTW: emissions) -            for (double[] emissionW: emissionTW) -                randomise(emissionW); - -        for (double[] priorTag: prior) -            randomise(priorTag); -    } - -    void expectationMaximisation(int numIterations) -    {  -        for (int iteration = 0; iteration < numIterations; ++iteration) -        { -            double emissionsCounts[][][] = new double[numPositions][numTags][tokenLexicon.size()]; -            double priorCounts[][] = new double[training.length][numTags]; - -            // E-step -            double llh = 0; -            for (int i = 0; i < training.length; ++i) -            { -                PhraseAndContexts instance = training[i]; -                for (Context ctx: instance.contexts) -                { -                    double probs[] = posterior(i, ctx); -                    double z = normalise(probs); -                    llh += log(z) * ctx.count; - -                    for (int t = 0; t < numTags; ++t) -                    { -                        priorCounts[i][t] += ctx.count * probs[t]; -                        for (int j = 0; j < ctx.tokens.length; ++j) -                            emissionsCounts[j][t][ctx.tokens[j]] += ctx.count * probs[t]; -                    } -                } -            } - -            // M-step: normalise -            for (double[][] emissionTW: emissionsCounts) -                for (double[] emissionW: emissionTW) -                    normalise(emissionW); - -            for (double[] priorTag: priorCounts) -                normalise(priorTag); - -            emissions = emissionsCounts; -            prior = priorCounts; - -            System.out.println("Iteration " + iteration + " llh " + llh); -        } -    } - -    static double normalise(double probs[]) -    { -        double z = 0; -        for (double p: probs) -            z += p; -        for (int i = 0; i < probs.length; ++i) -            probs[i] /= z; -        return z; -    } - -    void randomise(double probs[]) -    { -        double z = 0; -        for (int i = 0; i < probs.length; ++i) -        { -            probs[i] = 10 + rng.nextDouble(); -            z += probs[i]; -        } -             -        for (int i = 0; i < probs.length; ++i) -            probs[i] /= z; -    } -     -    static int argmax(double probs[]) -    { -        double m = Double.NEGATIVE_INFINITY; -        int mi = -1; -        for (int i = 0; i < probs.length; ++i) -        { -            if (probs[i] > m) -            { -                m = probs[i]; -                mi = i; -            } -        } -        return mi; -    } - -    double[] posterior(int phraseId, Context c) // unnormalised -    { -        double probs[] = new double[numTags]; -        for (int t = 0; t < numTags; ++t) -        { -            probs[t] = prior[phraseId][t]; -            for (int j = 0; j < c.tokens.length; ++j) -                probs[t] *= emissions[j][t][c.tokens[j]]; -        } -        return probs; -    } - -    private void readTrainingFromFile(Reader in) throws IOException -    { -        // read in line-by-line -        BufferedReader bin = new BufferedReader(in); -        String line; -        List<PhraseAndContexts> instances = new ArrayList<PhraseAndContexts>(); -        Pattern separator = Pattern.compile(" \\|\\|\\| "); - -        int numEdges = 0; -        while ((line = bin.readLine()) != null) -        { -            // split into phrase and contexts -            StringTokenizer st = new StringTokenizer(line, "\t"); -            assert(st.hasMoreTokens()); -            String phrase = st.nextToken(); -            assert(st.hasMoreTokens()); -            String rest = st.nextToken(); -            assert(!st.hasMoreTokens()); - -            // process phrase -            st = new StringTokenizer(phrase, " "); -            List<Integer> ptoks = new ArrayList<Integer>(); -            while (st.hasMoreTokens()) -                ptoks.add(tokenLexicon.insert(st.nextToken())); - -            // process contexts -            ArrayList<Context> contexts = new ArrayList<Context>(); -            String[] parts = separator.split(rest); -            assert(parts.length % 2 == 0); -            for (int i = 0; i < parts.length; i += 2) -            { -                // process pairs of strings - context and count -                ArrayList<Integer> ctx = new ArrayList<Integer>(); -                String ctxString = parts[i]; -                String countString = parts[i+1]; -                StringTokenizer ctxStrtok = new StringTokenizer(ctxString, " "); -                while (ctxStrtok.hasMoreTokens()) -                { -                    String token = ctxStrtok.nextToken(); -                    if (!token.equals("<PHRASE>")) -                        ctx.add(tokenLexicon.insert(token)); -                } - -                assert(countString.startsWith("C=")); -                Context c = new Context(); -                c.count = Integer.parseInt(countString.substring(2).trim()); -                // damn unboxing doesn't work with toArray -                c.tokens = new int[ctx.size()]; -                for (int k = 0; k < ctx.size(); ++k) -                    c.tokens[k] = ctx.get(k); -                contexts.add(c); - -                numEdges += 1; -            } - -            // package up -            PhraseAndContexts instance = new PhraseAndContexts(); -            // damn unboxing doesn't work with toArray -            instance.phraseTokens = new int[ptoks.size()]; -            for (int k = 0; k < ptoks.size(); ++k) -                instance.phraseTokens[k] = ptoks.get(k); -            instance.contexts = contexts.toArray(new Context[] {}); -            instances.add(instance); -        } - -        training = instances.toArray(new PhraseAndContexts[] {}); - -        System.out.println("Read in " + training.length + " phrases and " + numEdges + " edges"); -    } - -    void displayPosterior() -    { -        for (int i = 0; i < training.length; ++i) -        { -            PhraseAndContexts instance = training[i]; -            for (Context ctx: instance.contexts) -            { -                double probs[] = posterior(i, ctx); -                double z = normalise(probs); - -                // emit phrase -                for (int t: instance.phraseTokens) -                    System.out.print(tokenLexicon.lookup(t) + " "); -                System.out.print("\t"); -                for (int c: ctx.tokens) -                    System.out.print(tokenLexicon.lookup(c) + " "); -                System.out.print("||| C=" + ctx.count + " |||"); - -                System.out.print(" " + argmax(probs)); -                //for (int t = 0; t < numTags; ++t) -                    //System.out.print(" " + probs[t]); -                System.out.println(); -            } -        } -    } - -    class PhraseAndContexts -    { -        int phraseTokens[]; -        Context contexts[]; -    } - -    class Context -    { -        int count; -        int[] tokens; -    } - -    public static void main(String []args) -    { -        assert(args.length >= 2); -        try -        { -            PhraseContextModel model = new PhraseContextModel(new File(args[0]), Integer.parseInt(args[1])); -            model.expectationMaximisation(Integer.parseInt(args[2])); -            model.displayPosterior(); -        } -        catch (IOException e) -        { -            System.out.println("Failed to read input file: " + args[0]); -            e.printStackTrace(); -        } -    } +	// model/optimisation configuration parameters +	int numTags, numEdges; +	boolean posteriorRegularisation = true; +	double constraintScale = 3; // FIXME: make configurable +	 +	// copied from L1LMax in depparsing code +	final double c1= 0.0001, c2=0.9, stoppingPrecision = 1e-5, maxStep = 10; +	final int maxZoomEvals = 10, maxExtrapolationIters = 200; +	int maxProjectionIterations = 200; +	int minOccurrencesForProjection = 0; + +	// book keeping +	Lexicon<String> tokenLexicon = new Lexicon<String>(); +	int numPositions; +	Random rng = new Random(); + +	// training set; 1 entry for each unique phrase +	PhraseAndContexts training[]; + +	// model parameters (learnt) +	double emissions[][][]; // position in 0 .. 3 x tag x word Pr(word | tag, position) +	double prior[][]; // phrase x tag Pr(tag | phrase) +	double lambda[]; // edge = (phrase, context) x tag flattened lagrange multipliers + +	PhraseContextModel(File infile, int tags) throws IOException +	{ +		numTags = tags; +		numEdges = 0; +		readTrainingFromFile(new FileReader(infile)); +		assert (training.length > 0); + +		// now initialise emissions +		assert (training[0].contexts.length > 0); +		numPositions = training[0].contexts[0].tokens.length; + +		emissions = new double[numPositions][numTags][tokenLexicon.size()]; +		prior = new double[training.length][numTags]; +		if (posteriorRegularisation) +			lambda = new double[numEdges * numTags]; + +		for (double[][] emissionTW : emissions) +			for (double[] emissionW : emissionTW) +				randomise(emissionW); + +		for (double[] priorTag : prior) +			randomise(priorTag); +	} + +	void expectationMaximisation(int numIterations) +	{ +		double lastLlh = Double.NEGATIVE_INFINITY; + +		for (int iteration = 0; iteration < numIterations; ++iteration) +		{ +			double emissionsCounts[][][] = new double[numPositions][numTags][tokenLexicon.size()]; +			double priorCounts[][] = new double[training.length][numTags]; +			 +			// E-step +			double llh = 0; +			if (posteriorRegularisation) +			{ +				EStepDualObjective objective = new EStepDualObjective(); +				 +				// copied from x2y2withconstraints +				LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc(new InterpolationPickFirstStep(1));				 +				OptimizerStats stats = new OptimizerStats(); +				ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls); +				CompositeStopingCriteria compositeStop = new CompositeStopingCriteria(); +				compositeStop.add(new ProjectedGradientL2Norm(0.001)); +				compositeStop.add(new ValueDifference(0.001)); +				optimizer.setMaxIterations(50); +				boolean succeed = optimizer.optimize(objective,stats,compositeStop); +				 +				// copied from depparser l1lmaxobjective +//				ProjectedOptimizerStats stats = new ProjectedOptimizerStats(); +//				GenericPickFirstStep pickFirstStep = new GenericPickFirstStep(1); +//				LineSearchMethod linesearch = new WolfRuleLineSearch(pickFirstStep, c1, c2); +//				ProjectedGradientDescent optimizer = new ProjectedGradientDescent(linesearch); +//				optimizer.setMaxIterations(maxProjectionIterations); +//		        StopingCriteria stopGrad = new NormalizedProjectedGradientL2Norm(stoppingPrecision); +//		        StopingCriteria stopValue = new NormalizedValueDifference(stoppingPrecision); +//		        CompositeStopingCriteria stop = new CompositeStopingCriteria(); +//		        stop.add(stopGrad); +//		        stop.add(stopValue); +//		        boolean succeed = optimizer.optimize(objective, stats, stop); + +				//System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1)); +				//System.out.println("Solution: " + objective.parameters); +				if (!succeed) +					System.out.println("Failed to optimize"); +				//System.out.println("Ended optimization in " + optimizer.getCurrentIteration());				 + +		        // make sure we update the dual params +				//llh = objective.getValue(); +				llh = objective.primal(); +				// FIXME: this is the dual not the primal and omits the llh term +				 +				for (int i = 0; i < training.length; ++i) +				{ +					PhraseAndContexts instance = training[i]; +					for (int j = 0; j < instance.contexts.length; ++j) +					{ +						Context c = instance.contexts[j]; +						for (int t = 0; t < numTags; t++) +						{ +							double p = objective.q.get(i).get(j).get(t); +							priorCounts[i][t] += c.count * p; +							for (int k = 0; k < c.tokens.length; ++k) +								emissionsCounts[k][t][c.tokens[k]] += c.count * p; +						} +					} +				} +			} +			else +			{ +				for (int i = 0; i < training.length; ++i) +				{ +					PhraseAndContexts instance = training[i]; +					for (Context ctx : instance.contexts) +					{ +						double probs[] = posterior(i, ctx); +						double z = normalise(probs); +						llh += log(z) * ctx.count; +	 +						for (int t = 0; t < numTags; ++t) +						{ +							priorCounts[i][t] += ctx.count * probs[t]; +							for (int j = 0; j < ctx.tokens.length; ++j) +								emissionsCounts[j][t][ctx.tokens[j]] += ctx.count * probs[t]; +						} +					} +				} +			} + +			// M-step: normalise +			for (double[][] emissionTW : emissionsCounts) +				for (double[] emissionW : emissionTW) +					normalise(emissionW); + +			for (double[] priorTag : priorCounts) +				normalise(priorTag); + +			emissions = emissionsCounts; +			prior = priorCounts; + +			System.out.println("Iteration " + iteration + " llh " + llh); + +//			if (llh - lastLlh < 1e-4) +//				break; +//			else +//				lastLlh = llh; +		} +	} + +	static double normalise(double probs[]) +	{ +		double z = 0; +		for (double p : probs) +			z += p; +		for (int i = 0; i < probs.length; ++i) +			probs[i] /= z; +		return z; +	} + +	void randomise(double probs[]) +	{ +		double z = 0; +		for (int i = 0; i < probs.length; ++i) +		{ +			probs[i] = 10 + rng.nextDouble(); +			z += probs[i]; +		} + +		for (int i = 0; i < probs.length; ++i) +			probs[i] /= z; +	} + +	static int argmax(double probs[]) +	{ +		double m = Double.NEGATIVE_INFINITY; +		int mi = -1; +		for (int i = 0; i < probs.length; ++i) +		{ +			if (probs[i] > m) +			{ +				m = probs[i]; +				mi = i; +			} +		} +		return mi; +	} + +	double[] posterior(int phraseId, Context c) // unnormalised +	{ +		double probs[] = new double[numTags]; +		for (int t = 0; t < numTags; ++t) +		{ +			probs[t] = prior[phraseId][t]; +			for (int j = 0; j < c.tokens.length; ++j) +				probs[t] *= emissions[j][t][c.tokens[j]]; +		} +		return probs; +	} + +	private void readTrainingFromFile(Reader in) throws IOException +	{ +		// read in line-by-line +		BufferedReader bin = new BufferedReader(in); +		String line; +		List<PhraseAndContexts> instances = new ArrayList<PhraseAndContexts>(); +		Pattern separator = Pattern.compile(" \\|\\|\\| "); + +		while ((line = bin.readLine()) != null) +		{ +			// split into phrase and contexts +			StringTokenizer st = new StringTokenizer(line, "\t"); +			assert (st.hasMoreTokens()); +			String phrase = st.nextToken(); +			assert (st.hasMoreTokens()); +			String rest = st.nextToken(); +			assert (!st.hasMoreTokens()); + +			// process phrase +			st = new StringTokenizer(phrase, " "); +			List<Integer> ptoks = new ArrayList<Integer>(); +			while (st.hasMoreTokens()) +				ptoks.add(tokenLexicon.insert(st.nextToken())); + +			// process contexts +			ArrayList<Context> contexts = new ArrayList<Context>(); +			String[] parts = separator.split(rest); +			assert (parts.length % 2 == 0); +			for (int i = 0; i < parts.length; i += 2) +			{ +				// process pairs of strings - context and count +				ArrayList<Integer> ctx = new ArrayList<Integer>(); +				String ctxString = parts[i]; +				String countString = parts[i + 1]; +				StringTokenizer ctxStrtok = new StringTokenizer(ctxString, " "); +				while (ctxStrtok.hasMoreTokens()) +				{ +					String token = ctxStrtok.nextToken(); +					if (!token.equals("<PHRASE>")) +						ctx.add(tokenLexicon.insert(token)); +				} + +				assert (countString.startsWith("C=")); +				Context c = new Context(); +				c.count = Integer.parseInt(countString.substring(2).trim()); +				// damn unboxing doesn't work with toArray +				c.tokens = new int[ctx.size()]; +				for (int k = 0; k < ctx.size(); ++k) +					c.tokens[k] = ctx.get(k); +				contexts.add(c); + +				numEdges += 1; +			} + +			// package up +			PhraseAndContexts instance = new PhraseAndContexts(); +			// damn unboxing doesn't work with toArray +			instance.phraseTokens = new int[ptoks.size()]; +			for (int k = 0; k < ptoks.size(); ++k) +				instance.phraseTokens[k] = ptoks.get(k); +			instance.contexts = contexts.toArray(new Context[] {}); +			instances.add(instance); +		} + +		training = instances.toArray(new PhraseAndContexts[] {}); + +		System.out.println("Read in " + training.length + " phrases and " + numEdges + " edges"); +	} + +	void displayPosterior() +	{ +		for (int i = 0; i < training.length; ++i) +		{ +			PhraseAndContexts instance = training[i]; +			for (Context ctx : instance.contexts) +			{ +				double probs[] = posterior(i, ctx); +				normalise(probs); + +				// emit phrase +				for (int t : instance.phraseTokens) +					System.out.print(tokenLexicon.lookup(t) + " "); +				System.out.print("\t"); +				for (int c : ctx.tokens) +					System.out.print(tokenLexicon.lookup(c) + " "); +				System.out.print("||| C=" + ctx.count + " |||"); + +				int t = argmax(probs); +				System.out.print(" " + t + " ||| " + probs[t]); +				// for (int t = 0; t < numTags; ++t) +				// System.out.print(" " + probs[t]); +				System.out.println(); +			} +		} +	} + +	class PhraseAndContexts +	{ +		int phraseTokens[]; +		Context contexts[]; +	} + +	class Context +	{ +		int count; +		int[] tokens; +	} + +	public static void main(String[] args) +	{ +		assert (args.length >= 2); +		try +		{ +			PhraseContextModel model = new PhraseContextModel(new File(args[0]), Integer.parseInt(args[1])); +			model.expectationMaximisation(Integer.parseInt(args[2])); +			model.displayPosterior(); +		}  +		catch (IOException e) +		{ +			System.out.println("Failed to read input file: " + args[0]); +			e.printStackTrace(); +		} +	} + +	class EStepDualObjective extends ProjectedObjective +	{ +		List<List<TDoubleArrayList>> conditionals; // phrase id x context # x tag - precomputed +		List<List<TDoubleArrayList>> q; // ditto, but including exp(-lambda) terms +		double objective = 0; // log(z) +		// Objective.gradient = d log(z) / d lambda = E_q[phi] +		double llh = 0; + +		public EStepDualObjective() +		{ +			super(); +			// compute conditionals p(context, tag | phrase) for all training instances +			conditionals = new ArrayList<List<TDoubleArrayList>>(training.length); +			q = new ArrayList<List<TDoubleArrayList>>(training.length); +			for (int i = 0; i < training.length; ++i) +			{ +				PhraseAndContexts instance = training[i]; +				conditionals.add(new ArrayList<TDoubleArrayList>(instance.contexts.length)); +				q.add(new ArrayList<TDoubleArrayList>(instance.contexts.length)); + +				for (int j = 0; j < instance.contexts.length; ++j) +				{ +					Context c = instance.contexts[j]; +					double probs[] = posterior(i, c); +					double z = normalise(probs); +					llh += log(z) * c.count; +					conditionals.get(i).add(new TDoubleArrayList(probs)); +					q.get(i).add(new TDoubleArrayList(probs)); +				} +			} +			 +			gradient = new double[numEdges*numTags]; +			setInitialParameters(lambda); +		} + +		@Override +		public double[] projectPoint(double[] point) +		{ +			SimplexProjection p = new SimplexProjection(constraintScale); + +			double[] newPoint = point.clone(); +			int edgeIndex = 0; +			for (int i = 0; i < training.length; ++i) +			{ +				PhraseAndContexts instance = training[i];				 + +				for (int t = 0; t < numTags; t++) +				{ +					double[] subPoint = new double[instance.contexts.length]; +					for (int j = 0; j < instance.contexts.length; ++j) +						subPoint[j] = point[edgeIndex+j*numTags+t]; + +					p.project(subPoint); +					for (int j = 0; j < instance.contexts.length; ++j) +						newPoint[edgeIndex+j*numTags+t] = subPoint[j]; +				} +				 +				edgeIndex += instance.contexts.length * numTags; +			} +			//System.out.println("Project point: " + Arrays.toString(point)  +			//		+ " => " + Arrays.toString(newPoint)); +			return newPoint; +		} + +		@Override +		public void setParameters(double[] params) +		{ +			super.setParameters(params); +			computeObjectiveAndGradient(); +		} + +		@Override +		public double[] getGradient() +		{ +			return gradient; +		} + +		@Override +		public double getValue() +		{ +			return objective; +		} + +		public void computeObjectiveAndGradient() +		{ +			int edgeIndex = 0; +			objective = 0; +			Arrays.fill(gradient, 0); +			for (int i = 0; i < training.length; ++i) +			{ +				PhraseAndContexts instance = training[i]; +				 +				for (int j = 0; j < instance.contexts.length; ++j) +				{ +					Context c = instance.contexts[j]; +					 +					double z = 0; +					for (int t = 0; t < numTags; t++) +					{ +						double v = conditionals.get(i).get(j).get(t) * exp(-parameters[edgeIndex+t]); +						q.get(i).get(j).set(t, v); +						z += v; +					} +					objective = log(z) * c.count; + +					for (int t = 0; t < numTags; t++) +					{ +						double v = q.get(i).get(j).get(t) / z;  +						q.get(i).get(j).set(t, v); +						gradient[edgeIndex+t] -= c.count * v; +					} +					 +					edgeIndex += numTags; +				} +			}			 +//			System.out.println("computeObjectiveAndGradient logz=" + objective); +//			System.out.println("gradient=" + Arrays.toString(gradient)); +		} + +		public String toString() +		{ +			StringBuilder sb = new StringBuilder(); +			sb.append(getClass().getCanonicalName()).append(" with "); +			sb.append(parameters.length).append(" parameters and "); +			sb.append(training.length * numTags).append(" constraints"); +			return sb.toString(); +		} +				 +		double primal() +		{ +			// primal = llh + KL(q||p) + scale * sum_pt max_c E_q[phi_pct] +			// kl = sum_Y q(Y) log q(Y) / p(Y|X) +			//    = sum_Y q(Y) { -lambda . phi(Y) - log Z } +			//    = -log Z - lambda . E_q[phi] +			 +			double kl = -objective - MathUtils.dotProduct(parameters, gradient); +			double l1lmax = 0; +			for (int i = 0; i < training.length; ++i) +			{ +				PhraseAndContexts instance = training[i];				 +				for (int t = 0; t < numTags; t++) +				{ +					double lmax = Double.NEGATIVE_INFINITY; +					for (int j = 0; j < instance.contexts.length; ++j) +						lmax = max(lmax, q.get(i).get(j).get(t)); +					l1lmax += lmax; +				} +			} + +			return llh + kl + constraintScale * l1lmax; +		} +	}  } | 
