summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation')
-rw-r--r--gi/posterior-regularisation/prjava/src/arr/F.java70
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/C2F.java17
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java260
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java183
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java229
5 files changed, 759 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/arr/F.java b/gi/posterior-regularisation/prjava/src/arr/F.java
new file mode 100644
index 00000000..c194496e
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/arr/F.java
@@ -0,0 +1,70 @@
+package arr;
+
+public class F {
+ public static void randomise(double probs[])
+ {
+ double z = 0;
+ for (int i = 0; i < probs.length; ++i)
+ {
+ probs[i] = 3 + Math.random();
+ z += probs[i];
+ }
+
+ for (int i = 0; i < probs.length; ++i)
+ probs[i] /= z;
+ }
+
+ public static void l1normalize(double [] a){
+ double sum=0;
+ for(int i=0;i<a.length;i++){
+ sum+=a[i];
+ }
+ if(sum==0){
+ return ;
+ }
+ for(int i=0;i<a.length;i++){
+ a[i]/=sum;
+ }
+ }
+
+ public static void l1normalize(double [][] a){
+ double sum=0;
+ for(int i=0;i<a.length;i++){
+ for(int j=0;j<a[i].length;j++){
+ sum+=a[i][j];
+ }
+ }
+ if(sum==0){
+ return;
+ }
+ for(int i=0;i<a.length;i++){
+ for(int j=0;j<a[i].length;j++){
+ a[i][j]/=sum;
+ }
+ }
+ }
+
+ public static double l1norm(double a[]){
+ double norm=0;
+ for(int i=0;i<a.length;i++){
+ norm += a[i];
+ }
+ return norm;
+ }
+
+ public static int argmax(double probs[])
+ {
+ double m = Double.NEGATIVE_INFINITY;
+ int mi = -1;
+ for (int i = 0; i < probs.length; ++i)
+ {
+ if (probs[i] > m)
+ {
+ m = probs[i];
+ mi = i;
+ }
+ }
+ return mi;
+ }
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/C2F.java b/gi/posterior-regularisation/prjava/src/phrase/C2F.java
new file mode 100644
index 00000000..2646d961
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/phrase/C2F.java
@@ -0,0 +1,17 @@
+package phrase;
+/**
+ * @brief context generates phrase
+ * @author desaic
+ *
+ */
+public class C2F {
+
+ /**
+ * @param args
+ */
+ public static void main(String[] args) {
+ // TODO Auto-generated method stub
+
+ }
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
new file mode 100644
index 00000000..8b1e0a8c
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -0,0 +1,260 @@
+package phrase;
+
+import io.FileUtil;
+
+import java.io.PrintStream;
+import java.util.Arrays;
+
+public class PhraseCluster {
+
+ /**@brief number of clusters*/
+ public int K;
+ private int n_phrase;
+ private int n_words;
+ public PhraseCorpus c;
+
+ /**@brief
+ * emit[tag][position][word]
+ */
+ private double emit[][][];
+ private double pi[][];
+
+ public static int ITER=20;
+ public static String postFilename="../pdata/posterior.out";
+ public static String phraseStatFilename="../pdata/phrase_stat.out";
+ private static int NUM_TAG=3;
+ public static void main(String[] args) {
+
+ PhraseCorpus c=new PhraseCorpus(PhraseCorpus.DATA_FILENAME);
+
+ PhraseCluster cluster=new PhraseCluster(NUM_TAG,c);
+ PhraseObjective.ps=FileUtil.openOutFile(phraseStatFilename);
+ for(int i=0;i<ITER;i++){
+ PhraseObjective.ps.println("ITER: "+i);
+ cluster.PREM();
+ // cluster.EM();
+ }
+
+ PrintStream ps=io.FileUtil.openOutFile(postFilename);
+ cluster.displayPosterior(ps);
+ ps.println();
+ cluster.displayModelParam(ps);
+ ps.close();
+ PhraseObjective.ps.close();
+ }
+
+ public PhraseCluster(int numCluster,PhraseCorpus corpus){
+ K=numCluster;
+ c=corpus;
+ n_words=c.wordLex.size();
+ n_phrase=c.data.length;
+
+ emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
+ pi=new double[n_phrase][K];
+
+ for(double [][]i:emit){
+ for(double []j:i){
+ arr.F.randomise(j);
+ }
+ }
+
+ for(double []j:pi){
+ arr.F.randomise(j);
+ }
+
+ pi[0]=new double[]{
+ 0.3,0.5,0.2
+ };
+
+ double temp[][]=new double[][]{
+ {0.11,0.16,0.19,0.11,0.1},
+ {0.10,0.15,0.18,0.1,0.11},
+ {0.09,0.07,0.12,0.14,0.13}
+ };
+
+ for(int tag=0;tag<3;tag++){
+ for(int word=0;word<4;word++){
+ for(int pos=0;pos<4;pos++){
+ emit[tag][pos][word]=temp[tag][word];
+ }
+ }
+ }
+
+ }
+
+ public void EM(){
+ double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
+ double [][]exp_pi=new double[n_phrase][K];
+
+ double loglikelihood=0;
+
+ //E
+ for(int phrase=0;phrase<c.data.length;phrase++){
+ int [][] data=c.data[phrase];
+ for(int ctx=0;ctx<data.length;ctx++){
+ int context[]=data[ctx];
+ double p[]=posterior(phrase,context);
+ loglikelihood+=Math.log(arr.F.l1norm(p));
+ arr.F.l1normalize(p);
+
+ int contextCnt=context[context.length-1];
+ //increment expected count
+ for(int tag=0;tag<K;tag++){
+ for(int pos=0;pos<context.length-1;pos++){
+ exp_emit[tag][pos][context[pos]]+=p[tag]*contextCnt;
+ }
+
+ exp_pi[phrase][tag]+=p[tag]*contextCnt;
+ }
+ }
+ }
+
+ System.out.println("Log likelihood: "+loglikelihood);
+
+ //M
+ for(double [][]i:exp_emit){
+ for(double []j:i){
+ arr.F.l1normalize(j);
+ }
+ }
+
+ emit=exp_emit;
+
+ for(double []j:exp_pi){
+ arr.F.l1normalize(j);
+ }
+
+ pi=exp_pi;
+ }
+
+ public void PREM(){
+ double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
+ double [][]exp_pi=new double[n_phrase][K];
+
+ double loglikelihood=0;
+ double primal=0;
+ //E
+ for(int phrase=0;phrase<c.data.length;phrase++){
+ PhraseObjective po=new PhraseObjective(this,phrase);
+ po.optimizeWithProjectedGradientDescent();
+ double [][] q=po.posterior();
+ loglikelihood+=po.getValue();
+ primal+=po.primal();
+ for(int edge=0;edge<q.length;edge++){
+ int []context=c.data[phrase][edge];
+ int contextCnt=context[context.length-1];
+ //increment expected count
+ for(int tag=0;tag<K;tag++){
+ for(int pos=0;pos<context.length-1;pos++){
+ exp_emit[tag][pos][context[pos]]+=q[edge][tag]*contextCnt;
+ }
+
+ exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
+ }
+ }
+ }
+
+ System.out.println("Log likelihood: "+loglikelihood);
+ System.out.println("Primal Objective: "+primal);
+
+ //M
+ for(double [][]i:exp_emit){
+ for(double []j:i){
+ arr.F.l1normalize(j);
+ }
+ }
+
+ emit=exp_emit;
+
+ for(double []j:exp_pi){
+ arr.F.l1normalize(j);
+ }
+
+ pi=exp_pi;
+ }
+
+ /**
+ *
+ * @param phrase index of phrase
+ * @param ctx array of context
+ * @return unnormalized posterior
+ */
+ public double[]posterior(int phrase, int[]ctx){
+ double[] prob=Arrays.copyOf(pi[phrase], K);
+
+ for(int tag=0;tag<K;tag++){
+ for(int c=0;c<ctx.length-1;c++){
+ int word=ctx[c];
+ prob[tag]*=emit[tag][c][word];
+ }
+ }
+
+ return prob;
+ }
+
+ public void displayPosterior(PrintStream ps)
+ {
+
+ c.buildList();
+
+ for (int i = 0; i < n_phrase; ++i)
+ {
+ int [][]data=c.data[i];
+ for (int[] e: data)
+ {
+ double probs[] = posterior(i, e);
+ arr.F.l1normalize(probs);
+
+ // emit phrase
+ ps.print(c.phraseList[i]);
+ ps.print("\t");
+ ps.print(c.getContextString(e));
+ ps.print("||| C=" + e[e.length-1] + " |||");
+
+ int t=arr.F.argmax(probs);
+
+ ps.print(t+"||| [");
+ for(t=0;t<K;t++){
+ ps.print(probs[t]+", ");
+ }
+ // for (int t = 0; t < numTags; ++t)
+ // System.out.print(" " + probs[t]);
+ ps.println("]");
+ }
+ }
+ }
+
+ public void displayModelParam(PrintStream ps)
+ {
+
+ c.buildList();
+
+ ps.println("P(tag|phrase)");
+ for (int i = 0; i < n_phrase; ++i)
+ {
+ ps.print(c.phraseList[i]);
+ for(int j=0;j<pi[i].length;j++){
+ ps.print("\t"+pi[i][j]);
+ }
+ ps.println();
+ }
+
+ ps.println("P(word|tag,position)");
+ for (int i = 0; i < K; ++i)
+ {
+ ps.println(i);
+ for(int position=0;position<PhraseCorpus.NUM_CONTEXT;position++){
+ ps.println(position);
+ for(int word=0;word<emit[i][position].length;word++){
+ if((word+1)%100==0){
+ ps.println();
+ }
+ ps.print(c.wordList[word]+"="+emit[i][position][word]+"\t");
+ }
+ ps.println();
+ }
+ ps.println();
+ }
+
+ }
+}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java
new file mode 100644
index 00000000..3902f665
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java
@@ -0,0 +1,183 @@
+package phrase;
+
+import java.io.BufferedInputStream;
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Scanner;
+
+public class PhraseCorpus {
+
+
+ public static String LEX_FILENAME="../pdata/lex.out";
+ //public static String DATA_FILENAME="../pdata/canned.con";
+ public static String DATA_FILENAME="../pdata/btec.con";
+ public static int NUM_CONTEXT=4;
+
+ public HashMap<String,Integer>wordLex;
+ public HashMap<String,Integer>phraseLex;
+
+ public String wordList[];
+ public String phraseList[];
+
+ //data[phrase][num context][position]
+ public int data[][][];
+
+ public static void main(String[] args) {
+ // TODO Auto-generated method stub
+ PhraseCorpus c=new PhraseCorpus(DATA_FILENAME);
+ c.saveLex(LEX_FILENAME);
+ c.loadLex(LEX_FILENAME);
+ c.saveLex(LEX_FILENAME);
+ }
+
+ public PhraseCorpus(String filename){
+ BufferedReader r=io.FileUtil.openBufferedReader(filename);
+
+ phraseLex=new HashMap<String,Integer>();
+ wordLex=new HashMap<String,Integer>();
+
+ ArrayList<int[][]>dataList=new ArrayList<int[][]>();
+ String line=null;
+
+ while((line=readLine(r))!=null){
+
+ String toks[]=line.split("\t");
+ String phrase=toks[0];
+ addLex(phrase,phraseLex);
+
+ toks=toks[1].split(" \\|\\|\\| ");
+
+ ArrayList <int[]>ctxList=new ArrayList<int[]>();
+
+ for(int i=0;i<toks.length;i+=2){
+ String ctx=toks[i];
+ String words[]=ctx.split(" ");
+ int []context=new int [NUM_CONTEXT+1];
+ int idx=0;
+ for(String word:words){
+ if(word.equals("<PHRASE>")){
+ continue;
+ }
+ addLex(word,wordLex);
+ context[idx]=wordLex.get(word);
+ idx++;
+ }
+
+ String count=toks[i+1];
+ context[idx]=Integer.parseInt(count.trim().substring(2));
+
+
+ ctxList.add(context);
+
+ }
+
+ dataList.add(ctxList.toArray(new int [0][]));
+
+ }
+ try{
+ r.close();
+ }catch(IOException ioe){
+ ioe.printStackTrace();
+ }
+ data=dataList.toArray(new int[0][][]);
+ }
+
+ private void addLex(String key, HashMap<String,Integer>lex){
+ Integer i=lex.get(key);
+ if(i==null){
+ lex.put(key, lex.size());
+ }
+ }
+
+ //for debugging
+ public void saveLex(String lexFilename){
+ PrintStream ps=io.FileUtil.openOutFile(lexFilename);
+ ps.println("Phrase Lexicon");
+ ps.println(phraseLex.size());
+ printDict(phraseLex,ps);
+
+ ps.println("Word Lexicon");
+ ps.println(wordLex.size());
+ printDict(wordLex,ps);
+ ps.close();
+ }
+
+ private static void printDict(HashMap<String,Integer>lex,PrintStream ps){
+ String []dict=buildList(lex);
+ for(int i=0;i<dict.length;i++){
+ ps.println(dict[i]);
+ }
+ }
+
+ public void loadLex(String lexFilename){
+ Scanner sc=io.FileUtil.openInFile(lexFilename);
+
+ sc.nextLine();
+ int size=sc.nextInt();
+ sc.nextLine();
+ String[]dict=new String[size];
+ for(int i=0;i<size;i++){
+ dict[i]=sc.nextLine();
+ }
+ phraseLex=buildMap(dict);
+
+ sc.nextLine();
+ size=sc.nextInt();
+ sc.nextLine();
+ dict=new String[size];
+ for(int i=0;i<size;i++){
+ dict[i]=sc.nextLine();
+ }
+ wordLex=buildMap(dict);
+ sc.close();
+ }
+
+ private HashMap<String, Integer> buildMap(String[]dict){
+ HashMap<String,Integer> map=new HashMap<String,Integer>();
+ for(int i=0;i<dict.length;i++){
+ map.put(dict[i], i);
+ }
+ return map;
+ }
+
+ public void buildList(){
+ if(wordList==null){
+ wordList=buildList(wordLex);
+ phraseList=buildList(phraseLex);
+ }
+ }
+
+ private static String[]buildList(HashMap<String,Integer>lex){
+ String dict[]=new String [lex.size()];
+ for(String key:lex.keySet()){
+ dict[lex.get(key)]=key;
+ }
+ return dict;
+ }
+
+ public String getContextString(int context[])
+ {
+ StringBuffer b = new StringBuffer();
+ for (int i=0;i<context.length-1;i++)
+ {
+ if (b.length() > 0)
+ b.append(" ");
+ b.append(wordList[context[i]]);
+ }
+ return b.toString();
+ }
+
+ public static String readLine(BufferedReader r){
+ try{
+ return r.readLine();
+ }
+ catch(IOException ioe){
+ ioe.printStackTrace();
+ }
+ return null;
+ }
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
new file mode 100644
index 00000000..e9e063d6
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
@@ -0,0 +1,229 @@
+package phrase;
+
+import java.io.PrintStream;
+import java.util.Arrays;
+
+import optimization.gradientBasedMethods.ProjectedGradientDescent;
+import optimization.gradientBasedMethods.ProjectedObjective;
+import optimization.gradientBasedMethods.stats.OptimizerStats;
+import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;
+import optimization.linesearch.InterpolationPickFirstStep;
+import optimization.linesearch.LineSearchMethod;
+import optimization.linesearch.WolfRuleLineSearch;
+import optimization.projections.SimplexProjection;
+import optimization.stopCriteria.CompositeStopingCriteria;
+import optimization.stopCriteria.ProjectedGradientL2Norm;
+import optimization.stopCriteria.StopingCriteria;
+import optimization.stopCriteria.ValueDifference;
+import optimization.util.MathUtils;
+
+public class PhraseObjective extends ProjectedObjective{
+
+ private static final double GRAD_DIFF = 0.002;
+ public static double INIT_STEP_SIZE=1;
+ public static double VAL_DIFF=0.001;
+ private double scale=5;
+ private double c1=0.0001;
+ private double c2=0.9;
+
+ private PhraseCluster c;
+
+ /**@brief
+ * for debugging purposes
+ */
+ public static PrintStream ps;
+
+ /**@brief current phrase being optimzed*/
+ public int phrase;
+
+ /**@brief un-regularized posterior
+ * unnormalized
+ * p[edge][tag]
+ * P(tag|edge) \propto P(tag|phrase)P(context|tag)
+ */
+ private double[][]p;
+
+ /**@brief regularized posterior
+ * q[edge][tag] propto p[edge][tag]*exp(-lambda)
+ */
+ private double q[][];
+ private int data[][];
+
+ /**@brief log likelihood of the associated phrase
+ *
+ */
+ private double loglikelihood;
+ private SimplexProjection projection;
+
+ double[] newPoint ;
+
+ private int n_param;
+
+ /**@brief likelihood under p
+ *
+ */
+ private double llh;
+
+ public PhraseObjective(PhraseCluster cluster, int phraseIdx){
+ phrase=phraseIdx;
+ c=cluster;
+ data=c.c.data[phrase];
+ n_param=data.length*c.K;
+ parameters=new double [n_param];
+ newPoint = new double[n_param];
+ gradient = new double[n_param];
+ initP();
+ projection=new SimplexProjection (scale);
+ q=new double [data.length][c.K];
+
+ setParameters(parameters);
+ }
+
+ private void initP(){
+ int countIdx=data[0].length-1;
+
+ p=new double[data.length][];
+ for(int edge=0;edge<data.length;edge++){
+ p[edge]=c.posterior(phrase,data[edge]);
+ }
+ for(int edge=0;edge<data.length;edge++){
+ llh+=Math.log
+ (data[edge][countIdx]*arr.F.l1norm(p[edge]));
+ arr.F.l1normalize(p[edge]);
+ }
+ }
+
+ @Override
+ public void setParameters(double[] params) {
+ super.setParameters(params);
+ updateFunction();
+ }
+
+ private void updateFunction(){
+ updateCalls++;
+ loglikelihood=0;
+ int countIdx=data[0].length-1;
+ for(int tag=0;tag<c.K;tag++){
+ for(int edge=0;edge<data.length;edge++){
+ q[edge][tag]=p[edge][tag]*
+ Math.exp(-parameters[tag*data.length+edge]/data[edge][countIdx]);
+ }
+ }
+
+ for(int edge=0;edge<data.length;edge++){
+ loglikelihood+=Math.log
+ (data[edge][countIdx]*arr.F.l1norm(q[edge]));
+ arr.F.l1normalize(q[edge]);
+ }
+
+ for(int tag=0;tag<c.K;tag++){
+ for(int edge=0;edge<data.length;edge++){
+ gradient[tag*data.length+edge]=-q[edge][tag];
+ }
+ }
+ }
+
+ @Override
+ // TODO Auto-generated method stub
+ public double[] projectPoint(double[] point) {
+ double toProject[]=new double[data.length];
+ for(int tag=0;tag<c.K;tag++){
+ for(int edge=0;edge<data.length;edge++){
+ toProject[edge]=point[tag*data.length+edge];
+ }
+ projection.project(toProject);
+ for(int edge=0;edge<data.length;edge++){
+ newPoint[tag*data.length+edge]=toProject[edge];
+ }
+ }
+ return newPoint;
+ }
+
+ @Override
+ public double[] getGradient() {
+ // TODO Auto-generated method stub
+ gradientCalls++;
+ return gradient;
+ }
+
+ @Override
+ public double getValue() {
+ // TODO Auto-generated method stub
+ functionCalls++;
+ return loglikelihood;
+ }
+
+ @Override
+ public String toString() {
+ // TODO Auto-generated method stub
+ return "";
+ }
+
+ public double [][]posterior(){
+ return q;
+ }
+
+ public void optimizeWithProjectedGradientDescent(){
+ LineSearchMethod ls =
+ new ArmijoLineSearchMinimizationAlongProjectionArc
+ (new InterpolationPickFirstStep(INIT_STEP_SIZE));
+ //LineSearchMethod ls = new WolfRuleLineSearch(
+ // (new InterpolationPickFirstStep(INIT_STEP_SIZE)), c1, c2);
+ OptimizerStats stats = new OptimizerStats();
+
+
+ ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
+ StopingCriteria stopGrad = new ProjectedGradientL2Norm(GRAD_DIFF);
+ StopingCriteria stopValue = new ValueDifference(VAL_DIFF);
+ CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
+ compositeStop.add(stopGrad);
+ compositeStop.add(stopValue);
+ optimizer.setMaxIterations(100);
+ updateFunction();
+ boolean succed = optimizer.optimize(this,stats,compositeStop);
+// System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
+ if(succed){
+ System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
+ }else{
+ System.out.println("Failed to optimize");
+ }
+
+ // ps.println(Arrays.toString(parameters));
+
+ // for(int edge=0;edge<data.length;edge++){
+ // ps.println(Arrays.toString(q[edge]));
+ // }
+
+ }
+
+ /**
+ * L - KL(q||p) -
+ * scale * \sum_{tag,phrase} max_i P(tag|i th occurrence of phrase)
+ * @return
+ */
+ public double primal()
+ {
+
+ double l=llh;
+
+// ps.print("Phrase "+phrase+": "+l);
+ double kl=-loglikelihood
+ +MathUtils.dotProduct(parameters, gradient);
+// ps.print(", "+kl);
+ l=l-kl;
+ double sum=0;
+ for(int tag=0;tag<c.K;tag++){
+ double max=0;
+ for(int edge=0;edge<data.length;edge++){
+ if(q[edge][tag]>max){
+ max=q[edge][tag];
+ }
+ }
+ sum+=max;
+ }
+// ps.println(", "+sum);
+ l=l-scale*sum;
+ return l;
+ }
+
+}