package phrase; import gnu.trove.TIntArrayList; import io.FileUtil; import java.io.IOException; import java.io.PrintStream; import java.util.Arrays; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import phrase.Corpus.Edge; public class PhraseCluster { public int K; public double scalePT, scaleCT; private int n_phrases, n_words, n_contexts, n_positions; public Corpus c; private ExecutorService pool; // emit[tag][position][word] = p(word | tag, position in context) private double emit[][][]; // pi[phrase][tag] = p(tag | phrase) private double pi[][]; public static void main(String[] args) { String input_fname = args[0]; int tags = Integer.parseInt(args[1]); String output_fname = args[2]; int iterations = Integer.parseInt(args[3]); double scalePT = Double.parseDouble(args[4]); double scaleCT = Double.parseDouble(args[5]); int threads = Integer.parseInt(args[6]); boolean runEM = Boolean.parseBoolean(args[7]); assert(tags >= 2); assert(scalePT >= 0); assert(scaleCT >= 0); Corpus corpus = null; try { corpus = Corpus.readFromFile(FileUtil.openBufferedReader(input_fname)); } catch (IOException e) { System.err.println("Failed to open input file: " + input_fname); e.printStackTrace(); System.exit(1); } PhraseCluster cluster = new PhraseCluster(tags, corpus, scalePT, scaleCT, threads); //PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out"); double last = 0; for(int i=0;i= 1) o = cluster.PREM_phrase_constraints_parallel(); else o = cluster.PREM_phrase_constraints(); } else o = cluster.PREM_phrase_context_constraints(); //PhraseObjective.ps. System.out.println("ITER: "+i+" objective: " + o); last = o; } double pl1lmax = cluster.phrase_l1lmax(); double cl1lmax = cluster.context_l1lmax(); System.out.println("Final posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax); if (runEM) System.out.println("With PR objective " + (last - scalePT*pl1lmax - scaleCT*cl1lmax)); PrintStream ps=io.FileUtil.openOutFile(output_fname); cluster.displayPosterior(ps); ps.close(); //PhraseObjective.ps.close(); //ps = io.FileUtil.openOutFile(outputDir + "/parameters.out"); //cluster.displayModelParam(ps); //ps.close(); cluster.finish(); } public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){ K=numCluster; c=corpus; n_words=c.getNumWords(); n_phrases=c.getNumPhrases(); n_contexts=c.getNumContexts(); n_positions=c.getNumContextPositions(); this.scalePT = scalep; this.scaleCT = scalec; if (threads > 0 && scalec <= 0) pool = Executors.newFixedThreadPool(threads); emit=new double [K][n_positions][n_words]; pi=new double[n_phrases][K]; for(double [][]i:emit){ for(double []j:i){ arr.F.randomise(j); } } for(double []j:pi){ arr.F.randomise(j); } } public void finish() { if (pool != null) pool.shutdown(); } public double EM(){ double [][][]exp_emit=new double [K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; double loglikelihood=0; //E for(int phrase=0; phrase < n_phrases; phrase++){ List contexts = c.getEdgesForPhrase(phrase); for (int ctx=0; ctx 0; loglikelihood+=Math.log(z); arr.F.l1normalize(p); int count = edge.getCount(); //increment expected count TIntArrayList context = edge.getContext(); for(int tag=0;tag edges = c.getEdgesForPhrase(phrase); for(int edge=0;edge expectations = new LinkedBlockingQueue(); double [][][]exp_emit=new double [K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; double loglikelihood=0, kl=0, l1lmax=0, primal=0; //E for(int phrase=0;phrase edges = c.getEdgesForPhrase(phrase); for(int edge=0;edge 0); double [][][]exp_emit=new double [K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; //E step // TODO: cache the lambda values (the null below) PhraseContextObjective pco = new PhraseContextObjective(this, null); pco.optimizeWithProjectedGradientDescent(); //now extract expectations List edges = c.getEdges(); for(int e = 0; e < edges.size(); ++e) { double [] q = pco.posterior(e); Corpus.Edge edge = edges.get(e); TIntArrayList context = edge.getContext(); int contextCnt = edge.getCount(); //increment expected count for(int tag=0;tag EPS) ps.print("\t" + j + ": " + pi[i][j]); } ps.println(); } ps.println("P(word|tag,position)"); for (int i = 0; i < K; ++i) { for(int position=0;position EPS) ps.print(c.getWord(word)+"="+emit[i][position][word]+"\t"); } ps.println(); } ps.println(); } } double phrase_l1lmax() { double sum=0; for(int phrase=0; phrase