diff options
Diffstat (limited to 'gi/posterior-regularisation')
6 files changed, 832 insertions, 221 deletions
| diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java new file mode 100644 index 00000000..d5e856ca --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java @@ -0,0 +1,221 @@ +package phrase; + +import gnu.trove.TIntArrayList; + +import java.io.*; +import java.util.*; +import java.util.regex.Pattern; + + +public class Corpus +{ +	private Lexicon<String> wordLexicon = new Lexicon<String>(); +	private Lexicon<TIntArrayList> phraseLexicon = new Lexicon<TIntArrayList>(); +	private Lexicon<TIntArrayList> contextLexicon = new Lexicon<TIntArrayList>(); +	private List<Edge> edges = new ArrayList<Edge>(); +	private List<List<Edge>> phraseToContext = new ArrayList<List<Edge>>(); +	private List<List<Edge>> contextToPhrase = new ArrayList<List<Edge>>(); +	 +	public class Edge +	{ +		Edge(int phraseId, int contextId, int count) +		{ +			this.phraseId = phraseId; +			this.contextId = contextId; +			this.count = count; +		} +		public int getPhraseId() +		{ +			return phraseId; +		} +		public TIntArrayList getPhrase() +		{ +			return Corpus.this.getPhrase(phraseId); +		} +		public String getPhraseString() +		{ +			return Corpus.this.getPhraseString(phraseId); +		}		 +		public int getContextId() +		{ +			return contextId; +		} +		public TIntArrayList getContext() +		{ +			return Corpus.this.getContext(contextId); +		} +		public String getContextString(boolean insertPhraseSentinel) +		{ +			return Corpus.this.getContextString(contextId, insertPhraseSentinel); +		} +		public int getCount() +		{ +			return count; +		} +		public boolean equals(Object other) +		{ +			if (other instanceof Edge)  +			{ +				Edge oe = (Edge) other; +				return oe.phraseId == phraseId && oe.contextId == contextId;  +			} +			else return false; +		} +		public int hashCode() +		{   // this is how boost's hash_combine does it +			int seed = phraseId; +			seed ^= contextId + 0x9e3779b9 + (seed << 6) + (seed >> 2); +			return seed; +		} +		public String toString() +		{ +			return getPhraseString() + "\t" + getContextString(true); +		} +		 +		private int phraseId; +		private int contextId; +		private int count; +	} + +	List<Edge> getEdges() +	{ +		return edges; +	} +	 +	int getNumEdges() +	{ +		return edges.size(); +	} + +	int getNumPhrases() +	{ +		return phraseLexicon.size(); +	} +	 +	int getNumContextPositions() +	{ +		return contextLexicon.lookup(0).size(); +	} +	 +	List<Edge> getEdgesForPhrase(int phraseId) +	{ +		return phraseToContext.get(phraseId); +	} +	 +	int getNumContexts() +	{ +		return contextLexicon.size(); +	} +	 +	List<Edge> getEdgesForContext(int contextId) +	{ +		return contextToPhrase.get(contextId); +	} +	 +	int getNumWords() +	{ +		return wordLexicon.size(); +	} +	 +	String getWord(int wordId) +	{ +		return wordLexicon.lookup(wordId); +	} +	 +	public TIntArrayList getPhrase(int phraseId) +	{ +		return phraseLexicon.lookup(phraseId); +	} +	 +	public String getPhraseString(int phraseId) +	{ +		StringBuffer b = new StringBuffer(); +		for (int tid: getPhrase(phraseId).toNativeArray()) +		{ +			if (b.length() > 0) +				b.append(" "); +			b.append(wordLexicon.lookup(tid)); +		} +		return b.toString(); +	}		 +	 +	public TIntArrayList getContext(int contextId) +	{ +		return contextLexicon.lookup(contextId); +	} +	 +	public String getContextString(int contextId, boolean insertPhraseSentinel) +	{ +		StringBuffer b = new StringBuffer(); +		TIntArrayList c = getContext(contextId); +		for (int i = 0; i < c.size(); ++i) +		{ +			if (i > 0) b.append(" "); +			if (i == c.size() / 2) b.append("<PHRASE> "); +			b.append(wordLexicon.lookup(c.get(i))); +		} +		return b.toString(); +	} +	 +	static Corpus readFromFile(Reader in) throws IOException +	{ +		Corpus c = new Corpus(); +		 +		// read in line-by-line +		BufferedReader bin = new BufferedReader(in); +		String line; +		Pattern separator = Pattern.compile(" \\|\\|\\| "); + +		while ((line = bin.readLine()) != null) +		{ +			// split into phrase and contexts +			StringTokenizer st = new StringTokenizer(line, "\t"); +			assert (st.hasMoreTokens()); +			String phraseToks = st.nextToken(); +			assert (st.hasMoreTokens()); +			String rest = st.nextToken(); +			assert (!st.hasMoreTokens()); + +			// process phrase	 +			st = new StringTokenizer(phraseToks, " "); +			TIntArrayList ptoks = new TIntArrayList(); +			while (st.hasMoreTokens()) +				ptoks.add(c.wordLexicon.insert(st.nextToken())); +			int phraseId = c.phraseLexicon.insert(ptoks); +			if (phraseId == c.phraseToContext.size()) +				c.phraseToContext.add(new ArrayList<Edge>()); +			 +			// process contexts +			String[] parts = separator.split(rest); +			assert (parts.length % 2 == 0); +			for (int i = 0; i < parts.length; i += 2) +			{ +				// process pairs of strings - context and count +				TIntArrayList ctx = new TIntArrayList(); +				String ctxString = parts[i]; +				String countString = parts[i + 1]; +				StringTokenizer ctxStrtok = new StringTokenizer(ctxString, " "); +				while (ctxStrtok.hasMoreTokens()) +				{ +					String token = ctxStrtok.nextToken(); +					if (!token.equals("<PHRASE>")) +						ctx.add(c.wordLexicon.insert(token)); +				} +				int contextId = c.contextLexicon.insert(ctx); +				if (contextId == c.contextToPhrase.size()) +					c.contextToPhrase.add(new ArrayList<Edge>()); + +				assert (countString.startsWith("C=")); +				Edge e = c.new Edge(phraseId, contextId,  +						Integer.parseInt(countString.substring(2).trim())); +				c.edges.add(e); +				 +				// index the edge for fast phrase, context lookup +				c.phraseToContext.get(phraseId).add(e); +				c.contextToPhrase.get(contextId).add(e); +			} +		} +		 +		return c; +	}	 +} diff --git a/gi/posterior-regularisation/prjava/src/phrase/Lexicon.java b/gi/posterior-regularisation/prjava/src/phrase/Lexicon.java new file mode 100644 index 00000000..a386e4a3 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/Lexicon.java @@ -0,0 +1,34 @@ +package phrase; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class Lexicon<T> +{ +	public int insert(T word) +	{ +		Integer i = wordToIndex.get(word); +		if (i == null) +		{ +			i = indexToWord.size(); +			wordToIndex.put(word, i); +			indexToWord.add(word); +		} +		return i; +	} + +	public T lookup(int index) +	{ +		return indexToWord.get(index); +	} + +	public int size() +	{ +		return indexToWord.size(); +	} + +	private Map<T, Integer> wordToIndex = new HashMap<T, Integer>(); +	private List<T> indexToWord = new ArrayList<T>(); +}
\ No newline at end of file diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 731d03ac..e4db2a1a 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -1,44 +1,54 @@  package phrase;
 +import gnu.trove.TIntArrayList;
  import io.FileUtil;
 -
 -import java.io.FileOutputStream;
  import java.io.IOException;
 -import java.io.OutputStream;
  import java.io.PrintStream;
  import java.util.Arrays;
 +import java.util.List;
  import java.util.concurrent.ExecutorService;
  import java.util.concurrent.Executors;
  import java.util.concurrent.LinkedBlockingQueue;
 -import java.util.zip.GZIPOutputStream;
 +
 +import phrase.Corpus.Edge;
  public class PhraseCluster {
  	public int K;
 -	public double scale;
 -	private int n_phrase;
 -	private int n_words;
 -	public PhraseCorpus c;
 +	public double scalePT, scaleCT;
 +	private int n_phrases, n_words, n_contexts, n_positions;
 +	public Corpus c;
  	private ExecutorService pool; 
 -	/**@brief
 -	 * emit[tag][position][word]
 -	 */
 +	// emit[tag][position][word] = p(word | tag, position in context)
  	private double emit[][][];
 +	// pi[phrase][tag] = p(tag | phrase)
  	private double pi[][];
 -
 -	public static void main(String[] args) {
 +	public static void main(String[] args) 
 +	{
  		String input_fname = args[0];
  		int tags = Integer.parseInt(args[1]);
  		String output_fname = args[2];
  		int iterations = Integer.parseInt(args[3]);
 -		double scale = Double.parseDouble(args[4]);
 -		int threads = Integer.parseInt(args[5]);
 -		boolean runEM = Boolean.parseBoolean(args[6]);
 -		
 -		PhraseCorpus corpus = new PhraseCorpus(input_fname);
 -		PhraseCluster cluster = new PhraseCluster(tags, corpus, scale, threads);
 +		double scalePT = Double.parseDouble(args[4]);
 +		double scaleCT = Double.parseDouble(args[5]);
 +		int threads = Integer.parseInt(args[6]);
 +		boolean runEM = Boolean.parseBoolean(args[7]);
 +		
 +		assert(tags >= 2);
 +		assert(scalePT >= 0);
 +		assert(scaleCT >= 0);
 +		
 +		Corpus corpus = null;
 +		try {
 +			corpus = Corpus.readFromFile(FileUtil.openBufferedReader(input_fname));
 +		} catch (IOException e) {
 +			System.err.println("Failed to open input file: " + input_fname);
 +			e.printStackTrace();
 +			System.exit(1);
 +		}
 +		PhraseCluster cluster = new PhraseCluster(tags, corpus, scalePT, scaleCT, threads);
  		//PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out");
 @@ -48,19 +58,25 @@ public class PhraseCluster {  			double o;
  			if (runEM || i < 3) 
  				o = cluster.EM();
 -			else
 -				o = cluster.PREM();
 +			else if (scaleCT == 0)
 +			{
 +				if (threads >= 1)
 +					o = cluster.PREM_phrase_constraints_parallel();
 +				else
 +					o = cluster.PREM_phrase_constraints();
 +			}
 +			else 
 +				o = cluster.PREM_phrase_context_constraints();
 +			
  			//PhraseObjective.ps.
  			System.out.println("ITER: "+i+" objective: " + o);
  			last = o;
  		}
 -		if (runEM)
 -		{
 -			double l1lmax = cluster.posterior_l1lmax();
 -			System.out.println("Final l1lmax term " + l1lmax + ", total PR objective " + (last - scale*l1lmax));
 -			// nb. KL is 0 by definition
 -		}
 +		double pl1lmax = cluster.phrase_l1lmax();
 +		double cl1lmax = cluster.context_l1lmax();
 +		System.out.println("Final posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax);
 +		if (runEM) System.out.println("With PR objective " + (last - scalePT*pl1lmax - scaleCT*cl1lmax));
  		PrintStream ps=io.FileUtil.openOutFile(output_fname);
  		cluster.displayPosterior(ps);
 @@ -75,17 +91,20 @@ public class PhraseCluster {  		cluster.finish();
  	}
 -	public PhraseCluster(int numCluster, PhraseCorpus corpus, double scale, int threads){
 +	public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){
  		K=numCluster;
  		c=corpus;
 -		n_words=c.wordLex.size();
 -		n_phrase=c.data.length;
 -		this.scale = scale;
 -		if (threads > 0)
 +		n_words=c.getNumWords();
 +		n_phrases=c.getNumPhrases();
 +		n_contexts=c.getNumContexts();
 +		n_positions=c.getNumContextPositions();
 +		this.scalePT = scalep;
 +		this.scaleCT = scalec;
 +		if (threads > 0 && scalec <= 0)
  			pool = Executors.newFixedThreadPool(threads);
 -		emit=new double [K][c.numContexts][n_words];
 -		pi=new double[n_phrase][K];
 +		emit=new double [K][n_positions][n_words];
 +		pi=new double[n_phrases][K];
  		for(double [][]i:emit){
  			for(double []j:i){
 @@ -105,30 +124,32 @@ public class PhraseCluster {  	}
  	public double EM(){
 -		double [][][]exp_emit=new double [K][c.numContexts][n_words];
 -		double [][]exp_pi=new double[n_phrase][K];
 +		double [][][]exp_emit=new double [K][n_positions][n_words];
 +		double [][]exp_pi=new double[n_phrases][K];
  		double loglikelihood=0;
  		//E
 -		for(int phrase=0;phrase<c.data.length;phrase++){
 -			int [][] data=c.data[phrase];
 -			for(int ctx=0;ctx<data.length;ctx++){
 -				int context[]=data[ctx];
 -				double p[]=posterior(phrase,context);
 +		for(int phrase=0; phrase < n_phrases; phrase++){
 +			List<Edge> contexts = c.getEdgesForPhrase(phrase);
 +
 +			for (int ctx=0; ctx<contexts.size(); ctx++){
 +				Edge edge = contexts.get(ctx);
 +				double p[]=posterior(edge);
  				double z = arr.F.l1norm(p);
  				assert z > 0;
  				loglikelihood+=Math.log(z);
  				arr.F.l1normalize(p);
 -				int contextCnt=context[context.length-1];
 +				int count = edge.getCount();
  				//increment expected count
 +				TIntArrayList context = edge.getContext();
  				for(int tag=0;tag<K;tag++){
 -					for(int pos=0;pos<context.length-1;pos++){
 -						exp_emit[tag][pos][context[pos]]+=p[tag]*contextCnt;
 +					for(int pos=0;pos<n_positions;pos++){
 +						exp_emit[tag][pos][context.get(pos)]+=p[tag]*count;
  					}
 -					exp_pi[phrase][tag]+=p[tag]*contextCnt;
 +					exp_pi[phrase][tag]+=p[tag]*count;
  				}
  			}
  		}
 @@ -153,29 +174,32 @@ public class PhraseCluster {  		return loglikelihood;
  	}
 -	public double PREM(){
 -		if (pool != null)
 -			return PREMParallel();
 +	public double PREM_phrase_constraints(){
 +		assert (scaleCT <= 0);
 -		double [][][]exp_emit=new double [K][c.numContexts][n_words];
 -		double [][]exp_pi=new double[n_phrase][K];
 +		double [][][]exp_emit=new double[K][n_positions][n_words];
 +		double [][]exp_pi=new double[n_phrases][K];
 -		double loglikelihood=0;
 -		double primal=0;
 +		double loglikelihood=0, kl=0, l1lmax=0, primal=0;
  		//E
 -		for(int phrase=0;phrase<c.data.length;phrase++){
 +		for(int phrase=0; phrase<n_phrases; phrase++){
  			PhraseObjective po=new PhraseObjective(this,phrase);
  			po.optimizeWithProjectedGradientDescent();
  			double [][] q=po.posterior();
 -			loglikelihood+=po.llh;
 -			primal+=po.primal();
 +			loglikelihood += po.loglikelihood();
 +			kl += po.KL_divergence();
 +			l1lmax += po.l1lmax();
 +			primal += po.primal();
 +			List<Edge> edges = c.getEdgesForPhrase(phrase);
 +
  			for(int edge=0;edge<q.length;edge++){
 -				int []context=c.data[phrase][edge];
 -				int contextCnt=context[context.length-1];
 +				Edge e = edges.get(edge);
 +				TIntArrayList context = e.getContext();
 +				int contextCnt = e.getCount();
  				//increment expected count
  				for(int tag=0;tag<K;tag++){
 -					for(int pos=0;pos<context.length-1;pos++){
 -						exp_emit[tag][pos][context[pos]]+=q[edge][tag]*contextCnt;
 +					for(int pos=0;pos<n_positions;pos++){
 +						exp_emit[tag][pos][context.get(pos)]+=q[edge][tag]*contextCnt;
  					}
  					exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
 @@ -183,8 +207,9 @@ public class PhraseCluster {  			}
  		}
 -		System.out.println("Log likelihood: "+loglikelihood);
 -		System.out.println("Primal Objective: "+primal);
 +		System.out.println("\tllh:            " + loglikelihood);
 +		System.out.println("\tKL:             " + kl);
 +		System.out.println("\tphrase l1lmax:  " + l1lmax);
  		//M
  		for(double [][]i:exp_emit){
 @@ -204,18 +229,21 @@ public class PhraseCluster {  		return primal;
  	}
 -	public double PREMParallel(){
 +	public double PREM_phrase_constraints_parallel()
 +	{
  		assert(pool != null);
 +		assert(scaleCT <= 0);
 +		
  		final LinkedBlockingQueue<PhraseObjective> expectations 
  			= new LinkedBlockingQueue<PhraseObjective>();
 -		double [][][]exp_emit=new double [K][c.numContexts][n_words];
 -		double [][]exp_pi=new double[n_phrase][K];
 +		double [][][]exp_emit=new double [K][n_positions][n_words];
 +		double [][]exp_pi=new double[n_phrases][K];
 -		double loglikelihood=0;
 -		double primal=0;
 +		double loglikelihood=0, kl=0, l1lmax=0, primal=0;
 +
  		//E
 -		for(int phrase=0;phrase<c.data.length;phrase++){
 +		for(int phrase=0;phrase<n_phrases;phrase++){
  			final int p=phrase;
  			pool.execute(new Runnable() {
  				public void run() {
 @@ -235,7 +263,7 @@ public class PhraseCluster {  		}
  		// aggregate the expectations as they become available
 -		for(int count=0;count<c.data.length;count++) {
 +		for(int count=0;count<n_phrases;count++) {
  			try {
  				//System.out.println("" + Thread.currentThread().getId() + " reading queue #" + count);
 @@ -245,109 +273,139 @@ public class PhraseCluster {  				int phrase = po.phrase;
  				//System.out.println("" + Thread.currentThread().getId() + " taken phrase " + phrase);
  				double [][] q=po.posterior();
 -				loglikelihood+=po.llh;
 -				primal+=po.primal();
 +				loglikelihood += po.loglikelihood();
 +				kl += po.KL_divergence();
 +				l1lmax += po.l1lmax();
 +				primal += po.primal();
 +				
 +				List<Edge> edges = c.getEdgesForPhrase(phrase);
  				for(int edge=0;edge<q.length;edge++){
 -					int []context=c.data[phrase][edge];
 -					int contextCnt=context[context.length-1];
 +					Edge e = edges.get(edge);
 +					TIntArrayList context = e.getContext();
 +					int contextCnt = e.getCount();
  					//increment expected count
  					for(int tag=0;tag<K;tag++){
 -						for(int pos=0;pos<context.length-1;pos++){
 -							exp_emit[tag][pos][context[pos]]+=q[edge][tag]*contextCnt;
 +						for(int pos=0;pos<n_positions;pos++){
 +							exp_emit[tag][pos][context.get(pos)]+=q[edge][tag]*contextCnt;
  						}
  						exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
  					}
  				}
 -			} catch (InterruptedException e){
 +			} catch (InterruptedException e)
 +			{
  				System.err.println("M-step thread interrupted. Probably fatal!");
  				e.printStackTrace();
  			}
  		}
 -		System.out.println("Log likelihood: "+loglikelihood);
 -		System.out.println("Primal Objective: "+primal);
 +		System.out.println("\tllh:            " + loglikelihood);
 +		System.out.println("\tKL:             " + kl);
 +		System.out.println("\tphrase l1lmax:  " + l1lmax);
  		//M
 -		for(double [][]i:exp_emit){
 -			for(double []j:i){
 +		for(double [][]i:exp_emit)
 +			for(double []j:i)
  				arr.F.l1normalize(j);
 +		emit=exp_emit;
 +		
 +		for(double []j:exp_pi)
 +			arr.F.l1normalize(j);
 +		pi=exp_pi;
 +		
 +		return primal;
 +	}
 +
 +	public double PREM_phrase_context_constraints(){
 +		assert (scaleCT > 0);
 +		
 +		double [][][]exp_emit=new double [K][n_positions][n_words];
 +		double [][]exp_pi=new double[n_phrases][K];
 +
 +		//E step
 +		// TODO: cache the lambda values (the null below)
 +		PhraseContextObjective pco = new PhraseContextObjective(this, null);
 +		pco.optimizeWithProjectedGradientDescent();
 +
 +		//now extract expectations
 +		List<Corpus.Edge> edges = c.getEdges();
 +		for(int e = 0; e < edges.size(); ++e)
 +		{
 +			double [] q = pco.posterior(e);
 +			Corpus.Edge edge = edges.get(e);
 +
 +			TIntArrayList context = edge.getContext();
 +			int contextCnt = edge.getCount();
 +			//increment expected count
 +			for(int tag=0;tag<K;tag++)
 +			{
 +				for(int pos=0;pos<n_positions;pos++)
 +					exp_emit[tag][pos][context.get(pos)]+=q[tag]*contextCnt;
 +				exp_pi[edge.getPhraseId()][tag]+=q[tag]*contextCnt;
  			}
  		}
 +		System.out.println("\tllh:            " + pco.loglikelihood());
 +		System.out.println("\tKL:             " + pco.KL_divergence());
 +		System.out.println("\tphrase l1lmax:  " + pco.phrase_l1lmax());
 +		System.out.println("\tcontext l1lmax: " + pco.context_l1lmax());
 +		
 +		//M step
 +		for(double [][]i:exp_emit)
 +			for(double []j:i)
 +				arr.F.l1normalize(j);
  		emit=exp_emit;
 -		for(double []j:exp_pi){
 +		for(double []j:exp_pi)
  			arr.F.l1normalize(j);
 -		}
 -		
  		pi=exp_pi;
 -		return primal;
 -	}
 -	
 +		return pco.primal();
 +	}	
 +		
  	/**
 -	 * 
  	 * @param phrase index of phrase
  	 * @param ctx array of context
  	 * @return unnormalized posterior
  	 */
 -	public double[]posterior(int phrase, int[]ctx){
 -		double[] prob=Arrays.copyOf(pi[phrase], K);
 +	public double[] posterior(Corpus.Edge edge) 
 +	{
 +		double[] prob=Arrays.copyOf(pi[edge.getPhraseId()], K);
 -		for(int tag=0;tag<K;tag++){
 -			for(int c=0;c<ctx.length-1;c++){
 -				int word=ctx[c];
 -				prob[tag]*=emit[tag][c][word];
 -			}
 -		}
 +		TIntArrayList ctx = edge.getContext();
 +		for(int tag=0;tag<K;tag++)
 +			for(int c=0;c<n_positions;c++)
 +				prob[tag]*=emit[tag][c][ctx.get(c)];
  		return prob;
  	}
  	public void displayPosterior(PrintStream ps)
 -	{
 -		
 -		c.buildList();
 -		
 -		for (int i = 0; i < n_phrase; ++i)
 +	{	
 +		for (Edge edge : c.getEdges())
  		{
 -			int [][]data=c.data[i];
 -			for (int[] e: data)
 -			{
 -				double probs[] = posterior(i, e);
 -				arr.F.l1normalize(probs);
 +			double probs[] = posterior(edge);
 +			arr.F.l1normalize(probs);
 -				// emit phrase
 -				ps.print(c.phraseList[i]);
 -				ps.print("\t");
 -				ps.print(c.getContextString(e, true));
 -				int t=arr.F.argmax(probs);
 -				ps.println(" ||| C=" + t);
 -
 -				//ps.print("||| C=" + e[e.length-1] + " |||");
 -				
 -				//ps.print(t+"||| [");
 -				//for(t=0;t<K;t++){
 -				//	ps.print(probs[t]+", ");
 -				//}
 -				// for (int t = 0; t < numTags; ++t)
 -				// System.out.print(" " + probs[t]);
 -				//ps.println("]");
 -			}
 +			// emit phrase
 +			ps.print(edge.getPhraseString());
 +			ps.print("\t");
 +			ps.print(edge.getContextString(true));
 +			int t=arr.F.argmax(probs);
 +			ps.println(" ||| C=" + t);
  		}
  	}
  	public void displayModelParam(PrintStream ps)
  	{
 -		
 -		c.buildList();
 +		final double EPS = 1e-6;
  		ps.println("P(tag|phrase)");
 -		for (int i = 0; i < n_phrase; ++i)
 +		for (int i = 0; i < n_phrases; ++i)
  		{
 -			ps.print(c.phraseList[i]);
 +			ps.print(c.getPhrase(i));
  			for(int j=0;j<pi[i].length;j++){
 -				ps.print("\t"+pi[i][j]);
 +				if (pi[i][j] > EPS)
 +					ps.print("\t" + j + ": " + pi[i][j]);
  			}
  			ps.println();
  		}
 @@ -355,14 +413,11 @@ public class PhraseCluster {  		ps.println("P(word|tag,position)");
  		for (int i = 0; i < K; ++i)
  		{
 -			for(int position=0;position<c.numContexts;position++){
 +			for(int position=0;position<n_positions;position++){
  				ps.println("tag " + i + " position " + position);
  				for(int word=0;word<emit[i][position].length;word++){
 -					//if((word+1)%100==0){
 -					//	ps.println();
 -					//}
 -					if (emit[i][position][word] > 1e-10)
 -						ps.print(c.wordList[word]+"="+emit[i][position][word]+"\t");
 +					if (emit[i][position][word] > EPS)
 +						ps.print(c.getWord(word)+"="+emit[i][position][word]+"\t");
  				}
  				ps.println();
  			}
 @@ -371,19 +426,35 @@ public class PhraseCluster {  	}
 -	double posterior_l1lmax()
 +	double phrase_l1lmax()
  	{
  		double sum=0;
 -		for(int phrase=0;phrase<c.data.length;phrase++)
 +		for(int phrase=0; phrase<n_phrases; phrase++)
  		{
 -			int [][] data = c.data[phrase];
  			double [] maxes = new double[K];
 -			for(int ctx=0;ctx<data.length;ctx++)
 +			for (Edge edge : c.getEdgesForPhrase(phrase))
  			{
 -				int context[]=data[ctx];
 -				double p[]=posterior(phrase,context);
 +				double p[] = posterior(edge);
  				arr.F.l1normalize(p);
 +				for(int tag=0;tag<K;tag++)
 +					maxes[tag] = Math.max(maxes[tag], p[tag]);
 +			}
 +			for(int tag=0;tag<K;tag++)
 +				sum += maxes[tag];
 +		}
 +		return sum;
 +	}
 +	double context_l1lmax()
 +	{
 +		double sum=0;
 +		for(int context=0; context<n_contexts; context++)
 +		{
 +			double [] maxes = new double[K];
 +			for (Edge edge : c.getEdgesForContext(context))
 +			{
 +				double p[] = posterior(edge);
 +				arr.F.l1normalize(p);
  				for(int tag=0;tag<K;tag++)
  					maxes[tag] = Math.max(maxes[tag], p[tag]);
  			}
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java new file mode 100644 index 00000000..3273f0ad --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java @@ -0,0 +1,302 @@ +package phrase;
 +
 +import java.io.PrintStream;
 +import java.util.Arrays;
 +import java.util.HashMap;
 +import java.util.List;
 +import java.util.Map;
 +
 +import optimization.gradientBasedMethods.ProjectedGradientDescent;
 +import optimization.gradientBasedMethods.ProjectedObjective;
 +import optimization.gradientBasedMethods.stats.OptimizerStats;
 +import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;
 +import optimization.linesearch.InterpolationPickFirstStep;
 +import optimization.linesearch.LineSearchMethod;
 +import optimization.linesearch.WolfRuleLineSearch;
 +import optimization.projections.SimplexProjection;
 +import optimization.stopCriteria.CompositeStopingCriteria;
 +import optimization.stopCriteria.ProjectedGradientL2Norm;
 +import optimization.stopCriteria.StopingCriteria;
 +import optimization.stopCriteria.ValueDifference;
 +import optimization.util.MathUtils;
 +import phrase.Corpus.Edge;
 +
 +public class PhraseContextObjective extends ProjectedObjective
 +{
 +	private static final double GRAD_DIFF = 0.00002;
 +	private static double INIT_STEP_SIZE = 10;
 +	private static double VAL_DIFF = 1e-4; // FIXME needs to be tuned
 +	private static int ITERATIONS = 100;
 +	
 +	private PhraseCluster c;
 +	
 +	// un-regularized  unnormalized posterior, p[edge][tag]
 +	// P(tag|edge) \propto P(tag|phrase)P(context|tag)
 +	private double p[][];
 +
 +	// regularized unnormalized posterior 
 +	// q[edge][tag] propto p[edge][tag]*exp(-lambda)
 +	private double q[][];
 +	private List<Corpus.Edge> data;
 +	
 +	// log likelihood under q
 +	private double loglikelihood;
 +	private SimplexProjection projectionPhrase;
 +	private SimplexProjection projectionContext;
 +	
 +	double[] newPoint;
 +	private int n_param;
 +	
 +	// likelihood under p
 +	public double llh;
 +	
 +	private Map<Corpus.Edge, Integer> edgeIndex;
 +	
 +	public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters)
 +	{
 +		c=cluster;
 +		data=c.c.getEdges();
 +		n_param=data.size()*c.K*2;
 +		
 +		parameters = startingParameters;
 +		if (parameters == null)
 +			parameters = new double[n_param];
 +		
 +		newPoint = new double[n_param];
 +		gradient = new double[n_param];
 +		initP();
 +		projectionPhrase = new SimplexProjection(c.scalePT);
 +		projectionContext = new SimplexProjection(c.scaleCT);
 +		q=new double [data.size()][c.K];
 +		
 +		edgeIndex = new HashMap<Edge, Integer>();
 +		for (int e=0; e<data.size(); e++)
 +			edgeIndex.put(data.get(e), e);
 +
 +		setParameters(parameters);
 +	}
 +
 +	private void initP(){
 +		p=new double[data.size()][];
 +		for(int edge=0;edge<data.size();edge++)
 +		{
 +			p[edge]=c.posterior(data.get(edge));
 +			llh += data.get(edge).getCount() * Math.log(arr.F.l1norm(p[edge]));
 +			arr.F.l1normalize(p[edge]);
 +		}
 +	}
 +	
 +	@Override
 +	public void setParameters(double[] params) {
 +		//System.out.println("setParameters " + Arrays.toString(parameters));
 +		// TODO: test if params have changed and skip update otherwise
 +		super.setParameters(params);
 +		updateFunction();
 +	}
 +	
 +	private void updateFunction()
 +	{
 +		updateCalls++;
 +		loglikelihood=0;
 +
 +		for (int e=0; e<data.size(); e++) 
 +		{
 +			Edge edge = data.get(e);
 +			int offset = edgeIndex.get(edge)*c.K*2;
 +			for(int tag=0; tag<c.K; tag++)
 +			{
 +				int ip = offset + tag*2;
 +				int ic = ip + 1;
 +				q[e][tag] = p[e][tag]*
 +					Math.exp((-parameters[ip]-parameters[ic]) / edge.getCount());
 +			}
 +		}
 +	
 +		for(int edge=0;edge<data.size();edge++){
 +			loglikelihood+=data.get(edge).getCount() * Math.log(arr.F.l1norm(q[edge]));
 +			arr.F.l1normalize(q[edge]);
 +		}
 +		
 +		for (int e=0; e<data.size(); e++) 
 +		{
 +			Edge edge = data.get(e);
 +			int offset = edgeIndex.get(edge)*c.K*2;
 +			for(int tag=0; tag<c.K; tag++)
 +			{
 +				int ip = offset + tag*2;
 +				int ic = ip + 1;		
 +				gradient[ip]=-q[e][tag];
 +				gradient[ic]=-q[e][tag];
 +			}
 +		}
 +		//System.out.println("objective " + loglikelihood + " gradient: " + Arrays.toString(gradient));
 +	}
 +	
 +	@Override
 +	public double[] projectPoint(double[] point) 
 +	{
 +		//System.out.println("projectPoint: " + Arrays.toString(point));
 +		Arrays.fill(newPoint, 0, newPoint.length, 0);
 +		if (c.scalePT > 0)
 +		{
 +			// first project using the phrase-tag constraints,
 +			// for all p,t: sum_c lambda_ptc < scaleP 
 +			for (int p = 0; p < c.c.getNumPhrases(); ++p)
 +			{
 +				List<Edge> edges = c.c.getEdgesForPhrase(p);
 +				double toProject[] = new double[edges.size()];
 +				for(int tag=0;tag<c.K;tag++)
 +				{
 +					for(int e=0; e<edges.size(); e++)
 +						toProject[e] = point[index(edges.get(e), tag, true)];
 +					projectionPhrase.project(toProject);
 +					for(int e=0; e<edges.size(); e++)
 +						newPoint[index(edges.get(e),tag, true)] = toProject[e];
 +				}
 +			}
 +		}
 +		//System.out.println("after PT " + Arrays.toString(newPoint));
 +	
 +		if (c.scaleCT > 1e-6)
 +		{
 +			// now project using the context-tag constraints,
 +			// for all c,t: sum_p omega_pct < scaleC
 +			for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
 +			{
 +				List<Edge> edges = c.c.getEdgesForContext(ctx);
 +				double toProject[] = new double[edges.size()];
 +				for(int tag=0;tag<c.K;tag++)
 +				{
 +					for(int e=0; e<edges.size(); e++)
 +						toProject[e] = point[index(edges.get(e), tag, false)];
 +					projectionContext.project(toProject);
 +					for(int e=0; e<edges.size(); e++)
 +						newPoint[index(edges.get(e),tag, false)] = toProject[e];
 +				}
 +			}
 +		}
 +		double[] tmp = newPoint;
 +		newPoint = point;
 +		
 +		//System.out.println("\treturning " + Arrays.toString(tmp));
 +		return tmp;
 +	}
 +	
 +	private int index(Edge edge, int tag, boolean phrase)
 +	{
 +		// NB if indexing changes must also change code in updateFunction and constructor
 +		if (phrase)
 +			return edgeIndex.get(edge)*c.K*2 + tag*2;
 +		else
 +			return edgeIndex.get(edge)*c.K*2 + tag*2 + 1;
 +	}
 +
 +	@Override
 +	public double[] getGradient() {
 +		gradientCalls++;
 +		return gradient;
 +	}
 +
 +	@Override
 +	public double getValue() {
 +		functionCalls++;
 +		return loglikelihood;
 +	}
 +
 +	@Override
 +	public String toString() {
 +		return "No need for pointless toString";
 +	}
 +
 +	public double []posterior(int edgeIndex){
 +		return q[edgeIndex];
 +	}
 +	
 +	public double[] optimizeWithProjectedGradientDescent()
 +	{
 +		LineSearchMethod ls =
 +			new ArmijoLineSearchMinimizationAlongProjectionArc
 +				(new InterpolationPickFirstStep(INIT_STEP_SIZE));
 +		//LineSearchMethod  ls = new WolfRuleLineSearch(
 +		//		(new InterpolationPickFirstStep(INIT_STEP_SIZE)), c1, c2);
 +		OptimizerStats stats = new OptimizerStats();
 +		
 +		
 +		ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
 +		StopingCriteria stopGrad = new ProjectedGradientL2Norm(GRAD_DIFF);
 +		StopingCriteria stopValue = new ValueDifference(VAL_DIFF*(-llh));
 +		CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
 +		compositeStop.add(stopGrad);
 +		compositeStop.add(stopValue);
 +		optimizer.setMaxIterations(ITERATIONS);
 +		updateFunction();
 +		boolean succed = optimizer.optimize(this,stats,compositeStop);
 +//		System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
 +		if(succed){
 +			//System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
 +		}else{
 +			System.out.println("Failed to optimize");
 +		}
 +		//	ps.println(Arrays.toString(parameters));
 +		
 +		//	for(int edge=0;edge<data.getSize();edge++){
 +		//	ps.println(Arrays.toString(q[edge]));
 +		//	}
 +		//System.out.println(Arrays.toString(parameters));
 +		
 +		return parameters;
 +	}
 +	
 +	double loglikelihood()
 +	{
 +		return llh;
 +	}
 +	
 +	double KL_divergence()
 +	{
 +		return -loglikelihood + MathUtils.dotProduct(parameters, gradient);
 +	}
 +	
 +	double phrase_l1lmax()
 +	{
 +		// \sum_{tag,phrase} max_{context} P(tag|context,phrase)
 +		double sum=0;
 +		for (int p = 0; p < c.c.getNumPhrases(); ++p)
 +		{
 +			List<Edge> edges = c.c.getEdgesForPhrase(p);
 +			for(int tag=0;tag<c.K;tag++)
 +			{
 +				double max=0;
 +				for (Edge edge: edges)
 +					max = Math.max(max, q[edgeIndex.get(edge)][tag]);
 +				sum+=max;
 +			}	
 +		}
 +		return sum;
 +	}
 +	
 +	double context_l1lmax()
 +	{
 +		// \sum_{tag,context} max_{phrase} P(tag|context,phrase)
 +		double sum=0;
 +		for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
 +		{
 +			List<Edge> edges = c.c.getEdgesForContext(ctx);
 +			for(int tag=0; tag<c.K; tag++)
 +			{
 +				double max=0;
 +				for (Edge edge: edges)
 +					max = Math.max(max, q[edgeIndex.get(edge)][tag]);
 +				sum+=max;
 +			}	
 +		}
 +		return sum;
 +	}
 +	
 +	// L - KL(q||p) - scalePT * l1lmax_phrase - scaleCT * l1lmax_context
 +	public double primal()
 +	{
 +		return loglikelihood() - KL_divergence() - c.scalePT * phrase_l1lmax() - c.scalePT * context_l1lmax();
 +	}
 +	
 +}
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java index b8f1f24a..11e948ff 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java @@ -8,11 +8,8 @@ import java.util.ArrayList;  import java.util.HashMap;
  import java.util.Scanner;
 -public class PhraseCorpus {
 -	
 -	public static String LEX_FILENAME="../pdata/lex.out";
 -	public static String DATA_FILENAME="../pdata/btec.con";
 -	
 +public class PhraseCorpus 
 +{
  	public HashMap<String,Integer>wordLex;
  	public HashMap<String,Integer>phraseLex;
 @@ -21,16 +18,8 @@ public class PhraseCorpus {  	//data[phrase][num context][position]
  	public int data[][][];
 -	public int numContexts;
 -	
 -	public static void main(String[] args) {
 -		// TODO Auto-generated method stub
 -		PhraseCorpus c=new PhraseCorpus(DATA_FILENAME);
 -		c.saveLex(LEX_FILENAME);
 -		c.loadLex(LEX_FILENAME);
 -		c.saveLex(LEX_FILENAME);
 -	}
 -	
 +	public int numContexts;	
 +
  	public PhraseCorpus(String filename){
  		BufferedReader r=io.FileUtil.openBufferedReader(filename);
 @@ -185,5 +174,13 @@ public class PhraseCorpus {  		}
  		return null;
  	}
 -	
 +
 +	public static void main(String[] args) {
 +		String LEX_FILENAME="../pdata/lex.out";
 +		String DATA_FILENAME="../pdata/btec.con";
 +		PhraseCorpus c=new PhraseCorpus(DATA_FILENAME);
 +		c.saveLex(LEX_FILENAME);
 +		c.loadLex(LEX_FILENAME);
 +		c.saveLex(LEX_FILENAME);
 +	}
  }
 diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 0fdc169b..015ef106 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -2,6 +2,7 @@ package phrase;  import java.io.PrintStream;
  import java.util.Arrays;
 +import java.util.List;
  import optimization.gradientBasedMethods.ProjectedGradientDescent;
  import optimization.gradientBasedMethods.ProjectedObjective;
 @@ -17,11 +18,12 @@ import optimization.stopCriteria.StopingCriteria;  import optimization.stopCriteria.ValueDifference;
  import optimization.util.MathUtils;
 -public class PhraseObjective extends ProjectedObjective{
 -
 -	private static final double GRAD_DIFF = 0.00002;
 -	public static double INIT_STEP_SIZE = 10;
 -	public static double VAL_DIFF = 0.000001; // FIXME needs to be tuned
 +public class PhraseObjective extends ProjectedObjective
 +{
 +	static final double GRAD_DIFF = 0.00002;
 +	static double INIT_STEP_SIZE = 10;
 +	static double VAL_DIFF = 1e-6; // FIXME needs to be tuned
 +	static int ITERATIONS = 100;
  	//private double c1=0.0001; // wolf stuff
  	//private double c2=0.9;
  	private static double lambda[][];
 @@ -46,7 +48,7 @@ public class PhraseObjective extends ProjectedObjective{  	 * q[edge][tag] propto p[edge][tag]*exp(-lambda)
  	 */
  	private double q[][];
 -	private int data[][];
 +	private List<Corpus.Edge> data;
  	/**@brief log likelihood of the associated phrase
  	 * 
 @@ -66,14 +68,14 @@ public class PhraseObjective extends ProjectedObjective{  	public PhraseObjective(PhraseCluster cluster, int phraseIdx){
  		phrase=phraseIdx;
  		c=cluster;
 -		data=c.c.data[phrase];
 -		n_param=data.length*c.K;
 +		data=c.c.getEdgesForPhrase(phrase);
 +		n_param=data.size()*c.K;
 -		if( lambda==null){
 -			lambda=new double[c.c.data.length][];
 +		if (lambda==null){
 +			lambda=new double[c.c.getNumPhrases()][];
  		}
 -		if(lambda[phrase]==null){
 +		if (lambda[phrase]==null){
  			lambda[phrase]=new double[n_param];
  		}
 @@ -81,22 +83,17 @@ public class PhraseObjective extends ProjectedObjective{  		newPoint  = new double[n_param];
  		gradient = new double[n_param];
  		initP();
 -		projection=new SimplexProjection(c.scale);
 -		q=new double [data.length][c.K];
 +		projection=new SimplexProjection(c.scalePT);
 +		q=new double [data.size()][c.K];
  		setParameters(parameters);
  	}
  	private void initP(){
 -		int countIdx=data[0].length-1;
 -		
 -		p=new double[data.length][];
 -		for(int edge=0;edge<data.length;edge++){
 -			p[edge]=c.posterior(phrase,data[edge]);
 -		}
 -		for(int edge=0;edge<data.length;edge++){
 -			llh+=Math.log
 -				(data[edge][countIdx]*arr.F.l1norm(p[edge]));
 +		p=new double[data.size()][];
 +		for(int edge=0;edge<data.size();edge++){
 +			p[edge]=c.posterior(data.get(edge));
 +			llh += data.get(edge).getCount() * Math.log(arr.F.l1norm(p[edge])); // Was bug here - count inside log!
  			arr.F.l1normalize(p[edge]);
  		}
  	}
 @@ -110,37 +107,36 @@ public class PhraseObjective extends ProjectedObjective{  	private void updateFunction(){
  		updateCalls++;
  		loglikelihood=0;
 -		int countIdx=data[0].length-1;
 +
  		for(int tag=0;tag<c.K;tag++){
 -			for(int edge=0;edge<data.length;edge++){
 +			for(int edge=0;edge<data.size();edge++){
  				q[edge][tag]=p[edge][tag]*
 -					Math.exp(-parameters[tag*data.length+edge]/data[edge][countIdx]);
 +					Math.exp(-parameters[tag*data.size()+edge]/data.get(edge).getCount());
  			}
  		}
 -		for(int edge=0;edge<data.length;edge++){
 -			loglikelihood+=data[edge][countIdx] * Math.log(arr.F.l1norm(q[edge]));
 +		for(int edge=0;edge<data.size();edge++){
 +			loglikelihood+=data.get(edge).getCount() * Math.log(arr.F.l1norm(q[edge]));
  			arr.F.l1normalize(q[edge]);
  		}
  		for(int tag=0;tag<c.K;tag++){
 -			for(int edge=0;edge<data.length;edge++){
 -				gradient[tag*data.length+edge]=-q[edge][tag];
 +			for(int edge=0;edge<data.size();edge++){
 +				gradient[tag*data.size()+edge]=-q[edge][tag];
  			}
  		}
  	}
  	@Override
 -	// TODO Auto-generated method stub
  	public double[] projectPoint(double[] point) {
 -		double toProject[]=new double[data.length];
 +		double toProject[]=new double[data.size()];
  		for(int tag=0;tag<c.K;tag++){
 -			for(int edge=0;edge<data.length;edge++){
 -				toProject[edge]=point[tag*data.length+edge];
 +			for(int edge=0;edge<data.size();edge++){
 +				toProject[edge]=point[tag*data.size()+edge];
  			}
  			projection.project(toProject);
 -			for(int edge=0;edge<data.length;edge++){
 -				newPoint[tag*data.length+edge]=toProject[edge];
 +			for(int edge=0;edge<data.size();edge++){
 +				newPoint[tag*data.size()+edge]=toProject[edge];
  			}
  		}
  		return newPoint;
 @@ -148,22 +144,19 @@ public class PhraseObjective extends ProjectedObjective{  	@Override
  	public double[] getGradient() {
 -		// TODO Auto-generated method stub
  		gradientCalls++;
  		return gradient;
  	}
  	@Override
  	public double getValue() {
 -		// TODO Auto-generated method stub
  		functionCalls++;
  		return loglikelihood;
  	}
  	@Override
  	public String toString() {
 -		// TODO Auto-generated method stub
 -		return "";
 +		return "No need for pointless toString";
  	}
  	public double [][]posterior(){
 @@ -185,7 +178,7 @@ public class PhraseObjective extends ProjectedObjective{  		CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
  		compositeStop.add(stopGrad);
  		compositeStop.add(stopValue);
 -		optimizer.setMaxIterations(100);
 +		optimizer.setMaxIterations(ITERATIONS);
  		updateFunction();
  		boolean succed = optimizer.optimize(this,stats,compositeStop);
  //		System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
 @@ -197,45 +190,38 @@ public class PhraseObjective extends ProjectedObjective{  		lambda[phrase]=parameters;
  		//	ps.println(Arrays.toString(parameters));
 -		//	for(int edge=0;edge<data.length;edge++){
 +		//	for(int edge=0;edge<data.getSize();edge++){
  		//	ps.println(Arrays.toString(q[edge]));
  		//	}
  	}
 -	/**
 -	 * L - KL(q||p) -
 -	 * 	 scale * \sum_{tag,phrase} max_i P(tag|i th occurrence of phrase)
 -	 * @return
 -	 */
 -	public double primal()
 +	public double KL_divergence()
 +	{
 +		return -loglikelihood + MathUtils.dotProduct(parameters, gradient);
 +	}
 +	
 +	public double loglikelihood()
 +	{
 +		return llh;
 +	}
 +	
 +	public double l1lmax()
  	{
 -		
 -		double l=llh;
 -		
 -//		ps.print("Phrase "+phrase+": "+l);
 -		double kl=-loglikelihood
 -			+MathUtils.dotProduct(parameters, gradient);
 -//		ps.print(", "+kl);
 -		//System.out.println("llh " + llh);
 -		//System.out.println("kl " + kl);
 -		
 -
 -		l=l-kl;
  		double sum=0;
  		for(int tag=0;tag<c.K;tag++){
  			double max=0;
 -			for(int edge=0;edge<data.length;edge++){
 -				if(q[edge][tag]>max){
 +			for(int edge=0;edge<data.size();edge++){
 +				if(q[edge][tag]>max)
  					max=q[edge][tag];
 -				}
  			}
  			sum+=max;
  		}
 -		//System.out.println("l1lmax " + sum);
 -//		ps.println(", "+sum);
 -		l=l-c.scale*sum;
 -		return l;
 +		return sum;
 +	}
 +
 +	public double primal()
 +	{
 +		return loglikelihood() - KL_divergence() - c.scalePT * l1lmax();	
  	}
 -	
  }
 | 
