diff options
Diffstat (limited to 'gi/posterior-regularisation')
4 files changed, 99 insertions, 23 deletions
| diff --git a/gi/posterior-regularisation/prjava/src/phrase/C2F.java b/gi/posterior-regularisation/prjava/src/phrase/C2F.java index a8e557f2..31fd4fda 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/C2F.java +++ b/gi/posterior-regularisation/prjava/src/phrase/C2F.java @@ -38,10 +38,10 @@ public class C2F {  		n_contexts=c.getNumContexts();
  		//number of words in a phrase to be considered
 -		//currently the first and last word
 -		//if the phrase has length 1 
 -		//use the same word for two positions
 -		n_positions=2;
 +		//currently the first and last word in source and target
 +		//if the phrase has length 1 in either dimension then
 +		//we use the same word for two positions
 +		n_positions=c.phraseEdges(c.getEdges().get(0).getPhrase()).size();
  		emit=new double [K][n_positions][n_words];
  		pi=new double[n_contexts][K];
 @@ -156,9 +156,13 @@ public class C2F {  		double[] prob=Arrays.copyOf(pi[edge.getContextId()], K);
  		TIntArrayList phrase = edge.getPhrase();
 +		TIntArrayList offsets = c.phraseEdges(phrase);
  		for(int tag=0;tag<K;tag++)
 -			prob[tag]*=emit[tag][0][phrase.get(0)]
 -			                        *emit[tag][1][phrase.get(phrase.size()-1)];
 +		{
 +			for (int i=0; i < offsets.size(); ++i)
 +				prob[tag]*=emit[tag][i][phrase.get(offsets.get(i))];
 +		}
 +			
  		return prob;
  	}
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java index d57f3c04..2de2797b 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java @@ -15,6 +15,14 @@ public class Corpus  	private List<Edge> edges = new ArrayList<Edge>();  	private List<List<Edge>> phraseToContext = new ArrayList<List<Edge>>();  	private List<List<Edge>> contextToPhrase = new ArrayList<List<Edge>>(); +	public int splitSentinel; +	public int phraseSentinel; +	 +	public Corpus() +	{ +		splitSentinel = wordLexicon.insert("<SPLIT>"); +		phraseSentinel = wordLexicon.insert("<PHRASE>");		 +	}  	public class Edge  	{ @@ -157,6 +165,11 @@ public class Corpus  		return b.toString();  	} +	public boolean isSentinel(int wordId) +	{ +		return wordId == splitSentinel || wordId == phraseSentinel; +	} +	  	static Corpus readFromFile(Reader in) throws IOException  	{  		Corpus c = new Corpus(); @@ -218,6 +231,19 @@ public class Corpus  		return c;  	} +	 +	TIntArrayList phraseEdges(TIntArrayList phrase) +	{ +		TIntArrayList r = new TIntArrayList(4); +		for (int p = 0; p < phrase.size(); ++p) +		{ +			if (p == 0 || phrase.get(p-1) == splitSentinel) 				 +				r.add(p); +			if (p == phrase.size() - 1 || phrase.get(p+1) == splitSentinel)  +				r.add(p); +		} +		return r; +	}  	public void printStats(PrintStream out)   	{ diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index a369b319..5efaf52e 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -69,16 +69,29 @@ public class PhraseCluster {  		pool = Executors.newFixedThreadPool(threads);
  	}
 -	public double EM()
 +	public double EM(boolean skipBigPhrases)
  	{
  		double [][][]exp_emit=new double [K][n_positions][n_words];
  		double [][]exp_pi=new double[n_phrases][K];
 +		if (skipBigPhrases)
 +		{
 +			for(double [][]i:exp_emit)
 +				for(double []j:i)
 +					Arrays.fill(j, 1e-100);
 +		}
 +		
  		double loglikelihood=0;
  		//E
  		for(int phrase=0; phrase < n_phrases; phrase++)
  		{
 +			if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
 +			{
 +				System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
 +				continue;
 +			}			
 +
  			List<Edge> contexts = c.getEdgesForPhrase(phrase);
  			for (int ctx=0; ctx<contexts.size(); ctx++)
 @@ -116,9 +129,10 @@ public class PhraseCluster {  		return loglikelihood;
  	}
 -	public double VBEM(double alphaEmit, double alphaPi)
 +	public double VBEM(double alphaEmit, double alphaPi, boolean skipBigPhrases)
  	{
  		// FIXME: broken - needs to be done entirely in log-space
 +		assert !skipBigPhrases : "FIXME: implement this!";
  		double [][][]exp_emit = new double [K][n_positions][n_words];
  		double [][]exp_pi = new double[n_phrases][K];
 @@ -217,24 +231,31 @@ public class PhraseCluster {  		return kl;
  	}
 -	public double PREM(double scalePT, double scaleCT)
 +	public double PREM(double scalePT, double scaleCT, boolean skipBigPhrases)
  	{
  		if (scaleCT == 0)
  		{
  			if (pool != null)
 -				return PREM_phrase_constraints_parallel(scalePT);
 +				return PREM_phrase_constraints_parallel(scalePT, skipBigPhrases);
  			else
 -				return PREM_phrase_constraints(scalePT);
 +				return PREM_phrase_constraints(scalePT, skipBigPhrases);
  		}
  		else
 -			return this.PREM_phrase_context_constraints(scalePT, scaleCT);
 +			return this.PREM_phrase_context_constraints(scalePT, scaleCT, skipBigPhrases);
  	}
 -	public double PREM_phrase_constraints(double scalePT)
 +	public double PREM_phrase_constraints(double scalePT, boolean skipBigPhrases)
  	{
  		double [][][]exp_emit=new double[K][n_positions][n_words];
  		double [][]exp_pi=new double[n_phrases][K];
 +		
 +		if (skipBigPhrases)
 +		{
 +			for(double [][]i:exp_emit)
 +				for(double []j:i)
 +					Arrays.fill(j, 1e-100);
 +		}
  		if (lambdaPT == null && cacheLambda)
  			lambdaPT = new double[n_phrases][];
 @@ -244,6 +265,12 @@ public class PhraseCluster {  		long start = System.currentTimeMillis();
  		//E
  		for(int phrase=0; phrase<n_phrases; phrase++){
 +			if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
 +			{
 +				System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
 +				continue;
 +			}
 +			
  			PhraseObjective po = new PhraseObjective(this, phrase, scalePT, (cacheLambda) ? lambdaPT[phrase] : null);
  			boolean ok = po.optimizeWithProjectedGradientDescent();
  			if (!ok) ++failures;
 @@ -292,7 +319,7 @@ public class PhraseCluster {  		return primal;
  	}
 -	public double PREM_phrase_constraints_parallel(final double scalePT)
 +	public double PREM_phrase_constraints_parallel(final double scalePT, boolean skipBigPhrases)
  	{
  		assert(pool != null);
 @@ -302,10 +329,17 @@ public class PhraseCluster {  		double [][][]exp_emit=new double [K][n_positions][n_words];
  		double [][]exp_pi=new double[n_phrases][K];
 +		if (skipBigPhrases)
 +		{
 +			for(double [][]i:exp_emit)
 +				for(double []j:i)
 +					Arrays.fill(j, 1e-100);
 +		}
 +		
  		double loglikelihood=0, kl=0, l1lmax=0, primal=0;
  		final AtomicInteger failures = new AtomicInteger(0);
  		final AtomicLong elapsed = new AtomicLong(0l);
 -		int iterations=0;
 +		int iterations=0, n=n_phrases;
  		long start = System.currentTimeMillis();
  		if (lambdaPT == null && cacheLambda)
 @@ -313,6 +347,12 @@ public class PhraseCluster {  		//E
  		for(int phrase=0;phrase<n_phrases;phrase++){
 +			if (skipBigPhrases && c.getPhrase(phrase).size() >= 2)
 +			{
 +				n -= 1;
 +				System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K);
 +				continue;
 +			}
  			final int p=phrase;
  			pool.execute(new Runnable() {
  				public void run() {
 @@ -337,7 +377,7 @@ public class PhraseCluster {  		}
  		// aggregate the expectations as they become available
 -		for(int count=0;count<n_phrases;count++) {
 +		for(int count=0;count<n;count++) {
  			try {
  				//System.out.println("" + Thread.currentThread().getId() + " reading queue #" + count);
 @@ -396,8 +436,10 @@ public class PhraseCluster {  		return primal;
  	}
 -	public double PREM_phrase_context_constraints(double scalePT, double scaleCT)
 +	public double PREM_phrase_context_constraints(double scalePT, double scaleCT, boolean skipBigPhrases)
  	{	
 +		assert !skipBigPhrases : "Not supported yet - FIXME!"; //FIXME
 +		
  		double[][][] exp_emit = new double [K][n_positions][n_words];
  		double[][] exp_pi = new double[n_phrases][K];
 @@ -454,7 +496,8 @@ public class PhraseCluster {  		TIntArrayList ctx = edge.getContext();
  		for(int tag=0;tag<K;tag++)
  			for(int c=0;c<n_positions;c++)
 -				prob[tag]*=emit[tag][c][ctx.get(c)];
 +				if (!this.c.isSentinel(ctx.get(c)))
 +					prob[tag]*=emit[tag][c][ctx.get(c)];
  		return prob;
  	}
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index 20f6c905..a67c17a2 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -32,6 +32,7 @@ public class Trainer          parser.accepts("alpha-pi").withRequiredArg().ofType(Double.class).defaultsTo(0.01);          parser.accepts("agree");          parser.accepts("no-parameter-cache"); +        parser.accepts("skip-large-phrases").withRequiredArg().ofType(Integer.class).defaultsTo(5);          OptionSet options = parser.parse(args);          if (options.has("help") || !options.has("in")) @@ -55,6 +56,7 @@ public class Trainer  		boolean vb = options.has("variational-bayes");  		double alphaEmit = (vb) ? (Double) options.valueOf("alpha-emit") : 0;  		double alphaPi = (vb) ? (Double) options.valueOf("alpha-pi") : 0; +		int skip = (Integer) options.valueOf("skip-large-phrases");  		if (options.has("seed"))  			F.rng = new Random((Long) options.valueOf("seed")); @@ -80,6 +82,7 @@ public class Trainer  		if (!options.has("agree"))  			System.out.println("Running with " + tags + " tags " +  					"for " + em_iterations + " EM and " + pr_iterations + " PR iterations " + +					"skipping large phrases for first " + skip + " iterations " +   					"with scale " + scale_phrase + " phrase and " + scale_context + " context " +  					"and " + threads + " threads");  		else @@ -112,12 +115,12 @@ public class Trainer  				if (i < em_iterations)  				{  					if (!vb) -						o = cluster.EM(); +						o = cluster.EM(i < skip);  					else -						o = cluster.VBEM(alphaEmit, alphaPi);	 +						o = cluster.VBEM(alphaEmit, alphaPi, i < skip);	  				}  				else -					o = cluster.PREM(scale_phrase, scale_context); +					o = cluster.PREM(scale_phrase, scale_context, i < skip);  			}  			System.out.println("ITER: "+i+" objective: " + o); @@ -125,9 +128,9 @@ public class Trainer  			if (i != 0 && Math.abs((o - last) / o) < threshold)  			{  				last = o; -				if (i < em_iterations) +				if (i < Math.max(em_iterations, skip))  				{ -					i = em_iterations - 1; +					i = Math.max(em_iterations, skip) - 1;  					continue;  				}  				else | 
