package phrase; import gnu.trove.TIntArrayList; import org.apache.commons.math.special.Gamma; import java.io.PrintStream; import java.util.Arrays; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import phrase.Corpus.Edge; public class PhraseCluster { public int K; private int n_phrases, n_words, n_contexts, n_positions; public Corpus c; public ExecutorService pool; double[] lambdaPTCT; double[][] lambdaPT; boolean cacheLambda = true; // emit[tag][position][word] = p(word | tag, position in context) double emit[][][]; // pi[phrase][tag] = p(tag | phrase) double pi[][]; public PhraseCluster(int numCluster, Corpus corpus) { K=numCluster; c=corpus; n_words=c.getNumWords(); n_phrases=c.getNumPhrases(); n_contexts=c.getNumContexts(); n_positions=c.getNumContextPositions(); emit=new double [K][n_positions][n_words]; pi=new double[n_phrases][K]; for(double [][]i:emit) for(double []j:i) arr.F.randomise(j, true); for(double []j:pi) arr.F.randomise(j, true); } public void initialiseVB(double alphaEmit, double alphaPi) { assert alphaEmit > 0; assert alphaPi > 0; for(double [][]i:emit) for(double []j:i) digammaNormalize(j, alphaEmit); for(double []j:pi) digammaNormalize(j, alphaPi); } void useThreadPool(int threads) { assert threads > 0; pool = Executors.newFixedThreadPool(threads); } public double EM(boolean skipBigPhrases) { double [][][]exp_emit=new double [K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; if (skipBigPhrases) { for(double [][]i:exp_emit) for(double []j:i) Arrays.fill(j, 1e-100); } double loglikelihood=0; //E for(int phrase=0; phrase < n_phrases; phrase++) { if (skipBigPhrases && c.getPhrase(phrase).size() >= 2) { System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K); continue; } List contexts = c.getEdgesForPhrase(phrase); for (int ctx=0; ctx 0; loglikelihood += edge.getCount() * Math.log(z); arr.F.l1normalize(p); int count = edge.getCount(); //increment expected count TIntArrayList context = edge.getContext(); for(int tag=0;tag contexts = c.getEdgesForPhrase(phrase); for (int ctx=0; ctx 0; loglikelihood += edge.getCount() * Math.log(z); arr.F.l1normalize(p); int count = edge.getCount(); //increment expected count TIntArrayList context = edge.getContext(); for(int tag=0;tag expectations = new LinkedBlockingQueue(); double [][][]exp_emit=new double [K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; if (skipBigPhrases) { for(double [][]i:exp_emit) for(double []j:i) Arrays.fill(j, 1e-100); } double loglikelihood=0, kl=0, l1lmax=0, primal=0; final AtomicInteger failures = new AtomicInteger(0); final AtomicLong elapsed = new AtomicLong(0l); int iterations=0, n=n_phrases; long start = System.currentTimeMillis(); if (lambdaPT == null && cacheLambda) lambdaPT = new double[n_phrases][]; //E for(int phrase=0;phrase= 2) { n -= 1; System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K); continue; } final int p=phrase; pool.execute(new Runnable() { public void run() { try { //System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p); long start = System.currentTimeMillis(); PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT, (cacheLambda) ? lambdaPT[p] : null); boolean ok = po.optimizeWithProjectedGradientDescent(); if (!ok) failures.incrementAndGet(); long end = System.currentTimeMillis(); elapsed.addAndGet(end - start); //System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p); expectations.put(po); //System.out.println("" + Thread.currentThread().getId() + " added to queue " + p); } catch (InterruptedException e) { System.err.println(Thread.currentThread().getId() + " Local e-step thread interrupted; will cause deadlock."); e.printStackTrace(); } } }); } // aggregate the expectations as they become available for(int count=0;count edges = c.getEdgesForPhrase(phrase); for(int edge=0;edge 0) System.out.println("WARNING: failed to converge in " + failures.get() + "/" + n_phrases + " cases"); System.out.println("\tmean iters: " + iterations/(double)n_phrases + " walltime " + (end-start)/1000.0 + " threads " + elapsed.get() / 1000.0); System.out.println("\tllh: " + loglikelihood); System.out.println("\tKL: " + kl); System.out.println("\tphrase l1lmax: " + l1lmax); //M for(double [][]i:exp_emit) for(double []j:i) arr.F.l1normalize(j); emit=exp_emit; for(double []j:exp_pi) arr.F.l1normalize(j); pi=exp_pi; return primal; } public double PREM_phrase_context_constraints(double scalePT, double scaleCT, boolean skipBigPhrases) { assert !skipBigPhrases : "Not supported yet - FIXME!"; //FIXME double[][][] exp_emit = new double [K][n_positions][n_words]; double[][] exp_pi = new double[n_phrases][K]; //E step PhraseContextObjective pco = new PhraseContextObjective(this, lambdaPTCT, pool, scalePT, scaleCT); boolean ok = pco.optimizeWithProjectedGradientDescent(); if (cacheLambda) lambdaPTCT = pco.getParameters(); //now extract expectations List edges = c.getEdges(); for(int e = 0; e < edges.size(); ++e) { double [] q = pco.posterior(e); Corpus.Edge edge = edges.get(e); TIntArrayList context = edge.getContext(); int contextCnt = edge.getCount(); //increment expected count for(int tag=0;tag EPS) ps.print("\t" + j + ": " + pi[i][j]); } ps.println(); } ps.println("P(word|tag,position)"); for (int i = 0; i < K; ++i) { for(int position=0;position EPS) ps.print(c.getWord(word)+"="+emit[i][position][word]+"\t"); } ps.println(); } ps.println(); } } double phrase_l1lmax() { double sum=0; for(int phrase=0; phrase