From 33994330b8395c4c44ad0ddc1e678372404c3566 Mon Sep 17 00:00:00 2001 From: desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f> Date: Mon, 5 Jul 2010 15:26:42 +0000 Subject: forget to add files git-svn-id: https://ws10smt.googlecode.com/svn/trunk@126 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/posterior-regularisation/prjava/src/arr/F.java | 70 ++++++ .../prjava/src/phrase/C2F.java | 17 ++ .../prjava/src/phrase/PhraseCluster.java | 260 +++++++++++++++++++++ .../prjava/src/phrase/PhraseCorpus.java | 183 +++++++++++++++ .../prjava/src/phrase/PhraseObjective.java | 229 ++++++++++++++++++ 5 files changed, 759 insertions(+) create mode 100644 gi/posterior-regularisation/prjava/src/arr/F.java create mode 100644 gi/posterior-regularisation/prjava/src/phrase/C2F.java create mode 100644 gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java create mode 100644 gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java create mode 100644 gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java (limited to 'gi/posterior-regularisation/prjava/src') diff --git a/gi/posterior-regularisation/prjava/src/arr/F.java b/gi/posterior-regularisation/prjava/src/arr/F.java new file mode 100644 index 00000000..c194496e --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/arr/F.java @@ -0,0 +1,70 @@ +package arr; + +public class F { + public static void randomise(double probs[]) + { + double z = 0; + for (int i = 0; i < probs.length; ++i) + { + probs[i] = 3 + Math.random(); + z += probs[i]; + } + + for (int i = 0; i < probs.length; ++i) + probs[i] /= z; + } + + public static void l1normalize(double [] a){ + double sum=0; + for(int i=0;i<a.length;i++){ + sum+=a[i]; + } + if(sum==0){ + return ; + } + for(int i=0;i<a.length;i++){ + a[i]/=sum; + } + } + + public static void l1normalize(double [][] a){ + double sum=0; + for(int i=0;i<a.length;i++){ + for(int j=0;j<a[i].length;j++){ + sum+=a[i][j]; + } + } + if(sum==0){ + return; + } + for(int i=0;i<a.length;i++){ + for(int j=0;j<a[i].length;j++){ + a[i][j]/=sum; + } + } + } + + public static double l1norm(double a[]){ + double norm=0; + for(int i=0;i<a.length;i++){ + norm += a[i]; + } + return norm; + } + + public static int argmax(double probs[]) + { + double m = Double.NEGATIVE_INFINITY; + int mi = -1; + for (int i = 0; i < probs.length; ++i) + { + if (probs[i] > m) + { + m = probs[i]; + mi = i; + } + } + return mi; + } + +} diff --git a/gi/posterior-regularisation/prjava/src/phrase/C2F.java b/gi/posterior-regularisation/prjava/src/phrase/C2F.java new file mode 100644 index 00000000..2646d961 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/C2F.java @@ -0,0 +1,17 @@ +package phrase; +/** + * @brief context generates phrase + * @author desaic + * + */ +public class C2F { + + /** + * @param args + */ + public static void main(String[] args) { + // TODO Auto-generated method stub + + } + +} diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java new file mode 100644 index 00000000..8b1e0a8c --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -0,0 +1,260 @@ +package phrase; + +import io.FileUtil; + +import java.io.PrintStream; +import java.util.Arrays; + +public class PhraseCluster { + + /**@brief number of clusters*/ + public int K; + private int n_phrase; + private int n_words; + public PhraseCorpus c; + + /**@brief + * emit[tag][position][word] + */ + private double emit[][][]; + private double pi[][]; + + public static int ITER=20; + public static String postFilename="../pdata/posterior.out"; + public static String phraseStatFilename="../pdata/phrase_stat.out"; + private static int NUM_TAG=3; + public static void main(String[] args) { + + PhraseCorpus c=new PhraseCorpus(PhraseCorpus.DATA_FILENAME); + + PhraseCluster cluster=new PhraseCluster(NUM_TAG,c); + PhraseObjective.ps=FileUtil.openOutFile(phraseStatFilename); + for(int i=0;i<ITER;i++){ + PhraseObjective.ps.println("ITER: "+i); + cluster.PREM(); + // cluster.EM(); + } + + PrintStream ps=io.FileUtil.openOutFile(postFilename); + cluster.displayPosterior(ps); + ps.println(); + cluster.displayModelParam(ps); + ps.close(); + PhraseObjective.ps.close(); + } + + public PhraseCluster(int numCluster,PhraseCorpus corpus){ + K=numCluster; + c=corpus; + n_words=c.wordLex.size(); + n_phrase=c.data.length; + + emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words]; + pi=new double[n_phrase][K]; + + for(double [][]i:emit){ + for(double []j:i){ + arr.F.randomise(j); + } + } + + for(double []j:pi){ + arr.F.randomise(j); + } + + pi[0]=new double[]{ + 0.3,0.5,0.2 + }; + + double temp[][]=new double[][]{ + {0.11,0.16,0.19,0.11,0.1}, + {0.10,0.15,0.18,0.1,0.11}, + {0.09,0.07,0.12,0.14,0.13} + }; + + for(int tag=0;tag<3;tag++){ + for(int word=0;word<4;word++){ + for(int pos=0;pos<4;pos++){ + emit[tag][pos][word]=temp[tag][word]; + } + } + } + + } + + public void EM(){ + double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words]; + double [][]exp_pi=new double[n_phrase][K]; + + double loglikelihood=0; + + //E + for(int phrase=0;phrase<c.data.length;phrase++){ + int [][] data=c.data[phrase]; + for(int ctx=0;ctx<data.length;ctx++){ + int context[]=data[ctx]; + double p[]=posterior(phrase,context); + loglikelihood+=Math.log(arr.F.l1norm(p)); + arr.F.l1normalize(p); + + int contextCnt=context[context.length-1]; + //increment expected count + for(int tag=0;tag<K;tag++){ + for(int pos=0;pos<context.length-1;pos++){ + exp_emit[tag][pos][context[pos]]+=p[tag]*contextCnt; + } + + exp_pi[phrase][tag]+=p[tag]*contextCnt; + } + } + } + + System.out.println("Log likelihood: "+loglikelihood); + + //M + for(double [][]i:exp_emit){ + for(double []j:i){ + arr.F.l1normalize(j); + } + } + + emit=exp_emit; + + for(double []j:exp_pi){ + arr.F.l1normalize(j); + } + + pi=exp_pi; + } + + public void PREM(){ + double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words]; + double [][]exp_pi=new double[n_phrase][K]; + + double loglikelihood=0; + double primal=0; + //E + for(int phrase=0;phrase<c.data.length;phrase++){ + PhraseObjective po=new PhraseObjective(this,phrase); + po.optimizeWithProjectedGradientDescent(); + double [][] q=po.posterior(); + loglikelihood+=po.getValue(); + primal+=po.primal(); + for(int edge=0;edge<q.length;edge++){ + int []context=c.data[phrase][edge]; + int contextCnt=context[context.length-1]; + //increment expected count + for(int tag=0;tag<K;tag++){ + for(int pos=0;pos<context.length-1;pos++){ + exp_emit[tag][pos][context[pos]]+=q[edge][tag]*contextCnt; + } + + exp_pi[phrase][tag]+=q[edge][tag]*contextCnt; + } + } + } + + System.out.println("Log likelihood: "+loglikelihood); + System.out.println("Primal Objective: "+primal); + + //M + for(double [][]i:exp_emit){ + for(double []j:i){ + arr.F.l1normalize(j); + } + } + + emit=exp_emit; + + for(double []j:exp_pi){ + arr.F.l1normalize(j); + } + + pi=exp_pi; + } + + /** + * + * @param phrase index of phrase + * @param ctx array of context + * @return unnormalized posterior + */ + public double[]posterior(int phrase, int[]ctx){ + double[] prob=Arrays.copyOf(pi[phrase], K); + + for(int tag=0;tag<K;tag++){ + for(int c=0;c<ctx.length-1;c++){ + int word=ctx[c]; + prob[tag]*=emit[tag][c][word]; + } + } + + return prob; + } + + public void displayPosterior(PrintStream ps) + { + + c.buildList(); + + for (int i = 0; i < n_phrase; ++i) + { + int [][]data=c.data[i]; + for (int[] e: data) + { + double probs[] = posterior(i, e); + arr.F.l1normalize(probs); + + // emit phrase + ps.print(c.phraseList[i]); + ps.print("\t"); + ps.print(c.getContextString(e)); + ps.print("||| C=" + e[e.length-1] + " |||"); + + int t=arr.F.argmax(probs); + + ps.print(t+"||| ["); + for(t=0;t<K;t++){ + ps.print(probs[t]+", "); + } + // for (int t = 0; t < numTags; ++t) + // System.out.print(" " + probs[t]); + ps.println("]"); + } + } + } + + public void displayModelParam(PrintStream ps) + { + + c.buildList(); + + ps.println("P(tag|phrase)"); + for (int i = 0; i < n_phrase; ++i) + { + ps.print(c.phraseList[i]); + for(int j=0;j<pi[i].length;j++){ + ps.print("\t"+pi[i][j]); + } + ps.println(); + } + + ps.println("P(word|tag,position)"); + for (int i = 0; i < K; ++i) + { + ps.println(i); + for(int position=0;position<PhraseCorpus.NUM_CONTEXT;position++){ + ps.println(position); + for(int word=0;word<emit[i][position].length;word++){ + if((word+1)%100==0){ + ps.println(); + } + ps.print(c.wordList[word]+"="+emit[i][position][word]+"\t"); + } + ps.println(); + } + ps.println(); + } + + } +} diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java new file mode 100644 index 00000000..3902f665 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java @@ -0,0 +1,183 @@ +package phrase; + +import java.io.BufferedInputStream; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Scanner; + +public class PhraseCorpus { + + + public static String LEX_FILENAME="../pdata/lex.out"; + //public static String DATA_FILENAME="../pdata/canned.con"; + public static String DATA_FILENAME="../pdata/btec.con"; + public static int NUM_CONTEXT=4; + + public HashMap<String,Integer>wordLex; + public HashMap<String,Integer>phraseLex; + + public String wordList[]; + public String phraseList[]; + + //data[phrase][num context][position] + public int data[][][]; + + public static void main(String[] args) { + // TODO Auto-generated method stub + PhraseCorpus c=new PhraseCorpus(DATA_FILENAME); + c.saveLex(LEX_FILENAME); + c.loadLex(LEX_FILENAME); + c.saveLex(LEX_FILENAME); + } + + public PhraseCorpus(String filename){ + BufferedReader r=io.FileUtil.openBufferedReader(filename); + + phraseLex=new HashMap<String,Integer>(); + wordLex=new HashMap<String,Integer>(); + + ArrayList<int[][]>dataList=new ArrayList<int[][]>(); + String line=null; + + while((line=readLine(r))!=null){ + + String toks[]=line.split("\t"); + String phrase=toks[0]; + addLex(phrase,phraseLex); + + toks=toks[1].split(" \\|\\|\\| "); + + ArrayList <int[]>ctxList=new ArrayList<int[]>(); + + for(int i=0;i<toks.length;i+=2){ + String ctx=toks[i]; + String words[]=ctx.split(" "); + int []context=new int [NUM_CONTEXT+1]; + int idx=0; + for(String word:words){ + if(word.equals("<PHRASE>")){ + continue; + } + addLex(word,wordLex); + context[idx]=wordLex.get(word); + idx++; + } + + String count=toks[i+1]; + context[idx]=Integer.parseInt(count.trim().substring(2)); + + + ctxList.add(context); + + } + + dataList.add(ctxList.toArray(new int [0][])); + + } + try{ + r.close(); + }catch(IOException ioe){ + ioe.printStackTrace(); + } + data=dataList.toArray(new int[0][][]); + } + + private void addLex(String key, HashMap<String,Integer>lex){ + Integer i=lex.get(key); + if(i==null){ + lex.put(key, lex.size()); + } + } + + //for debugging + public void saveLex(String lexFilename){ + PrintStream ps=io.FileUtil.openOutFile(lexFilename); + ps.println("Phrase Lexicon"); + ps.println(phraseLex.size()); + printDict(phraseLex,ps); + + ps.println("Word Lexicon"); + ps.println(wordLex.size()); + printDict(wordLex,ps); + ps.close(); + } + + private static void printDict(HashMap<String,Integer>lex,PrintStream ps){ + String []dict=buildList(lex); + for(int i=0;i<dict.length;i++){ + ps.println(dict[i]); + } + } + + public void loadLex(String lexFilename){ + Scanner sc=io.FileUtil.openInFile(lexFilename); + + sc.nextLine(); + int size=sc.nextInt(); + sc.nextLine(); + String[]dict=new String[size]; + for(int i=0;i<size;i++){ + dict[i]=sc.nextLine(); + } + phraseLex=buildMap(dict); + + sc.nextLine(); + size=sc.nextInt(); + sc.nextLine(); + dict=new String[size]; + for(int i=0;i<size;i++){ + dict[i]=sc.nextLine(); + } + wordLex=buildMap(dict); + sc.close(); + } + + private HashMap<String, Integer> buildMap(String[]dict){ + HashMap<String,Integer> map=new HashMap<String,Integer>(); + for(int i=0;i<dict.length;i++){ + map.put(dict[i], i); + } + return map; + } + + public void buildList(){ + if(wordList==null){ + wordList=buildList(wordLex); + phraseList=buildList(phraseLex); + } + } + + private static String[]buildList(HashMap<String,Integer>lex){ + String dict[]=new String [lex.size()]; + for(String key:lex.keySet()){ + dict[lex.get(key)]=key; + } + return dict; + } + + public String getContextString(int context[]) + { + StringBuffer b = new StringBuffer(); + for (int i=0;i<context.length-1;i++) + { + if (b.length() > 0) + b.append(" "); + b.append(wordList[context[i]]); + } + return b.toString(); + } + + public static String readLine(BufferedReader r){ + try{ + return r.readLine(); + } + catch(IOException ioe){ + ioe.printStackTrace(); + } + return null; + } + +} diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java new file mode 100644 index 00000000..e9e063d6 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -0,0 +1,229 @@ +package phrase; + +import java.io.PrintStream; +import java.util.Arrays; + +import optimization.gradientBasedMethods.ProjectedGradientDescent; +import optimization.gradientBasedMethods.ProjectedObjective; +import optimization.gradientBasedMethods.stats.OptimizerStats; +import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc; +import optimization.linesearch.InterpolationPickFirstStep; +import optimization.linesearch.LineSearchMethod; +import optimization.linesearch.WolfRuleLineSearch; +import optimization.projections.SimplexProjection; +import optimization.stopCriteria.CompositeStopingCriteria; +import optimization.stopCriteria.ProjectedGradientL2Norm; +import optimization.stopCriteria.StopingCriteria; +import optimization.stopCriteria.ValueDifference; +import optimization.util.MathUtils; + +public class PhraseObjective extends ProjectedObjective{ + + private static final double GRAD_DIFF = 0.002; + public static double INIT_STEP_SIZE=1; + public static double VAL_DIFF=0.001; + private double scale=5; + private double c1=0.0001; + private double c2=0.9; + + private PhraseCluster c; + + /**@brief + * for debugging purposes + */ + public static PrintStream ps; + + /**@brief current phrase being optimzed*/ + public int phrase; + + /**@brief un-regularized posterior + * unnormalized + * p[edge][tag] + * P(tag|edge) \propto P(tag|phrase)P(context|tag) + */ + private double[][]p; + + /**@brief regularized posterior + * q[edge][tag] propto p[edge][tag]*exp(-lambda) + */ + private double q[][]; + private int data[][]; + + /**@brief log likelihood of the associated phrase + * + */ + private double loglikelihood; + private SimplexProjection projection; + + double[] newPoint ; + + private int n_param; + + /**@brief likelihood under p + * + */ + private double llh; + + public PhraseObjective(PhraseCluster cluster, int phraseIdx){ + phrase=phraseIdx; + c=cluster; + data=c.c.data[phrase]; + n_param=data.length*c.K; + parameters=new double [n_param]; + newPoint = new double[n_param]; + gradient = new double[n_param]; + initP(); + projection=new SimplexProjection (scale); + q=new double [data.length][c.K]; + + setParameters(parameters); + } + + private void initP(){ + int countIdx=data[0].length-1; + + p=new double[data.length][]; + for(int edge=0;edge<data.length;edge++){ + p[edge]=c.posterior(phrase,data[edge]); + } + for(int edge=0;edge<data.length;edge++){ + llh+=Math.log + (data[edge][countIdx]*arr.F.l1norm(p[edge])); + arr.F.l1normalize(p[edge]); + } + } + + @Override + public void setParameters(double[] params) { + super.setParameters(params); + updateFunction(); + } + + private void updateFunction(){ + updateCalls++; + loglikelihood=0; + int countIdx=data[0].length-1; + for(int tag=0;tag<c.K;tag++){ + for(int edge=0;edge<data.length;edge++){ + q[edge][tag]=p[edge][tag]* + Math.exp(-parameters[tag*data.length+edge]/data[edge][countIdx]); + } + } + + for(int edge=0;edge<data.length;edge++){ + loglikelihood+=Math.log + (data[edge][countIdx]*arr.F.l1norm(q[edge])); + arr.F.l1normalize(q[edge]); + } + + for(int tag=0;tag<c.K;tag++){ + for(int edge=0;edge<data.length;edge++){ + gradient[tag*data.length+edge]=-q[edge][tag]; + } + } + } + + @Override + // TODO Auto-generated method stub + public double[] projectPoint(double[] point) { + double toProject[]=new double[data.length]; + for(int tag=0;tag<c.K;tag++){ + for(int edge=0;edge<data.length;edge++){ + toProject[edge]=point[tag*data.length+edge]; + } + projection.project(toProject); + for(int edge=0;edge<data.length;edge++){ + newPoint[tag*data.length+edge]=toProject[edge]; + } + } + return newPoint; + } + + @Override + public double[] getGradient() { + // TODO Auto-generated method stub + gradientCalls++; + return gradient; + } + + @Override + public double getValue() { + // TODO Auto-generated method stub + functionCalls++; + return loglikelihood; + } + + @Override + public String toString() { + // TODO Auto-generated method stub + return ""; + } + + public double [][]posterior(){ + return q; + } + + public void optimizeWithProjectedGradientDescent(){ + LineSearchMethod ls = + new ArmijoLineSearchMinimizationAlongProjectionArc + (new InterpolationPickFirstStep(INIT_STEP_SIZE)); + //LineSearchMethod ls = new WolfRuleLineSearch( + // (new InterpolationPickFirstStep(INIT_STEP_SIZE)), c1, c2); + OptimizerStats stats = new OptimizerStats(); + + + ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls); + StopingCriteria stopGrad = new ProjectedGradientL2Norm(GRAD_DIFF); + StopingCriteria stopValue = new ValueDifference(VAL_DIFF); + CompositeStopingCriteria compositeStop = new CompositeStopingCriteria(); + compositeStop.add(stopGrad); + compositeStop.add(stopValue); + optimizer.setMaxIterations(100); + updateFunction(); + boolean succed = optimizer.optimize(this,stats,compositeStop); +// System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1)); + if(succed){ + System.out.println("Ended optimization in " + optimizer.getCurrentIteration()); + }else{ + System.out.println("Failed to optimize"); + } + + // ps.println(Arrays.toString(parameters)); + + // for(int edge=0;edge<data.length;edge++){ + // ps.println(Arrays.toString(q[edge])); + // } + + } + + /** + * L - KL(q||p) - + * scale * \sum_{tag,phrase} max_i P(tag|i th occurrence of phrase) + * @return + */ + public double primal() + { + + double l=llh; + +// ps.print("Phrase "+phrase+": "+l); + double kl=-loglikelihood + +MathUtils.dotProduct(parameters, gradient); +// ps.print(", "+kl); + l=l-kl; + double sum=0; + for(int tag=0;tag<c.K;tag++){ + double max=0; + for(int edge=0;edge<data.length;edge++){ + if(q[edge][tag]>max){ + max=q[edge][tag]; + } + } + sum+=max; + } +// ps.println(", "+sum); + l=l-scale*sum; + return l; + } + +} -- cgit v1.2.3