diff options
author | desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-28 23:14:21 +0000 |
---|---|---|
committer | desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-28 23:14:21 +0000 |
commit | ad418214fe3b3fcd33d81225eb3d3fb08b67f88a (patch) | |
tree | 885ef102fd5508d4693ee3fe374b68a76a7f30fc /gi/posterior-regularisation/prjava/src/hmm | |
parent | f96bf4df7e4a34b42373723cbe38e6c7425e3239 (diff) |
add draft version of POS induction with HMM and L1 Linf constraints
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@47 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/hmm')
-rw-r--r-- | gi/posterior-regularisation/prjava/src/hmm/HMM.java | 576 | ||||
-rw-r--r-- | gi/posterior-regularisation/prjava/src/hmm/HMMObjective.java | 348 | ||||
-rw-r--r-- | gi/posterior-regularisation/prjava/src/hmm/POS.java | 126 |
3 files changed, 1050 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/hmm/HMM.java b/gi/posterior-regularisation/prjava/src/hmm/HMM.java new file mode 100644 index 00000000..1c4d7659 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/hmm/HMM.java @@ -0,0 +1,576 @@ +package hmm;
+
+import java.io.PrintStream;
+import java.util.ArrayList;
+import java.util.Scanner;
+
+public class HMM {
+
+
+ //trans[i][j]=prob of going FROM i to j
+ double [][]trans;
+ double [][]emit;
+ double []pi;
+ int [][]data;
+ int [][]tagdata;
+
+ double logtrans[][];
+
+ public HMMObjective o;
+
+ public static void main(String[] args) {
+
+ }
+
+ public HMM(int n_state,int n_emit,int [][]data){
+ trans=new double [n_state][n_state];
+ emit=new double[n_state][n_emit];
+ pi=new double [n_state];
+ System.out.println(" random initial parameters");
+ fillRand(trans);
+ fillRand(emit);
+ fillRand(pi);
+
+ this.data=data;
+
+ }
+
+ private void fillRand(double [][] a){
+ for(int i=0;i<a.length;i++){
+ for(int j=0;j<a[i].length;j++){
+ a[i][j]=Math.random();
+ }
+ l1normalize(a[i]);
+ }
+ }
+ private void fillRand(double []a){
+ for(int i=0;i<a.length;i++){
+ a[i]=Math.random();
+ }
+ l1normalize(a);
+ }
+
+ private double loglikely=0;
+
+ public void EM(){
+ double trans_exp_cnt[][]=new double [trans.length][trans.length];
+ double emit_exp_cnt[][]=new double[trans.length][emit[0].length];
+ double start_exp_cnt[]=new double[trans.length];
+ loglikely=0;
+
+ //E
+ for(int i=0;i<data.length;i++){
+
+ double [][][] post=forwardBackward(data[i]);
+ incrementExpCnt(post, data[i],
+ trans_exp_cnt,
+ emit_exp_cnt,
+ start_exp_cnt);
+
+
+ if(i%100==0){
+ System.out.print(".");
+ }
+ if(i%1000==0){
+ System.out.println(i);
+ }
+
+ }
+ System.out.println("Log likelihood: "+loglikely);
+
+ //M
+ addOneSmooth(emit_exp_cnt);
+ for(int i=0;i<trans.length;i++){
+
+ //transition probs
+ double sum=0;
+ for(int j=0;j<trans.length;j++){
+ sum+=trans_exp_cnt[i][j];
+ }
+ //avoid NAN
+ if(sum==0){
+ sum=1;
+ }
+ for(int j=0;j<trans[i].length;j++){
+ trans[i][j]=trans_exp_cnt[i][j]/sum;
+ }
+
+ //emission probs
+
+ sum=0;
+ for(int j=0;j<emit[i].length;j++){
+ sum+=emit_exp_cnt[i][j];
+ }
+ //avoid NAN
+ if(sum==0){
+ sum=1;
+ }
+ for(int j=0;j<emit[i].length;j++){
+ emit[i][j]=emit_exp_cnt[i][j]/sum;
+ }
+
+
+ //initial probs
+ for(int j=0;j<pi.length;j++){
+ pi[j]=start_exp_cnt[j];
+ }
+ l1normalize(pi);
+ }
+ }
+
+ private double [][][]forwardBackward(int [] seq){
+ double a[][]=new double [seq.length][trans.length];
+ double b[][]=new double [seq.length][trans.length];
+
+ int len=seq.length;
+ //initialize the first step
+ for(int i=0;i<trans.length;i++){
+ a[0][i]=emit[i][seq[0]]*pi[i];
+ b[len-1][i]=1;
+ }
+
+ //log of denominator for likelyhood
+ double c=Math.log(l1norm(a[0]));
+
+ l1normalize(a[0]);
+ l1normalize(b[len-1]);
+
+
+
+ //forward
+ for(int n=1;n<len;n++){
+ for(int i=0;i<trans.length;i++){
+ for(int j=0;j<trans.length;j++){
+ a[n][i]+=trans[j][i]*a[n-1][j];
+ }
+ a[n][i]*=emit[i][seq[n]];
+ }
+ c+=Math.log(l1norm(a[n]));
+ l1normalize(a[n]);
+ }
+
+ loglikely+=c;
+
+ //backward
+ for(int n=len-2;n>=0;n--){
+ for(int i=0;i<trans.length;i++){
+ for(int j=0;j<trans.length;j++){
+ b[n][i]+=trans[i][j]*b[n+1][j]*emit[j][seq[n+1]];
+ }
+ }
+ l1normalize(b[n]);
+ }
+
+
+ //expected transition
+ double p[][][]=new double [seq.length][trans.length][trans.length];
+ for(int n=0;n<len-1;n++){
+ for(int i=0;i<trans.length;i++){
+ for(int j=0;j<trans.length;j++){
+ p[n][i][j]=a[n][i]*trans[i][j]*emit[j][seq[n+1]]*b[n+1][j];
+
+ }
+ }
+
+ l1normalize(p[n]);
+ }
+ return p;
+ }
+
+ private void incrementExpCnt(
+ double post[][][],int [] seq,
+ double trans_exp_cnt[][],
+ double emit_exp_cnt[][],
+ double start_exp_cnt[])
+ {
+
+ for(int n=0;n<post.length;n++){
+ for(int i=0;i<trans.length;i++){
+ double py=0;
+ for(int j=0;j<trans.length;j++){
+ py+=post[n][i][j];
+ trans_exp_cnt[i][j]+=post[n][i][j];
+ }
+
+ emit_exp_cnt[i][seq[n]]+=py;
+
+ }
+ }
+
+ //the first state
+ for(int i=0;i<trans.length;i++){
+ double py=0;
+ for(int j=0;j<trans.length;j++){
+ py+=post[0][i][j];
+ }
+ start_exp_cnt[i]+=py;
+ }
+
+
+ //the last state
+ int len=post.length;
+ for(int i=0;i<trans.length;i++){
+ double py=0;
+ for(int j=0;j<trans.length;j++){
+ py+=post[len-2][j][i];
+ }
+ emit_exp_cnt[i][seq[len-1]]+=py;
+ }
+ }
+
+ public void l1normalize(double [] a){
+ double sum=0;
+ for(int i=0;i<a.length;i++){
+ sum+=a[i];
+ }
+ if(sum==0){
+ return ;
+ }
+ for(int i=0;i<a.length;i++){
+ a[i]/=sum;
+ }
+ }
+
+ public void l1normalize(double [][] a){
+ double sum=0;
+ for(int i=0;i<a.length;i++){
+ for(int j=0;j<a[i].length;j++){
+ sum+=a[i][j];
+ }
+ }
+ if(sum==0){
+ return;
+ }
+ for(int i=0;i<a.length;i++){
+ for(int j=0;j<a[i].length;j++){
+ a[i][j]/=sum;
+ }
+ }
+ }
+
+ public void writeModel(String modelFilename){
+ PrintStream ps=io.FileUtil.openOutFile(modelFilename);
+ ps.println(trans.length);
+ ps.println("Initial Probabilities:");
+ for(int i=0;i<pi.length;i++){
+ ps.print(pi[i]+"\t");
+ }
+ ps.println();
+ ps.println("Transition Probabilities:");
+ for(int i=0;i<trans.length;i++){
+ for(int j=0;j<trans[i].length;j++){
+ ps.print(trans[i][j]+"\t");
+ }
+ ps.println();
+ }
+ ps.println("Emission Probabilities:");
+ ps.println(emit[0].length);
+ for(int i=0;i<trans.length;i++){
+ for(int j=0;j<emit[i].length;j++){
+ ps.println(emit[i][j]);
+ }
+ ps.println();
+ }
+ ps.close();
+ }
+
+ public HMM(){
+
+ }
+
+ public void readModel(String modelFilename){
+ Scanner sc=io.FileUtil.openInFile(modelFilename);
+
+ int n_state=sc.nextInt();
+ sc.nextLine();
+ sc.nextLine();
+ pi=new double [n_state];
+ for(int i=0;i<n_state;i++){
+ pi[i]=sc.nextDouble();
+ }
+ sc.nextLine();
+ sc.nextLine();
+ trans=new double[n_state][n_state];
+ for(int i=0;i<trans.length;i++){
+ for(int j=0;j<trans[i].length;j++){
+ trans[i][j]=sc.nextDouble();
+ }
+ }
+ sc.nextLine();
+ sc.nextLine();
+
+ int n_obs=sc.nextInt();
+ emit=new double[n_state][n_obs];
+ for(int i=0;i<trans.length;i++){
+ for(int j=0;j<emit[i].length;j++){
+ emit[i][j]=sc.nextDouble();
+ }
+ }
+ sc.close();
+ }
+
+ public int []viterbi(int [] seq){
+ double [][]p=new double [seq.length][trans.length];
+ int backp[][]=new int [seq.length][trans.length];
+
+ for(int i=0;i<trans.length;i++){
+ p[0][i]=Math.log(emit[i][seq[0]]*pi[i]);
+ }
+
+ double a[][]=logtrans;
+ if(logtrans==null){
+ a=new double [trans.length][trans.length];
+ for(int i=0;i<trans.length;i++){
+ for(int j=0;j<trans.length;j++){
+ a[i][j]=Math.log(trans[i][j]);
+ }
+ }
+ logtrans=a;
+ }
+
+ double maxprob=0;
+ for(int n=1;n<seq.length;n++){
+ for(int i=0;i<trans.length;i++){
+ maxprob=p[n-1][0]+a[0][i];
+ backp[n][i]=0;
+ for(int j=1;j<trans.length;j++){
+ double prob=p[n-1][j]+a[j][i];
+ if(maxprob<prob){
+ backp[n][i]=j;
+ maxprob=prob;
+ }
+ }
+ p[n][i]=maxprob+Math.log(emit[i][seq[n]]);
+ }
+ }
+
+ maxprob=p[seq.length-1][0];
+ int maxIdx=0;
+ for(int i=1;i<trans.length;i++){
+ if(p[seq.length-1][i]>maxprob){
+ maxprob=p[seq.length-1][i];
+ maxIdx=i;
+ }
+ }
+ int ans[]=new int [seq.length];
+ ans[seq.length-1]=maxIdx;
+ for(int i=seq.length-2;i>=0;i--){
+ ans[i]=backp[i+1][ans[i+1]];
+ }
+ return ans;
+ }
+
+ public double l1norm(double a[]){
+ double norm=0;
+ for(int i=0;i<a.length;i++){
+ norm += a[i];
+ }
+ return norm;
+ }
+
+ public double [][]getEmitProb(){
+ return emit;
+ }
+
+ public int [] sample(int terminalSym){
+ ArrayList<Integer > s=new ArrayList<Integer>();
+ int state=sample(pi);
+ int sym=sample(emit[state]);
+ while(sym!=terminalSym){
+ s.add(sym);
+ state=sample(trans[state]);
+ sym=sample(emit[state]);
+ }
+
+ int ans[]=new int [s.size()];
+ for(int i=0;i<ans.length;i++){
+ ans[i]=s.get(i);
+ }
+ return ans;
+ }
+
+ public int sample(double p[]){
+ double r=Math.random();
+ double sum=0;
+ for(int i=0;i<p.length;i++){
+ sum+=p[i];
+ if(sum>=r){
+ return i;
+ }
+ }
+ return p.length-1;
+ }
+
+ public void train(int tagdata[][]){
+ double trans_exp_cnt[][]=new double [trans.length][trans.length];
+ double emit_exp_cnt[][]=new double[trans.length][emit[0].length];
+ double start_exp_cnt[]=new double[trans.length];
+
+ for(int i=0;i<tagdata.length;i++){
+ start_exp_cnt[tagdata[i][0]]++;
+
+ for(int j=0;j<tagdata[i].length;j++){
+ if(j+1<tagdata[i].length){
+ trans_exp_cnt[ tagdata[i][j] ] [ tagdata[i][j+1] ]++;
+ }
+ emit_exp_cnt[tagdata[i][j]][data[i][j]]++;
+ }
+
+ }
+
+ //M
+ addOneSmooth(emit_exp_cnt);
+ for(int i=0;i<trans.length;i++){
+
+ //transition probs
+ double sum=0;
+ for(int j=0;j<trans.length;j++){
+ sum+=trans_exp_cnt[i][j];
+ }
+ if(sum==0){
+ sum=1;
+ }
+ for(int j=0;j<trans[i].length;j++){
+ trans[i][j]=trans_exp_cnt[i][j]/sum;
+ }
+
+ //emission probs
+
+ sum=0;
+ for(int j=0;j<emit[i].length;j++){
+ sum+=emit_exp_cnt[i][j];
+ }
+ if(sum==0){
+ sum=1;
+ }
+ for(int j=0;j<emit[i].length;j++){
+ emit[i][j]=emit_exp_cnt[i][j]/sum;
+ }
+
+
+ //initial probs
+ for(int j=0;j<pi.length;j++){
+ pi[j]=start_exp_cnt[j];
+ }
+ l1normalize(pi);
+ }
+ }
+
+ private void addOneSmooth(double a[][]){
+ for(int i=0;i<a.length;i++){
+ for(int j=0;j<a[i].length;j++){
+ a[i][j]+=0.01;
+ }
+ //l1normalize(a[i]);
+ }
+ }
+
+ public void PREM(){
+
+ o.optimizeWithProjectedGradientDescent();
+
+ double trans_exp_cnt[][]=new double [trans.length][trans.length];
+ double emit_exp_cnt[][]=new double[trans.length][emit[0].length];
+ double start_exp_cnt[]=new double[trans.length];
+
+ o.loglikelihood=0;
+ //E
+ for(int sentNum=0;sentNum<data.length;sentNum++){
+
+ double [][][] post=o.forwardBackward(sentNum);
+ incrementExpCnt(post, data[sentNum],
+ trans_exp_cnt,
+ emit_exp_cnt,
+ start_exp_cnt);
+
+
+ if(sentNum%100==0){
+ System.out.print(".");
+ }
+ if(sentNum%1000==0){
+ System.out.println(sentNum);
+ }
+
+ }
+
+ System.out.println("Log likelihood: "+o.getValue());
+
+ //M
+ addOneSmooth(emit_exp_cnt);
+ for(int i=0;i<trans.length;i++){
+
+ //transition probs
+ double sum=0;
+ for(int j=0;j<trans.length;j++){
+ sum+=trans_exp_cnt[i][j];
+ }
+ //avoid NAN
+ if(sum==0){
+ sum=1;
+ }
+ for(int j=0;j<trans[i].length;j++){
+ trans[i][j]=trans_exp_cnt[i][j]/sum;
+ }
+
+ //emission probs
+
+ sum=0;
+ for(int j=0;j<emit[i].length;j++){
+ sum+=emit_exp_cnt[i][j];
+ }
+ //avoid NAN
+ if(sum==0){
+ sum=1;
+ }
+ for(int j=0;j<emit[i].length;j++){
+ emit[i][j]=emit_exp_cnt[i][j]/sum;
+ }
+
+
+ //initial probs
+ for(int j=0;j<pi.length;j++){
+ pi[j]=start_exp_cnt[j];
+ }
+ l1normalize(pi);
+ }
+
+ }
+
+ public void computeMaxwt(double[][]maxwt, int[][] d){
+
+ for(int sentNum=0;sentNum<d.length;sentNum++){
+ double post[][][]=forwardBackward(d[sentNum]);
+
+ for(int n=0;n<post.length;n++){
+ for(int i=0;i<trans.length;i++){
+ double py=0;
+ for(int j=0;j<trans.length;j++){
+ py+=post[n][i][j];
+ }
+
+ if(py>maxwt[i][d[sentNum][n]]){
+ maxwt[i][d[sentNum][n]]=py;
+ }
+
+ }
+ }
+
+ //the last state
+ int len=post.length;
+ for(int i=0;i<trans.length;i++){
+ double py=0;
+ for(int j=0;j<trans.length;j++){
+ py+=post[len-2][j][i];
+ }
+
+ if(py>maxwt[i][d[sentNum][len-1]]){
+ maxwt[i][d[sentNum][len-1]]=py;
+ }
+
+ }
+
+ }
+
+ }
+
+}//end of class
diff --git a/gi/posterior-regularisation/prjava/src/hmm/HMMObjective.java b/gi/posterior-regularisation/prjava/src/hmm/HMMObjective.java new file mode 100644 index 00000000..551210c0 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/hmm/HMMObjective.java @@ -0,0 +1,348 @@ +package hmm;
+
+import gnu.trove.TIntArrayList;
+import optimization.gradientBasedMethods.ProjectedGradientDescent;
+import optimization.gradientBasedMethods.ProjectedObjective;
+import optimization.gradientBasedMethods.stats.OptimizerStats;
+import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;
+import optimization.linesearch.InterpolationPickFirstStep;
+import optimization.linesearch.LineSearchMethod;
+import optimization.projections.SimplexProjection;
+import optimization.stopCriteria.CompositeStopingCriteria;
+import optimization.stopCriteria.ProjectedGradientL2Norm;
+import optimization.stopCriteria.StopingCriteria;
+import optimization.stopCriteria.ValueDifference;
+
+public class HMMObjective extends ProjectedObjective{
+
+
+ private static final double GRAD_DIFF = 3;
+ public static double INIT_STEP_SIZE=10;
+ public static double VAL_DIFF=2000;
+
+ private HMM hmm;
+ double[] newPoint ;
+
+ //posterior[sent num][tok num][tag]=index into lambda
+ private int posteriorMap[][][];
+ //projection[word][tag].get(occurence)=index into lambda
+ private TIntArrayList projectionMap[][];
+
+ //Size of the simplex
+ public double scale=10;
+ private SimplexProjection projection;
+
+ private int wordFreq[];
+ private static int MIN_FREQ=3;
+ private int numWordsToProject=0;
+
+ private int n_param;
+
+ public double loglikelihood;
+
+ public HMMObjective(HMM h){
+ hmm=h;
+
+ countWords();
+ buildMap();
+
+ gradient=new double [n_param];
+ projection = new SimplexProjection(scale);
+ newPoint = new double[n_param];
+ setInitialParameters(new double[n_param]);
+
+ }
+
+ /**@brief counts word frequency in the corpus
+ *
+ */
+ private void countWords(){
+ wordFreq=new int [hmm.emit[0].length];
+ for(int i=0;i<hmm.data.length;i++){
+ for(int j=0;j<hmm.data[i].length;j++){
+ wordFreq[hmm.data[i][j]]++;
+ }
+ }
+ }
+
+ /**@brief build posterior and projection indices
+ *
+ */
+ private void buildMap(){
+ //number of sentences hidden states and words
+ int n_states=hmm.trans.length;
+ int n_words=hmm.emit[0].length;
+ int n_sents=hmm.data.length;
+
+ n_param=0;
+ posteriorMap=new int[n_sents][][];
+ projectionMap=new TIntArrayList[n_words][];
+ for(int sentNum=0;sentNum<n_sents;sentNum++){
+ int [] data=hmm.data[sentNum];
+ posteriorMap[sentNum]=new int[data.length][n_states];
+ numWordsToProject=0;
+ for(int i=0;i<data.length;i++){
+ int word=data[i];
+ for(int state=0;state<n_states;state++){
+ if(wordFreq[word]>MIN_FREQ){
+ if(projectionMap[word]==null){
+ projectionMap[word]=new TIntArrayList[n_states];
+ }
+
+ posteriorMap[sentNum][i][state]=n_param;
+ if(projectionMap[word][state]==null){
+ projectionMap[word][state]=new TIntArrayList();
+ numWordsToProject++;
+ }
+ projectionMap[word][state].add(n_param);
+ n_param++;
+ }else{
+
+ posteriorMap[sentNum][i][state]=-1;
+ }
+ }
+ }
+ }
+ }
+
+ @Override
+ public double[] projectPoint(double[] point) {
+ // TODO Auto-generated method stub
+ for(int i=0;i<projectionMap.length;i++){
+
+ if(projectionMap[i]==null){
+ //this word is not constrained
+ continue;
+ }
+
+ for(int j=0;j<projectionMap[i].length;j++){
+ TIntArrayList instances=projectionMap[i][j];
+ double[] toProject = new double[instances.size()];
+
+ for (int k = 0; k < toProject.length; k++) {
+ // System.out.print(instances.get(k) + " ");
+ toProject[k] = point[instances.get(k)];
+ }
+
+ projection.project(toProject);
+ for (int k = 0; k < toProject.length; k++) {
+ newPoint[instances.get(k)]=toProject[k];
+ }
+ }
+ }
+ return newPoint;
+ }
+
+ @Override
+ public double[] getGradient() {
+ // TODO Auto-generated method stub
+ gradientCalls++;
+ return gradient;
+ }
+
+ @Override
+ public double getValue() {
+ // TODO Auto-generated method stub
+ functionCalls++;
+ return loglikelihood;
+ }
+
+
+ @Override
+ public String toString() {
+ // TODO Auto-generated method stub
+ StringBuffer sb = new StringBuffer();
+ for (int i = 0; i < parameters.length; i++) {
+ sb.append(parameters[i]+" ");
+ if(i%100==0){
+ sb.append("\n");
+ }
+ }
+ sb.append("\n");
+ /*
+ for (int i = 0; i < gradient.length; i++) {
+ sb.append(gradient[i]+" ");
+ if(i%100==0){
+ sb.append("\n");
+ }
+ }
+ sb.append("\n");
+ */
+ return sb.toString();
+ }
+
+
+ /**
+ * @param seq
+ * @return posterior probability of each transition
+ */
+ public double [][][]forwardBackward(int sentNum){
+ int [] seq=hmm.data[sentNum];
+ int n_states=hmm.trans.length;
+ double a[][]=new double [seq.length][n_states];
+ double b[][]=new double [seq.length][n_states];
+
+ int len=seq.length;
+
+ boolean constrained=
+ (projectionMap[seq[0]]!=null);
+
+ //initialize the first step
+ for(int i=0;i<n_states;i++){
+ a[0][i]=hmm.emit[i][seq[0]]*hmm.pi[i];
+ if(constrained){
+ a[0][i]*=
+ Math.exp(- parameters[ posteriorMap[sentNum][0][i] ] );
+ }
+ b[len-1][i]=1;
+ }
+
+ loglikelihood+=Math.log(hmm.l1norm(a[0]));
+ hmm.l1normalize(a[0]);
+ hmm.l1normalize(b[len-1]);
+
+ //forward
+ for(int n=1;n<len;n++){
+
+ constrained=
+ (projectionMap[seq[n]]!=null);
+
+ for(int i=0;i<n_states;i++){
+ for(int j=0;j<n_states;j++){
+ a[n][i]+=hmm.trans[j][i]*a[n-1][j];
+ }
+ a[n][i]*=hmm.emit[i][seq[n]];
+
+ if(constrained){
+ a[n][i]*=
+ Math.exp(- parameters[ posteriorMap[sentNum][n][i] ] );
+ }
+
+ }
+ loglikelihood+=Math.log(hmm.l1norm(a[n]));
+ hmm.l1normalize(a[n]);
+ }
+
+ //temp variable for e^{-\lambda}
+ double factor=1;
+ //backward
+ for(int n=len-2;n>=0;n--){
+
+ constrained=
+ (projectionMap[seq[n+1]]!=null);
+
+ for(int i=0;i<n_states;i++){
+ for(int j=0;j<n_states;j++){
+
+ if(constrained){
+ factor=
+ Math.exp(- parameters[ posteriorMap[sentNum][n+1][j] ] );
+ }else{
+ factor=1;
+ }
+
+ b[n][i]+=hmm.trans[i][j]*b[n+1][j]*hmm.emit[j][seq[n+1]]*factor;
+
+ }
+ }
+ hmm.l1normalize(b[n]);
+ }
+
+ //expected transition
+ double p[][][]=new double [seq.length][n_states][n_states];
+ for(int n=0;n<len-1;n++){
+
+ constrained=
+ (projectionMap[seq[n+1]]!=null);
+
+ for(int i=0;i<n_states;i++){
+ for(int j=0;j<n_states;j++){
+
+ if(constrained){
+ factor=
+ Math.exp(- parameters[ posteriorMap[sentNum][n+1][j] ] );
+ }else{
+ factor=1;
+ }
+
+ p[n][i][j]=a[n][i]*hmm.trans[i][j]*
+ hmm.emit[j][seq[n+1]]*b[n+1][j]*factor;
+
+ }
+ }
+
+ hmm.l1normalize(p[n]);
+ }
+ return p;
+ }
+
+ public void optimizeWithProjectedGradientDescent(){
+ LineSearchMethod ls =
+ new ArmijoLineSearchMinimizationAlongProjectionArc
+ (new InterpolationPickFirstStep(INIT_STEP_SIZE));
+
+ OptimizerStats stats = new OptimizerStats();
+
+
+ ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
+ StopingCriteria stopGrad = new ProjectedGradientL2Norm(GRAD_DIFF);
+ StopingCriteria stopValue = new ValueDifference(VAL_DIFF);
+ CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
+ compositeStop.add(stopGrad);
+ compositeStop.add(stopValue);
+
+ optimizer.setMaxIterations(10);
+ updateFunction();
+ boolean succed = optimizer.optimize(this,stats,compositeStop);
+ System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
+ if(succed){
+ System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
+ }else{
+ System.out.println("Failed to optimize");
+ }
+ }
+
+ @Override
+ public void setParameters(double[] params) {
+ super.setParameters(params);
+ updateFunction();
+ }
+
+ private void updateFunction(){
+
+ updateCalls++;
+ loglikelihood=0;
+
+ for(int sentNum=0;sentNum<hmm.data.length;sentNum++){
+ double [][][]p=forwardBackward(sentNum);
+
+ for(int n=0;n<p.length-1;n++){
+ for(int i=0;i<p[n].length;i++){
+ if(projectionMap[hmm.data[sentNum][n]]!=null){
+ double posterior=0;
+ for(int j=0;j<p[n][i].length;j++){
+ posterior+=p[n][i][j];
+ }
+ gradient[posteriorMap[sentNum][n][i]]=-posterior;
+ }
+ }
+ }
+
+ //the last state
+ int n=p.length-2;
+ for(int i=0;i<p[n].length;i++){
+ if(projectionMap[hmm.data[sentNum][n+1]]!=null){
+
+ double posterior=0;
+ for(int j=0;j<p[n].length;j++){
+ posterior+=p[n][j][i];
+ }
+ gradient[posteriorMap[sentNum][n+1][i]]=-posterior;
+
+ }
+ }
+ }
+
+ }
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/hmm/POS.java b/gi/posterior-regularisation/prjava/src/hmm/POS.java new file mode 100644 index 00000000..722d38e2 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/hmm/POS.java @@ -0,0 +1,126 @@ +package hmm;
+
+import java.io.PrintStream;
+import java.util.HashMap;
+
+import data.Corpus;
+
+public class POS {
+
+ //public String trainFilename="../posdata/en_train.conll";
+ //public static String trainFilename="../posdata/small_train.txt";
+ public static String trainFilename="../posdata/en_test.conll";
+// public static String trainFilename="../posdata/trial1.txt";
+
+ public static String testFilename="../posdata/en_test.conll";
+ //public static String testFilename="../posdata/trial1.txt";
+
+ public static String predFilename="../posdata/en_test.predict.conll";
+ public static String modelFilename="../posdata/posModel.out";
+ public static final int ITER=20;
+ public static final int N_STATE=30;
+
+ public static void main(String[] args) {
+ //POS p=new POS();
+ //POS p=new POS(true);
+ PRPOS();
+ }
+
+
+ public POS(){
+ Corpus c= new Corpus(trainFilename);
+ //size of vocabulary +1 for unknown tokens
+ HMM hmm =new HMM(N_STATE, c.getVocabSize()+1,c.getAllData());
+ for(int i=0;i<ITER;i++){
+ System.out.println("Iter"+i);
+ hmm.EM();
+ if((i+1)%10==0){
+ hmm.writeModel(modelFilename+i);
+ }
+ }
+
+ hmm.writeModel(modelFilename);
+
+ Corpus test=new Corpus(testFilename,c.vocab);
+
+ PrintStream ps= io.FileUtil.openOutFile(predFilename);
+
+ int [][]data=test.getAllData();
+ for(int i=0;i<data.length;i++){
+ int []tag=hmm.viterbi(data[i]);
+ String sent[]=test.get(i);
+ for(int j=0;j<data[i].length;j++){
+ ps.println(sent[j]+"\t"+tag[j]);
+ }
+ ps.println();
+ }
+ ps.close();
+ }
+
+ //POS induction with L1/Linf constraints
+ public static void PRPOS(){
+ Corpus c= new Corpus(trainFilename);
+ //size of vocabulary +1 for unknown tokens
+ HMM hmm =new HMM(N_STATE, c.getVocabSize()+1,c.getAllData());
+ hmm.o=new HMMObjective(hmm);
+ for(int i=0;i<ITER;i++){
+ System.out.println("Iter: "+i);
+ hmm.PREM();
+ if((i+1)%10==0){
+ hmm.writeModel(modelFilename+i);
+ }
+ }
+
+ hmm.writeModel(modelFilename);
+
+ Corpus test=new Corpus(testFilename,c.vocab);
+
+ PrintStream ps= io.FileUtil.openOutFile(predFilename);
+
+ int [][]data=test.getAllData();
+ for(int i=0;i<data.length;i++){
+ int []tag=hmm.viterbi(data[i]);
+ String sent[]=test.get(i);
+ for(int j=0;j<data[i].length;j++){
+ ps.println(sent[j]+"\t"+tag[j]);
+ }
+ ps.println();
+ }
+ ps.close();
+ }
+
+
+ public POS(boolean supervised){
+ Corpus c= new Corpus(trainFilename);
+ //size of vocabulary +1 for unknown tokens
+ HMM hmm =new HMM(c.tagVocab.size() , c.getVocabSize()+1,c.getAllData());
+ hmm.train(c.getTagData());
+
+ hmm.writeModel(modelFilename);
+
+ Corpus test=new Corpus(testFilename,c.vocab);
+
+ HashMap<String, Integer>tagVocab=
+ (HashMap<String, Integer>) io.SerializedObjects.readSerializedObject(Corpus.tagalphaFilename);
+ String [] tagdict=new String [tagVocab.size()+1];
+ for(String key:tagVocab.keySet()){
+ tagdict[tagVocab.get(key)]=key;
+ }
+ tagdict[tagdict.length-1]=Corpus.UNK_TOK;
+
+ System.out.println(c.vocab.get("<e>"));
+
+ PrintStream ps= io.FileUtil.openOutFile(predFilename);
+
+ int [][]data=test.getAllData();
+ for(int i=0;i<data.length;i++){
+ int []tag=hmm.viterbi(data[i]);
+ String sent[]=test.get(i);
+ for(int j=0;j<data[i].length;j++){
+ ps.println(sent[j]+"\t"+tagdict[tag[j]]);
+ }
+ ps.println();
+ }
+ ps.close();
+ }
+}
|