diff options
Diffstat (limited to 'gi/posterior-regularisation')
8 files changed, 118 insertions, 110 deletions
diff --git a/gi/posterior-regularisation/prjava/Makefile b/gi/posterior-regularisation/prjava/Makefile index abd9b964..a16adcde 100644 --- a/gi/posterior-regularisation/prjava/Makefile +++ b/gi/posterior-regularisation/prjava/Makefile @@ -1,5 +1,5 @@  all: -	ant +	ant dist  clean:  	ant clean diff --git a/gi/posterior-regularisation/prjava/src/arr/F.java b/gi/posterior-regularisation/prjava/src/arr/F.java index 54dadeac..79de5d1a 100644 --- a/gi/posterior-regularisation/prjava/src/arr/F.java +++ b/gi/posterior-regularisation/prjava/src/arr/F.java @@ -56,6 +56,7 @@ public class F {  	}
  	public static double l1norm(double a[]){
 +		// FIXME: this isn't the l1 norm for a < 0
  		double norm=0;
  		for(int i=0;i<a.length;i++){
  			norm += a[i];
 @@ -63,6 +64,14 @@ public class F {  		return norm;
  	}
 +	public static double l2norm(double a[]){
 +		double norm=0;
 +		for(int i=0;i<a.length;i++){
 +			norm += a[i]*a[i];
 +		}
 +		return Math.sqrt(norm);
 +	}
 +	
  	public static int argmax(double probs[])
  	{
  		double m = Double.NEGATIVE_INFINITY;
 diff --git a/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/AbstractGradientBaseMethod.java b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/AbstractGradientBaseMethod.java index 0a4a5445..2fcb7990 100644 --- a/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/AbstractGradientBaseMethod.java +++ b/gi/posterior-regularisation/prjava/src/optimization/gradientBasedMethods/AbstractGradientBaseMethod.java @@ -56,9 +56,10 @@ public abstract class AbstractGradientBaseMethod implements Optimizer{  		stats.collectInitStats(this, o);  		direction = new double[o.getNumParameters()];  		initializeStructures(o, stats, stop); -		for (currentProjectionIteration = 1; currentProjectionIteration < maxNumberOfIterations; currentProjectionIteration++){		 -//			System.out.println("starting iterations: parameters:" ); -//			o.printParameters(); +		for (currentProjectionIteration = 1; currentProjectionIteration < maxNumberOfIterations; currentProjectionIteration++){ +			//System.out.println("\tgradient descent iteration " + currentProjectionIteration); +			//System.out.print("\tparameters:" ); +			//o.printParameters();  			previousValue = currValue;  			currValue = o.getValue();  			gradient = o.getGradient(); @@ -76,7 +77,7 @@ public abstract class AbstractGradientBaseMethod implements Optimizer{  			updateStructuresBeforeStep(o, stats, stop);  			lso.reset(direction);  			step = lineSearch.getStepSize(lso); -//			System.out.println("Leave with step: " + step); +			//System.out.println("\t\tLeave with step: " + step);  			if(step==-1){  				System.out.println("Failed to find step");  				stats.collectFinalStats(this, o); diff --git a/gi/posterior-regularisation/prjava/src/optimization/projections/SimplexProjection.java b/gi/posterior-regularisation/prjava/src/optimization/projections/SimplexProjection.java index eec11bcf..f22afcaf 100644 --- a/gi/posterior-regularisation/prjava/src/optimization/projections/SimplexProjection.java +++ b/gi/posterior-regularisation/prjava/src/optimization/projections/SimplexProjection.java @@ -40,7 +40,7 @@ public class SimplexProjection extends Projection{  		for (int i = 0; i < ds.length; i++) {  			currentSum+=ds[i];  			theta = (currentSum-scale)/(i+1); -			if(ds[i]-theta <= 0){ +			if(ds[i]-theta < -1e-10){  				break;  			}  			previousTheta = theta; diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index abd868c4..68148248 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -373,12 +373,13 @@ public class PhraseCluster {  		return primal;
  	}
 +	
 +	double[] lambda;
  	public double PREM_phrase_context_constraints(double scalePT, double scaleCT)
  	{	
  		double[][][] exp_emit = new double [K][n_positions][n_words];
  		double[][] exp_pi = new double[n_phrases][K];
 -		double[] lambda = null;
  		//E step
  		PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool, scalePT, scaleCT);
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java index ff135a3d..a9d3529c 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java @@ -32,7 +32,7 @@ public class PhraseContextObjective extends ProjectedObjective  	private PhraseCluster c;
 -	// un-regularized  unnormalized posterior, p[edge][tag]
 +	// un-regularized unnormalized posterior, p[edge][tag]
  	// P(tag|edge) \propto P(tag|phrase)P(context|tag)
  	private double p[][];
 @@ -144,7 +144,7 @@ public class PhraseContextObjective extends ProjectedObjective  				gradient[ic]=-q[e][tag];
  			}
  		}
 -		//System.out.println("objective " + loglikelihood + " gradient: " + Arrays.toString(gradient));		
 +		//System.out.println("objective " + loglikelihood + " ||gradient||_2: " + arr.F.l2norm(gradient));		
  		objectiveTime += System.currentTimeMillis() - begin;
  	}
 @@ -154,106 +154,100 @@ public class PhraseContextObjective extends ProjectedObjective  		long begin = System.currentTimeMillis();
  		List<Future<?>> tasks = new ArrayList<Future<?>>();
 -		//System.out.println("projectPoint: " + Arrays.toString(point));
 +		//System.out.println("\t\tprojectPoint: " + Arrays.toString(point));
  		Arrays.fill(newPoint, 0, newPoint.length, 0);
 -		if (scalePT > 0)
 +		// first project using the phrase-tag constraints,
 +		// for all p,t: sum_c lambda_ptc < scaleP 
 +		if (pool == null)
  		{
 -			// first project using the phrase-tag constraints,
 -			// for all p,t: sum_c lambda_ptc < scaleP 
 -			if (pool == null)
 +			for (int p = 0; p < c.c.getNumPhrases(); ++p)
  			{
 -				for (int p = 0; p < c.c.getNumPhrases(); ++p)
 +				List<Edge> edges = c.c.getEdgesForPhrase(p);
 +				double[] toProject = new double[edges.size()];
 +				for(int tag=0;tag<c.K;tag++)
  				{
 -					List<Edge> edges = c.c.getEdgesForPhrase(p);
 -					double[] toProject = new double[edges.size()];
 -					for(int tag=0;tag<c.K;tag++)
 -					{
 -						for(int e=0; e<edges.size(); e++)
 -							toProject[e] = point[index(edges.get(e), tag, true)];
 -						long lbegin = System.currentTimeMillis();
 -						projectionPhrase.project(toProject);
 -						actualProjectionTime += System.currentTimeMillis() - lbegin;
 -						for(int e=0; e<edges.size(); e++)
 -							newPoint[index(edges.get(e), tag, true)] = toProject[e];
 -					}
 +					for(int e=0; e<edges.size(); e++)
 +						toProject[e] = point[index(edges.get(e), tag, true)];
 +					long lbegin = System.currentTimeMillis();
 +					projectionPhrase.project(toProject);
 +					actualProjectionTime += System.currentTimeMillis() - lbegin;
 +					for(int e=0; e<edges.size(); e++)
 +						newPoint[index(edges.get(e), tag, true)] = toProject[e];
  				}
  			}
 -			else // do above in parallel using thread pool
 -			{	
 -				for (int p = 0; p < c.c.getNumPhrases(); ++p)
 +		}
 +		else // do above in parallel using thread pool
 +		{	
 +			for (int p = 0; p < c.c.getNumPhrases(); ++p)
 +			{
 +				final int phrase = p;
 +				final double[] inPoint = point;
 +				Runnable task = new Runnable()
  				{
 -					final int phrase = p;
 -					final double[] inPoint = point;
 -					Runnable task = new Runnable()
 +					public void run()
  					{
 -						public void run()
 +						List<Edge> edges = c.c.getEdgesForPhrase(phrase);
 +						double toProject[] = new double[edges.size()];
 +						for(int tag=0;tag<c.K;tag++)
  						{
 -							List<Edge> edges = c.c.getEdgesForPhrase(phrase);
 -							double toProject[] = new double[edges.size()];
 -							for(int tag=0;tag<c.K;tag++)
 -							{
 -								for(int e=0; e<edges.size(); e++)
 -									toProject[e] = inPoint[index(edges.get(e), tag, true)];
 -								projectionPhrase.project(toProject);
 -								for(int e=0; e<edges.size(); e++)
 -									newPoint[index(edges.get(e), tag, true)] = toProject[e];
 -							}
 -						}		
 -					};
 -					tasks.add(pool.submit(task));
 -				}
 +							for(int e=0; e<edges.size(); e++)
 +								toProject[e] = inPoint[index(edges.get(e), tag, true)];
 +							projectionPhrase.project(toProject);
 +							for(int e=0; e<edges.size(); e++)
 +								newPoint[index(edges.get(e), tag, true)] = toProject[e];
 +						}
 +					}		
 +				};
 +				tasks.add(pool.submit(task));
  			}
  		}
  		//System.out.println("after PT " + Arrays.toString(newPoint));
 -		if (scaleCT > 1e-6)
 +		// now project using the context-tag constraints,
 +		// for all c,t: sum_p omega_pct < scaleC
 +		if (pool == null)
  		{
 -			// now project using the context-tag constraints,
 -			// for all c,t: sum_p omega_pct < scaleC
 -			if (pool == null)
 +			for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
  			{
 -				for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
 +				List<Edge> edges = c.c.getEdgesForContext(ctx);
 +				double toProject[] = new double[edges.size()];
 +				for(int tag=0;tag<c.K;tag++)
  				{
 -					List<Edge> edges = c.c.getEdgesForContext(ctx);
 -					double toProject[] = new double[edges.size()];
 -					for(int tag=0;tag<c.K;tag++)
 -					{
 -						for(int e=0; e<edges.size(); e++)
 -							toProject[e] = point[index(edges.get(e), tag, false)];
 -						long lbegin = System.currentTimeMillis();
 -						projectionContext.project(toProject);
 -						actualProjectionTime += System.currentTimeMillis() - lbegin;
 -						for(int e=0; e<edges.size(); e++)
 -							newPoint[index(edges.get(e), tag, false)] = toProject[e];
 -					}
 +					for(int e=0; e<edges.size(); e++)
 +						toProject[e] = point[index(edges.get(e), tag, false)];
 +					long lbegin = System.currentTimeMillis();
 +					projectionContext.project(toProject);
 +					actualProjectionTime += System.currentTimeMillis() - lbegin;
 +					for(int e=0; e<edges.size(); e++)
 +						newPoint[index(edges.get(e), tag, false)] = toProject[e];
  				}
  			}
 -			else
 +		}
 +		else
 +		{
 +			// do above in parallel using thread pool
 +			for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
  			{
 -				// do above in parallel using thread pool
 -				for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
 +				final int context = ctx;
 +				final double[] inPoint = point;
 +				Runnable task = new Runnable()
  				{
 -					final int context = ctx;
 -					final double[] inPoint = point;
 -					Runnable task = new Runnable()
 +					public void run()
  					{
 -						public void run()
 +						List<Edge> edges = c.c.getEdgesForContext(context);
 +						double toProject[] = new double[edges.size()];
 +						for(int tag=0;tag<c.K;tag++)
  						{
 -							List<Edge> edges = c.c.getEdgesForContext(context);
 -							double toProject[] = new double[edges.size()];
 -							for(int tag=0;tag<c.K;tag++)
 -							{
 -								for(int e=0; e<edges.size(); e++)
 -									toProject[e] = inPoint[index(edges.get(e), tag, false)];
 -								projectionContext.project(toProject);
 -								for(int e=0; e<edges.size(); e++)
 -									newPoint[index(edges.get(e), tag, false)] = toProject[e];
 -							}
 +							for(int e=0; e<edges.size(); e++)
 +								toProject[e] = inPoint[index(edges.get(e), tag, false)];
 +							projectionContext.project(toProject);
 +							for(int e=0; e<edges.size(); e++)
 +								newPoint[index(edges.get(e), tag, false)] = toProject[e];
  						}
 -					};
 -					tasks.add(pool.submit(task));
 -				}
 +					}
 +				};
 +				tasks.add(pool.submit(task));
  			}
  		}
 @@ -283,9 +277,8 @@ public class PhraseContextObjective extends ProjectedObjective  		double[] tmp = newPoint;
  		newPoint = point;
  		projectionTime += System.currentTimeMillis() - begin;
 -
 -		//System.out.println("\treturning " + Arrays.toString(tmp));
 +		//System.out.println("\t\treturning " + Arrays.toString(tmp));
  		return tmp;
  	}
 @@ -405,6 +398,6 @@ public class PhraseContextObjective extends ProjectedObjective  	// L - KL(q||p) - scalePT * l1lmax_phrase - scaleCT * l1lmax_context
  	public double primal()
  	{
 -		return loglikelihood() - KL_divergence() - scalePT * phrase_l1lmax() - scalePT * context_l1lmax();
 +		return loglikelihood() - KL_divergence() - scalePT * phrase_l1lmax() - scaleCT * context_l1lmax();
  	}
  }
\ No newline at end of file diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 33167c20..0e2ab4b9 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -1,5 +1,6 @@  package phrase;
 +import java.util.Arrays;
  import java.util.List;
  import optimization.gradientBasedMethods.ProjectedGradientDescent;
 @@ -155,7 +156,7 @@ public class PhraseObjective extends ProjectedObjective  	@Override
  	public String toString() {
 -		return "No need for pointless toString";
 +		return Arrays.toString(parameters);
  	}
  	public double [][]posterior(){
 diff --git a/gi/posterior-regularisation/train_pr_global.py b/gi/posterior-regularisation/train_pr_global.py index f2806b6e..8521bccb 100644 --- a/gi/posterior-regularisation/train_pr_global.py +++ b/gi/posterior-regularisation/train_pr_global.py @@ -45,7 +45,7 @@ print 'edges_phrase_to_context', edges_phrase_to_context  # Step 2: initialise the model parameters  # -num_tags = 5 +num_tags = 10  num_types = len(types)  num_phrases = len(edges_phrase_to_context)  num_contexts = len(edges_context_to_phrase) @@ -56,11 +56,11 @@ def normalise(a):      return a / float(sum(a))  # Pr(tag | phrase) -#tagDist = [normalise(random(num_tags)+1) for p in range(num_phrases)] -tagDist = [normalise(array(range(1,num_tags+1))) for p in range(num_phrases)] +tagDist = [normalise(random(num_tags)+1) for p in range(num_phrases)] +#tagDist = [normalise(array(range(1,num_tags+1))) for p in range(num_phrases)]  # Pr(context at pos i = w | tag) indexed by i, tag, word -contextWordDist = [[normalise(array(range(1,num_types+1))) for t in range(num_tags)] for i in range(4)] -#contextWordDist = [[normalise(random(num_types)+1) for t in range(num_tags)] for i in range(4)] +#contextWordDist = [[normalise(array(range(1,num_types+1))) for t in range(num_tags)] for i in range(4)] +contextWordDist = [[normalise(random(num_types)+1) for t in range(num_tags)] for i in range(4)]  # PR langrange multipliers  lamba = zeros(2 * num_edges * num_tags)  omega_offset = num_edges * num_tags @@ -99,6 +99,8 @@ for iteration in range(20):                  cz = sum(conditionals)                  conditionals /= cz +                #print 'dual', phrase, context, count, 'p =', conditionals +                  local_z = 0                  for t in range(num_tags):                      li = lamba_index[phrase,context] + t @@ -106,8 +108,8 @@ for iteration in range(20):                  logz += log(local_z) * count          #print 'ls', ls -        print 'lambda', list(ls) -        print 'dual', logz +        #print 'lambda', list(ls) +        #print 'dual', logz          return logz      def loglikelihood(): @@ -146,12 +148,12 @@ for iteration in range(20):              for t in range(num_tags):                  best = -1e500                  for phrase, count in pcs: -                    li = lamba_index[phrase,context] + t +                    li = omega_offset + lamba_index[phrase,context] + t                      s = expectations[li]                      if s > best: best = s                  ct_l1linf += best -        return llh, kl, pt_l1linf, ct_l1linf, llh + kl + delta * pt_l1linf + gamma * ct_l1linf +        return llh, kl, pt_l1linf, ct_l1linf, llh - kl - delta * pt_l1linf - gamma * ct_l1linf      def dual_deriv(ls):          # d/dl log(z) = E_q[phi] @@ -173,13 +175,13 @@ for iteration in range(20):                      scores[t] = conditionals[t] * exp(-ls[li] - ls[omega_offset + li])                  local_z = sum(scores) +                #print 'ddual', phrase, context, count, 'q =', scores / local_z +                  for t in range(num_tags): -                    if delta > 0: -                        deriv[lamba_index[phrase,context] + t] -= count * scores[t] / local_z -                    if gamma > 0: -                        deriv[omega_offset + lamba_index[phrase,context] + t] -= count * scores[t] / local_z +                    deriv[lamba_index[phrase,context] + t] -= count * scores[t] / local_z +                    deriv[omega_offset + lamba_index[phrase,context] + t] -= count * scores[t] / local_z -        print 'ddual', list(deriv) +        #print 'ddual', list(deriv)          return deriv      def constraints(ls): @@ -244,7 +246,7 @@ for iteration in range(20):      print 'Post lambda optimisation dual', dual(lamba), 'primal', primal(lamba)      # E-step -    llh = z = 0 +    llh = log_z = 0      for p, (phrase, ccs) in enumerate(edges_phrase_to_context):          for context, count in ccs:              conditionals = zeros(num_tags) @@ -257,20 +259,21 @@ for iteration in range(20):              conditionals /= cz              llh += log(cz) * count -            scores = zeros(num_tags) +            q = zeros(num_tags)              li = lamba_index[phrase, context]              for t in range(num_tags): -                scores[t] = conditionals[t] * exp(-lamba[li + t] - lamba[omega_offset + li + t]) -            z += count * sum(scores) +                q[t] = conditionals[t] * exp(-lamba[li + t] - lamba[omega_offset + li + t]) +            qz = sum(q) +            log_z += count * log(qz)              for t in range(num_tags): -                tagCounts[p][t] += count * scores[t] +                tagCounts[p][t] += count * q[t] / qz              for i in range(4):                  for t in range(num_tags): -                    contextWordCounts[i][t][types[context[i]]] += count * scores[t] +                    contextWordCounts[i][t][types[context[i]]] += count * q[t] / qz -    print 'iteration', iteration, 'llh', llh, 'logz', log(z) +    print 'iteration', iteration, 'llh', llh, 'logz', log_z      # M-step      for p in range(num_phrases):  | 
