package hmm; import gnu.trove.TIntArrayList; import optimization.gradientBasedMethods.ProjectedGradientDescent; import optimization.gradientBasedMethods.ProjectedObjective; import optimization.gradientBasedMethods.stats.OptimizerStats; import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc; import optimization.linesearch.InterpolationPickFirstStep; import optimization.linesearch.LineSearchMethod; import optimization.projections.SimplexProjection; import optimization.stopCriteria.CompositeStopingCriteria; import optimization.stopCriteria.ProjectedGradientL2Norm; import optimization.stopCriteria.StopingCriteria; import optimization.stopCriteria.ValueDifference; public class HMMObjective extends ProjectedObjective{ private static final double GRAD_DIFF = 3; public static double INIT_STEP_SIZE=10; public static double VAL_DIFF=1000; private HMM hmm; double[] newPoint ; //posterior[sent num][tok num][tag]=index into lambda private int posteriorMap[][][]; //projection[word][tag].get(occurence)=index into lambda private TIntArrayList projectionMap[][]; //Size of the simplex public double scale=10; private SimplexProjection projection; private int wordFreq[]; private static int MIN_FREQ=10; private int numWordsToProject=0; private int n_param; public double loglikelihood; public HMMObjective(HMM h){ hmm=h; countWords(); buildMap(); gradient=new double [n_param]; projection = new SimplexProjection(scale); newPoint = new double[n_param]; setInitialParameters(new double[n_param]); } /**@brief counts word frequency in the corpus * */ private void countWords(){ wordFreq=new int [hmm.emit[0].length]; for(int i=0;i<hmm.data.length;i++){ for(int j=0;j<hmm.data[i].length;j++){ wordFreq[hmm.data[i][j]]++; } } } /**@brief build posterior and projection indices * */ private void buildMap(){ //number of sentences hidden states and words int n_states=hmm.trans.length; int n_words=hmm.emit[0].length; int n_sents=hmm.data.length; n_param=0; posteriorMap=new int[n_sents][][]; projectionMap=new TIntArrayList[n_words][]; for(int sentNum=0;sentNum<n_sents;sentNum++){ int [] data=hmm.data[sentNum]; posteriorMap[sentNum]=new int[data.length][n_states]; numWordsToProject=0; for(int i=0;i<data.length;i++){ int word=data[i]; for(int state=0;state<n_states;state++){ if(wordFreq[word]>MIN_FREQ){ if(projectionMap[word]==null){ projectionMap[word]=new TIntArrayList[n_states]; } // if(posteriorMap[sentNum][i]==null){ // posteriorMap[sentNum][i]=new int[n_states]; // } posteriorMap[sentNum][i][state]=n_param; if(projectionMap[word][state]==null){ projectionMap[word][state]=new TIntArrayList(); numWordsToProject++; } projectionMap[word][state].add(n_param); n_param++; } else{ posteriorMap[sentNum][i][state]=-1; } } } } } @Override public double[] projectPoint(double[] point) { // TODO Auto-generated method stub for(int i=0;i<projectionMap.length;i++){ if(projectionMap[i]==null){ //this word is not constrained continue; } for(int j=0;j<projectionMap[i].length;j++){ TIntArrayList instances=projectionMap[i][j]; double[] toProject = new double[instances.size()]; for (int k = 0; k < toProject.length; k++) { // System.out.print(instances.get(k) + " "); toProject[k] = point[instances.get(k)]; } projection.project(toProject); for (int k = 0; k < toProject.length; k++) { newPoint[instances.get(k)]=toProject[k]; } } } return newPoint; } @Override public double[] getGradient() { // TODO Auto-generated method stub gradientCalls++; return gradient; } @Override public double getValue() { // TODO Auto-generated method stub functionCalls++; return loglikelihood; } @Override public String toString() { // TODO Auto-generated method stub StringBuffer sb = new StringBuffer(); for (int i = 0; i < parameters.length; i++) { sb.append(parameters[i]+" "); if(i%100==0){ sb.append("\n"); } } sb.append("\n"); /* for (int i = 0; i < gradient.length; i++) { sb.append(gradient[i]+" "); if(i%100==0){ sb.append("\n"); } } sb.append("\n"); */ return sb.toString(); } /** * @param seq * @return posterior probability of each transition */ public double [][][]forwardBackward(int sentNum){ int [] seq=hmm.data[sentNum]; int n_states=hmm.trans.length; double a[][]=new double [seq.length][n_states]; double b[][]=new double [seq.length][n_states]; int len=seq.length; boolean constrained= (projectionMap[seq[0]]!=null); //initialize the first step for(int i=0;i<n_states;i++){ a[0][i]=hmm.emit[i][seq[0]]*hmm.pi[i]; if(constrained){ a[0][i]*= Math.exp(- parameters[ posteriorMap[sentNum][0][i] ] ); } b[len-1][i]=1; } loglikelihood+=Math.log(hmm.l1norm(a[0])); hmm.l1normalize(a[0]); hmm.l1normalize(b[len-1]); //forward for(int n=1;n<len;n++){ constrained= (projectionMap[seq[n]]!=null); for(int i=0;i<n_states;i++){ for(int j=0;j<n_states;j++){ a[n][i]+=hmm.trans[j][i]*a[n-1][j]; } a[n][i]*=hmm.emit[i][seq[n]]; if(constrained){ a[n][i]*= Math.exp(- parameters[ posteriorMap[sentNum][n][i] ] ); } } loglikelihood+=Math.log(hmm.l1norm(a[n])); hmm.l1normalize(a[n]); } //temp variable for e^{-\lambda} double factor=1; //backward for(int n=len-2;n>=0;n--){ constrained= (projectionMap[seq[n+1]]!=null); for(int i=0;i<n_states;i++){ for(int j=0;j<n_states;j++){ if(constrained){ factor= Math.exp(- parameters[ posteriorMap[sentNum][n+1][j] ] ); }else{ factor=1; } b[n][i]+=hmm.trans[i][j]*b[n+1][j]*hmm.emit[j][seq[n+1]]*factor; } } hmm.l1normalize(b[n]); } //expected transition double p[][][]=new double [seq.length][n_states][n_states]; for(int n=0;n<len-1;n++){ constrained= (projectionMap[seq[n+1]]!=null); for(int i=0;i<n_states;i++){ for(int j=0;j<n_states;j++){ if(constrained){ factor= Math.exp(- parameters[ posteriorMap[sentNum][n+1][j] ] ); }else{ factor=1; } p[n][i][j]=a[n][i]*hmm.trans[i][j]* hmm.emit[j][seq[n+1]]*b[n+1][j]*factor; } } hmm.l1normalize(p[n]); } return p; } public void optimizeWithProjectedGradientDescent(){ LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc (new InterpolationPickFirstStep(INIT_STEP_SIZE)); OptimizerStats stats = new OptimizerStats(); ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls); StopingCriteria stopGrad = new ProjectedGradientL2Norm(GRAD_DIFF); StopingCriteria stopValue = new ValueDifference(VAL_DIFF); CompositeStopingCriteria compositeStop = new CompositeStopingCriteria(); compositeStop.add(stopGrad); compositeStop.add(stopValue); optimizer.setMaxIterations(10); updateFunction(); boolean succed = optimizer.optimize(this,stats,compositeStop); System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1)); if(succed){ System.out.println("Ended optimization in " + optimizer.getCurrentIteration()); }else{ System.out.println("Failed to optimize"); } } @Override public void setParameters(double[] params) { super.setParameters(params); updateFunction(); } private void updateFunction(){ updateCalls++; loglikelihood=0; for(int sentNum=0;sentNum<hmm.data.length;sentNum++){ double [][][]p=forwardBackward(sentNum); for(int n=0;n<p.length-1;n++){ for(int i=0;i<p[n].length;i++){ if(projectionMap[hmm.data[sentNum][n]]!=null){ double posterior=0; for(int j=0;j<p[n][i].length;j++){ posterior+=p[n][i][j]; } gradient[posteriorMap[sentNum][n][i]]=-posterior; } } } //the last state int n=p.length-2; for(int i=0;i<p[n].length;i++){ if(projectionMap[hmm.data[sentNum][n+1]]!=null){ double posterior=0; for(int j=0;j<p[n].length;j++){ posterior+=p[n][j][i]; } gradient[posteriorMap[sentNum][n+1][i]]=-posterior; } } } } }