From 30bbb07a467490007ba6959c9734578ba0dbe24b Mon Sep 17 00:00:00 2001 From: desaicwtf Date: Mon, 5 Jul 2010 15:26:42 +0000 Subject: forget to add files git-svn-id: https://ws10smt.googlecode.com/svn/trunk@126 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/posterior-regularisation/prjava/src/arr/F.java | 70 ++++++ .../prjava/src/phrase/C2F.java | 17 ++ .../prjava/src/phrase/PhraseCluster.java | 260 +++++++++++++++++++++ .../prjava/src/phrase/PhraseCorpus.java | 183 +++++++++++++++ .../prjava/src/phrase/PhraseObjective.java | 229 ++++++++++++++++++ 5 files changed, 759 insertions(+) create mode 100644 gi/posterior-regularisation/prjava/src/arr/F.java create mode 100644 gi/posterior-regularisation/prjava/src/phrase/C2F.java create mode 100644 gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java create mode 100644 gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java create mode 100644 gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java (limited to 'gi/posterior-regularisation/prjava') diff --git a/gi/posterior-regularisation/prjava/src/arr/F.java b/gi/posterior-regularisation/prjava/src/arr/F.java new file mode 100644 index 00000000..c194496e --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/arr/F.java @@ -0,0 +1,70 @@ +package arr; + +public class F { + public static void randomise(double probs[]) + { + double z = 0; + for (int i = 0; i < probs.length; ++i) + { + probs[i] = 3 + Math.random(); + z += probs[i]; + } + + for (int i = 0; i < probs.length; ++i) + probs[i] /= z; + } + + public static void l1normalize(double [] a){ + double sum=0; + for(int i=0;i m) + { + m = probs[i]; + mi = i; + } + } + return mi; + } + +} diff --git a/gi/posterior-regularisation/prjava/src/phrase/C2F.java b/gi/posterior-regularisation/prjava/src/phrase/C2F.java new file mode 100644 index 00000000..2646d961 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/C2F.java @@ -0,0 +1,17 @@ +package phrase; +/** + * @brief context generates phrase + * @author desaic + * + */ +public class C2F { + + /** + * @param args + */ + public static void main(String[] args) { + // TODO Auto-generated method stub + + } + +} diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java new file mode 100644 index 00000000..8b1e0a8c --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -0,0 +1,260 @@ +package phrase; + +import io.FileUtil; + +import java.io.PrintStream; +import java.util.Arrays; + +public class PhraseCluster { + + /**@brief number of clusters*/ + public int K; + private int n_phrase; + private int n_words; + public PhraseCorpus c; + + /**@brief + * emit[tag][position][word] + */ + private double emit[][][]; + private double pi[][]; + + public static int ITER=20; + public static String postFilename="../pdata/posterior.out"; + public static String phraseStatFilename="../pdata/phrase_stat.out"; + private static int NUM_TAG=3; + public static void main(String[] args) { + + PhraseCorpus c=new PhraseCorpus(PhraseCorpus.DATA_FILENAME); + + PhraseCluster cluster=new PhraseCluster(NUM_TAG,c); + PhraseObjective.ps=FileUtil.openOutFile(phraseStatFilename); + for(int i=0;iwordLex; + public HashMapphraseLex; + + public String wordList[]; + public String phraseList[]; + + //data[phrase][num context][position] + public int data[][][]; + + public static void main(String[] args) { + // TODO Auto-generated method stub + PhraseCorpus c=new PhraseCorpus(DATA_FILENAME); + c.saveLex(LEX_FILENAME); + c.loadLex(LEX_FILENAME); + c.saveLex(LEX_FILENAME); + } + + public PhraseCorpus(String filename){ + BufferedReader r=io.FileUtil.openBufferedReader(filename); + + phraseLex=new HashMap(); + wordLex=new HashMap(); + + ArrayListdataList=new ArrayList(); + String line=null; + + while((line=readLine(r))!=null){ + + String toks[]=line.split("\t"); + String phrase=toks[0]; + addLex(phrase,phraseLex); + + toks=toks[1].split(" \\|\\|\\| "); + + ArrayList ctxList=new ArrayList(); + + for(int i=0;i")){ + continue; + } + addLex(word,wordLex); + context[idx]=wordLex.get(word); + idx++; + } + + String count=toks[i+1]; + context[idx]=Integer.parseInt(count.trim().substring(2)); + + + ctxList.add(context); + + } + + dataList.add(ctxList.toArray(new int [0][])); + + } + try{ + r.close(); + }catch(IOException ioe){ + ioe.printStackTrace(); + } + data=dataList.toArray(new int[0][][]); + } + + private void addLex(String key, HashMaplex){ + Integer i=lex.get(key); + if(i==null){ + lex.put(key, lex.size()); + } + } + + //for debugging + public void saveLex(String lexFilename){ + PrintStream ps=io.FileUtil.openOutFile(lexFilename); + ps.println("Phrase Lexicon"); + ps.println(phraseLex.size()); + printDict(phraseLex,ps); + + ps.println("Word Lexicon"); + ps.println(wordLex.size()); + printDict(wordLex,ps); + ps.close(); + } + + private static void printDict(HashMaplex,PrintStream ps){ + String []dict=buildList(lex); + for(int i=0;i buildMap(String[]dict){ + HashMap map=new HashMap(); + for(int i=0;ilex){ + String dict[]=new String [lex.size()]; + for(String key:lex.keySet()){ + dict[lex.get(key)]=key; + } + return dict; + } + + public String getContextString(int context[]) + { + StringBuffer b = new StringBuffer(); + for (int i=0;i 0) + b.append(" "); + b.append(wordList[context[i]]); + } + return b.toString(); + } + + public static String readLine(BufferedReader r){ + try{ + return r.readLine(); + } + catch(IOException ioe){ + ioe.printStackTrace(); + } + return null; + } + +} diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java new file mode 100644 index 00000000..e9e063d6 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -0,0 +1,229 @@ +package phrase; + +import java.io.PrintStream; +import java.util.Arrays; + +import optimization.gradientBasedMethods.ProjectedGradientDescent; +import optimization.gradientBasedMethods.ProjectedObjective; +import optimization.gradientBasedMethods.stats.OptimizerStats; +import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc; +import optimization.linesearch.InterpolationPickFirstStep; +import optimization.linesearch.LineSearchMethod; +import optimization.linesearch.WolfRuleLineSearch; +import optimization.projections.SimplexProjection; +import optimization.stopCriteria.CompositeStopingCriteria; +import optimization.stopCriteria.ProjectedGradientL2Norm; +import optimization.stopCriteria.StopingCriteria; +import optimization.stopCriteria.ValueDifference; +import optimization.util.MathUtils; + +public class PhraseObjective extends ProjectedObjective{ + + private static final double GRAD_DIFF = 0.002; + public static double INIT_STEP_SIZE=1; + public static double VAL_DIFF=0.001; + private double scale=5; + private double c1=0.0001; + private double c2=0.9; + + private PhraseCluster c; + + /**@brief + * for debugging purposes + */ + public static PrintStream ps; + + /**@brief current phrase being optimzed*/ + public int phrase; + + /**@brief un-regularized posterior + * unnormalized + * p[edge][tag] + * P(tag|edge) \propto P(tag|phrase)P(context|tag) + */ + private double[][]p; + + /**@brief regularized posterior + * q[edge][tag] propto p[edge][tag]*exp(-lambda) + */ + private double q[][]; + private int data[][]; + + /**@brief log likelihood of the associated phrase + * + */ + private double loglikelihood; + private SimplexProjection projection; + + double[] newPoint ; + + private int n_param; + + /**@brief likelihood under p + * + */ + private double llh; + + public PhraseObjective(PhraseCluster cluster, int phraseIdx){ + phrase=phraseIdx; + c=cluster; + data=c.c.data[phrase]; + n_param=data.length*c.K; + parameters=new double [n_param]; + newPoint = new double[n_param]; + gradient = new double[n_param]; + initP(); + projection=new SimplexProjection (scale); + q=new double [data.length][c.K]; + + setParameters(parameters); + } + + private void initP(){ + int countIdx=data[0].length-1; + + p=new double[data.length][]; + for(int edge=0;edgemax){ + max=q[edge][tag]; + } + } + sum+=max; + } +// ps.println(", "+sum); + l=l-scale*sum; + return l; + } + +} -- cgit v1.2.3