summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation')
-rw-r--r--gi/posterior-regularisation/prjava/lib/optimization-0.1.jarbin0 -> 120451 bytes
-rw-r--r--gi/posterior-regularisation/prjava/lib/trove-2.0.2.jarbin0 -> 737844 bytes
-rw-r--r--gi/posterior-regularisation/prjava/src/data/Corpus.java230
-rw-r--r--gi/posterior-regularisation/prjava/src/hmm/HMM.java576
-rw-r--r--gi/posterior-regularisation/prjava/src/hmm/HMMObjective.java348
-rw-r--r--gi/posterior-regularisation/prjava/src/hmm/POS.java126
-rw-r--r--gi/posterior-regularisation/prjava/src/io/FileUtil.java37
-rw-r--r--gi/posterior-regularisation/prjava/src/io/SerializedObjects.java83
-rw-r--r--gi/posterior-regularisation/prjava/src/test/CorpusTest.java60
-rw-r--r--gi/posterior-regularisation/prjava/src/test/HMMModelStats.java96
-rw-r--r--gi/posterior-regularisation/prjava/src/test/IntDoublePair.java23
-rw-r--r--gi/posterior-regularisation/prjava/src/test/X2y2WithConstraints.java131
12 files changed, 1710 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/lib/optimization-0.1.jar b/gi/posterior-regularisation/prjava/lib/optimization-0.1.jar
new file mode 100644
index 00000000..2d14fce7
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/lib/optimization-0.1.jar
Binary files differ
diff --git a/gi/posterior-regularisation/prjava/lib/trove-2.0.2.jar b/gi/posterior-regularisation/prjava/lib/trove-2.0.2.jar
new file mode 100644
index 00000000..3e59fbf3
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/lib/trove-2.0.2.jar
Binary files differ
diff --git a/gi/posterior-regularisation/prjava/src/data/Corpus.java b/gi/posterior-regularisation/prjava/src/data/Corpus.java
new file mode 100644
index 00000000..f0da0b33
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/data/Corpus.java
@@ -0,0 +1,230 @@
+package data;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Scanner;
+
+public class Corpus {
+
+ public static final String alphaFilename="../posdata/corpus.alphabet";
+ public static final String tagalphaFilename="../posdata/corpus.tag.alphabet";
+
+// public static final String START_SYM="<s>";
+ public static final String END_SYM="<e>";
+ public static final String NUM_TOK="<NUM>";
+
+ public static final String UNK_TOK="<unk>";
+
+ private ArrayList<String[]>sent;
+ private ArrayList<int[]>data;
+
+ public ArrayList<String[]>tag;
+ public ArrayList<int[]>tagData;
+
+ public static boolean convertNumTok=true;
+
+ private HashMap<String,Integer>freq;
+ public HashMap<String,Integer>vocab;
+
+ public HashMap<String,Integer>tagVocab;
+ private int tagV;
+
+ private int V;
+
+ public static void main(String[] args) {
+ Corpus c=new Corpus("../posdata/en_test.conll");
+ System.out.println(
+ Arrays.toString(c.get(0))
+ );
+ System.out.println(
+ Arrays.toString(c.getInt(0))
+ );
+
+ System.out.println(
+ Arrays.toString(c.get(1))
+ );
+ System.out.println(
+ Arrays.toString(c.getInt(1))
+ );
+ }
+
+ public Corpus(String filename,HashMap<String,Integer>dict){
+ V=0;
+ tagV=0;
+ freq=new HashMap<String,Integer>();
+ tagVocab=new HashMap<String,Integer>();
+ vocab=dict;
+
+ sent=new ArrayList<String[]>();
+ tag=new ArrayList<String[]>();
+
+ Scanner sc=io.FileUtil.openInFile(filename);
+ ArrayList<String>s=new ArrayList<String>();
+ // s.add(START_SYM);
+ while(sc.hasNextLine()){
+ String line=sc.nextLine();
+ String toks[]=line.split("\t");
+ if(toks.length<2){
+ s.add(END_SYM);
+ sent.add(s.toArray(new String[0]));
+ s=new ArrayList<String>();
+ // s.add(START_SYM);
+ continue;
+ }
+ String tok=toks[1].toLowerCase();
+ s.add(tok);
+ }
+ sc.close();
+
+ buildData();
+ }
+
+ public Corpus(String filename){
+ V=0;
+ freq=new HashMap<String,Integer>();
+ vocab=new HashMap<String,Integer>();
+ tagVocab=new HashMap<String,Integer>();
+
+ sent=new ArrayList<String[]>();
+ tag=new ArrayList<String[]>();
+
+ System.out.println("Reading:"+filename);
+
+ Scanner sc=io.FileUtil.openInFile(filename);
+ ArrayList<String>s=new ArrayList<String>();
+ ArrayList<String>tags=new ArrayList<String>();
+ //s.add(START_SYM);
+ while(sc.hasNextLine()){
+ String line=sc.nextLine();
+ String toks[]=line.split("\t");
+ if(toks.length<2){
+ s.add(END_SYM);
+ tags.add(END_SYM);
+ if(s.size()>2){
+ sent.add(s.toArray(new String[0]));
+ tag.add(tags.toArray(new String [0]));
+ }
+ s=new ArrayList<String>();
+ tags=new ArrayList<String>();
+ // s.add(START_SYM);
+ continue;
+ }
+
+ String tok=toks[1].toLowerCase();
+ if(convertNumTok && tok.matches(".*\\d.*")){
+ tok=NUM_TOK;
+ }
+ s.add(tok);
+
+ if(toks.length>3){
+ tok=toks[3].toLowerCase();
+ }else{
+ tok="_";
+ }
+ tags.add(tok);
+
+ }
+ sc.close();
+
+ for(int i=0;i<sent.size();i++){
+ String[]toks=sent.get(i);
+ for(int j=0;j<toks.length;j++){
+ addVocab(toks[j]);
+ addTag(tag.get(i)[j]);
+ }
+ }
+
+ buildVocab();
+ buildData();
+ System.out.println(data.size()+"sentences, "+vocab.keySet().size()+" word types");
+ }
+
+ public String[] get(int idx){
+ return sent.get(idx);
+ }
+
+ private void addVocab(String s){
+ Integer integer=freq.get(s);
+ if(integer==null){
+ integer=0;
+ }
+ freq.put(s, integer+1);
+ }
+
+ public int tokIdx(String tok){
+ Integer integer=vocab.get(tok);
+ if(integer==null){
+ return V;
+ }
+ return integer;
+ }
+
+ public int tagIdx(String tok){
+ Integer integer=tagVocab.get(tok);
+ if(integer==null){
+ return tagV;
+ }
+ return integer;
+ }
+
+ private void buildData(){
+ data=new ArrayList<int[]>();
+ for(int i=0;i<sent.size();i++){
+ String s[]=sent.get(i);
+ data.add(new int [s.length]);
+ for(int j=0;j<s.length;j++){
+ data.get(i)[j]=tokIdx(s[j]);
+ }
+ }
+
+ tagData=new ArrayList<int[]>();
+ for(int i=0;i<tag.size();i++){
+ String s[]=tag.get(i);
+ tagData.add(new int [s.length]);
+ for(int j=0;j<s.length;j++){
+ tagData.get(i)[j]=tagIdx(s[j]);
+ }
+ }
+ }
+
+ public int [] getInt(int idx){
+ return data.get(idx);
+ }
+
+ /**
+ *
+ * @return size of vocabulary
+ */
+ public int getVocabSize(){
+ return V;
+ }
+
+ public int [][]getAllData(){
+ return data.toArray(new int [0][]);
+ }
+
+ public int [][]getTagData(){
+ return tagData.toArray(new int [0][]);
+ }
+
+ private void buildVocab(){
+ for (String key:freq.keySet()){
+ if(freq.get(key)>2){
+ vocab.put(key, V);
+ V++;
+ }
+ }
+ io.SerializedObjects.writeSerializedObject(vocab, alphaFilename);
+ io.SerializedObjects.writeSerializedObject(tagVocab,tagalphaFilename);
+ }
+
+ private void addTag(String tag){
+ Integer i=tagVocab.get(tag);
+ if(i==null){
+ tagVocab.put(tag, tagV);
+ tagV++;
+ }
+ }
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/hmm/HMM.java b/gi/posterior-regularisation/prjava/src/hmm/HMM.java
new file mode 100644
index 00000000..1c4d7659
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/hmm/HMM.java
@@ -0,0 +1,576 @@
+package hmm;
+
+import java.io.PrintStream;
+import java.util.ArrayList;
+import java.util.Scanner;
+
+public class HMM {
+
+
+ //trans[i][j]=prob of going FROM i to j
+ double [][]trans;
+ double [][]emit;
+ double []pi;
+ int [][]data;
+ int [][]tagdata;
+
+ double logtrans[][];
+
+ public HMMObjective o;
+
+ public static void main(String[] args) {
+
+ }
+
+ public HMM(int n_state,int n_emit,int [][]data){
+ trans=new double [n_state][n_state];
+ emit=new double[n_state][n_emit];
+ pi=new double [n_state];
+ System.out.println(" random initial parameters");
+ fillRand(trans);
+ fillRand(emit);
+ fillRand(pi);
+
+ this.data=data;
+
+ }
+
+ private void fillRand(double [][] a){
+ for(int i=0;i<a.length;i++){
+ for(int j=0;j<a[i].length;j++){
+ a[i][j]=Math.random();
+ }
+ l1normalize(a[i]);
+ }
+ }
+ private void fillRand(double []a){
+ for(int i=0;i<a.length;i++){
+ a[i]=Math.random();
+ }
+ l1normalize(a);
+ }
+
+ private double loglikely=0;
+
+ public void EM(){
+ double trans_exp_cnt[][]=new double [trans.length][trans.length];
+ double emit_exp_cnt[][]=new double[trans.length][emit[0].length];
+ double start_exp_cnt[]=new double[trans.length];
+ loglikely=0;
+
+ //E
+ for(int i=0;i<data.length;i++){
+
+ double [][][] post=forwardBackward(data[i]);
+ incrementExpCnt(post, data[i],
+ trans_exp_cnt,
+ emit_exp_cnt,
+ start_exp_cnt);
+
+
+ if(i%100==0){
+ System.out.print(".");
+ }
+ if(i%1000==0){
+ System.out.println(i);
+ }
+
+ }
+ System.out.println("Log likelihood: "+loglikely);
+
+ //M
+ addOneSmooth(emit_exp_cnt);
+ for(int i=0;i<trans.length;i++){
+
+ //transition probs
+ double sum=0;
+ for(int j=0;j<trans.length;j++){
+ sum+=trans_exp_cnt[i][j];
+ }
+ //avoid NAN
+ if(sum==0){
+ sum=1;
+ }
+ for(int j=0;j<trans[i].length;j++){
+ trans[i][j]=trans_exp_cnt[i][j]/sum;
+ }
+
+ //emission probs
+
+ sum=0;
+ for(int j=0;j<emit[i].length;j++){
+ sum+=emit_exp_cnt[i][j];
+ }
+ //avoid NAN
+ if(sum==0){
+ sum=1;
+ }
+ for(int j=0;j<emit[i].length;j++){
+ emit[i][j]=emit_exp_cnt[i][j]/sum;
+ }
+
+
+ //initial probs
+ for(int j=0;j<pi.length;j++){
+ pi[j]=start_exp_cnt[j];
+ }
+ l1normalize(pi);
+ }
+ }
+
+ private double [][][]forwardBackward(int [] seq){
+ double a[][]=new double [seq.length][trans.length];
+ double b[][]=new double [seq.length][trans.length];
+
+ int len=seq.length;
+ //initialize the first step
+ for(int i=0;i<trans.length;i++){
+ a[0][i]=emit[i][seq[0]]*pi[i];
+ b[len-1][i]=1;
+ }
+
+ //log of denominator for likelyhood
+ double c=Math.log(l1norm(a[0]));
+
+ l1normalize(a[0]);
+ l1normalize(b[len-1]);
+
+
+
+ //forward
+ for(int n=1;n<len;n++){
+ for(int i=0;i<trans.length;i++){
+ for(int j=0;j<trans.length;j++){
+ a[n][i]+=trans[j][i]*a[n-1][j];
+ }
+ a[n][i]*=emit[i][seq[n]];
+ }
+ c+=Math.log(l1norm(a[n]));
+ l1normalize(a[n]);
+ }
+
+ loglikely+=c;
+
+ //backward
+ for(int n=len-2;n>=0;n--){
+ for(int i=0;i<trans.length;i++){
+ for(int j=0;j<trans.length;j++){
+ b[n][i]+=trans[i][j]*b[n+1][j]*emit[j][seq[n+1]];
+ }
+ }
+ l1normalize(b[n]);
+ }
+
+
+ //expected transition
+ double p[][][]=new double [seq.length][trans.length][trans.length];
+ for(int n=0;n<len-1;n++){
+ for(int i=0;i<trans.length;i++){
+ for(int j=0;j<trans.length;j++){
+ p[n][i][j]=a[n][i]*trans[i][j]*emit[j][seq[n+1]]*b[n+1][j];
+
+ }
+ }
+
+ l1normalize(p[n]);
+ }
+ return p;
+ }
+
+ private void incrementExpCnt(
+ double post[][][],int [] seq,
+ double trans_exp_cnt[][],
+ double emit_exp_cnt[][],
+ double start_exp_cnt[])
+ {
+
+ for(int n=0;n<post.length;n++){
+ for(int i=0;i<trans.length;i++){
+ double py=0;
+ for(int j=0;j<trans.length;j++){
+ py+=post[n][i][j];
+ trans_exp_cnt[i][j]+=post[n][i][j];
+ }
+
+ emit_exp_cnt[i][seq[n]]+=py;
+
+ }
+ }
+
+ //the first state
+ for(int i=0;i<trans.length;i++){
+ double py=0;
+ for(int j=0;j<trans.length;j++){
+ py+=post[0][i][j];
+ }
+ start_exp_cnt[i]+=py;
+ }
+
+
+ //the last state
+ int len=post.length;
+ for(int i=0;i<trans.length;i++){
+ double py=0;
+ for(int j=0;j<trans.length;j++){
+ py+=post[len-2][j][i];
+ }
+ emit_exp_cnt[i][seq[len-1]]+=py;
+ }
+ }
+
+ public void l1normalize(double [] a){
+ double sum=0;
+ for(int i=0;i<a.length;i++){
+ sum+=a[i];
+ }
+ if(sum==0){
+ return ;
+ }
+ for(int i=0;i<a.length;i++){
+ a[i]/=sum;
+ }
+ }
+
+ public void l1normalize(double [][] a){
+ double sum=0;
+ for(int i=0;i<a.length;i++){
+ for(int j=0;j<a[i].length;j++){
+ sum+=a[i][j];
+ }
+ }
+ if(sum==0){
+ return;
+ }
+ for(int i=0;i<a.length;i++){
+ for(int j=0;j<a[i].length;j++){
+ a[i][j]/=sum;
+ }
+ }
+ }
+
+ public void writeModel(String modelFilename){
+ PrintStream ps=io.FileUtil.openOutFile(modelFilename);
+ ps.println(trans.length);
+ ps.println("Initial Probabilities:");
+ for(int i=0;i<pi.length;i++){
+ ps.print(pi[i]+"\t");
+ }
+ ps.println();
+ ps.println("Transition Probabilities:");
+ for(int i=0;i<trans.length;i++){
+ for(int j=0;j<trans[i].length;j++){
+ ps.print(trans[i][j]+"\t");
+ }
+ ps.println();
+ }
+ ps.println("Emission Probabilities:");
+ ps.println(emit[0].length);
+ for(int i=0;i<trans.length;i++){
+ for(int j=0;j<emit[i].length;j++){
+ ps.println(emit[i][j]);
+ }
+ ps.println();
+ }
+ ps.close();
+ }
+
+ public HMM(){
+
+ }
+
+ public void readModel(String modelFilename){
+ Scanner sc=io.FileUtil.openInFile(modelFilename);
+
+ int n_state=sc.nextInt();
+ sc.nextLine();
+ sc.nextLine();
+ pi=new double [n_state];
+ for(int i=0;i<n_state;i++){
+ pi[i]=sc.nextDouble();
+ }
+ sc.nextLine();
+ sc.nextLine();
+ trans=new double[n_state][n_state];
+ for(int i=0;i<trans.length;i++){
+ for(int j=0;j<trans[i].length;j++){
+ trans[i][j]=sc.nextDouble();
+ }
+ }
+ sc.nextLine();
+ sc.nextLine();
+
+ int n_obs=sc.nextInt();
+ emit=new double[n_state][n_obs];
+ for(int i=0;i<trans.length;i++){
+ for(int j=0;j<emit[i].length;j++){
+ emit[i][j]=sc.nextDouble();
+ }
+ }
+ sc.close();
+ }
+
+ public int []viterbi(int [] seq){
+ double [][]p=new double [seq.length][trans.length];
+ int backp[][]=new int [seq.length][trans.length];
+
+ for(int i=0;i<trans.length;i++){
+ p[0][i]=Math.log(emit[i][seq[0]]*pi[i]);
+ }
+
+ double a[][]=logtrans;
+ if(logtrans==null){
+ a=new double [trans.length][trans.length];
+ for(int i=0;i<trans.length;i++){
+ for(int j=0;j<trans.length;j++){
+ a[i][j]=Math.log(trans[i][j]);
+ }
+ }
+ logtrans=a;
+ }
+
+ double maxprob=0;
+ for(int n=1;n<seq.length;n++){
+ for(int i=0;i<trans.length;i++){
+ maxprob=p[n-1][0]+a[0][i];
+ backp[n][i]=0;
+ for(int j=1;j<trans.length;j++){
+ double prob=p[n-1][j]+a[j][i];
+ if(maxprob<prob){
+ backp[n][i]=j;
+ maxprob=prob;
+ }
+ }
+ p[n][i]=maxprob+Math.log(emit[i][seq[n]]);
+ }
+ }
+
+ maxprob=p[seq.length-1][0];
+ int maxIdx=0;
+ for(int i=1;i<trans.length;i++){
+ if(p[seq.length-1][i]>maxprob){
+ maxprob=p[seq.length-1][i];
+ maxIdx=i;
+ }
+ }
+ int ans[]=new int [seq.length];
+ ans[seq.length-1]=maxIdx;
+ for(int i=seq.length-2;i>=0;i--){
+ ans[i]=backp[i+1][ans[i+1]];
+ }
+ return ans;
+ }
+
+ public double l1norm(double a[]){
+ double norm=0;
+ for(int i=0;i<a.length;i++){
+ norm += a[i];
+ }
+ return norm;
+ }
+
+ public double [][]getEmitProb(){
+ return emit;
+ }
+
+ public int [] sample(int terminalSym){
+ ArrayList<Integer > s=new ArrayList<Integer>();
+ int state=sample(pi);
+ int sym=sample(emit[state]);
+ while(sym!=terminalSym){
+ s.add(sym);
+ state=sample(trans[state]);
+ sym=sample(emit[state]);
+ }
+
+ int ans[]=new int [s.size()];
+ for(int i=0;i<ans.length;i++){
+ ans[i]=s.get(i);
+ }
+ return ans;
+ }
+
+ public int sample(double p[]){
+ double r=Math.random();
+ double sum=0;
+ for(int i=0;i<p.length;i++){
+ sum+=p[i];
+ if(sum>=r){
+ return i;
+ }
+ }
+ return p.length-1;
+ }
+
+ public void train(int tagdata[][]){
+ double trans_exp_cnt[][]=new double [trans.length][trans.length];
+ double emit_exp_cnt[][]=new double[trans.length][emit[0].length];
+ double start_exp_cnt[]=new double[trans.length];
+
+ for(int i=0;i<tagdata.length;i++){
+ start_exp_cnt[tagdata[i][0]]++;
+
+ for(int j=0;j<tagdata[i].length;j++){
+ if(j+1<tagdata[i].length){
+ trans_exp_cnt[ tagdata[i][j] ] [ tagdata[i][j+1] ]++;
+ }
+ emit_exp_cnt[tagdata[i][j]][data[i][j]]++;
+ }
+
+ }
+
+ //M
+ addOneSmooth(emit_exp_cnt);
+ for(int i=0;i<trans.length;i++){
+
+ //transition probs
+ double sum=0;
+ for(int j=0;j<trans.length;j++){
+ sum+=trans_exp_cnt[i][j];
+ }
+ if(sum==0){
+ sum=1;
+ }
+ for(int j=0;j<trans[i].length;j++){
+ trans[i][j]=trans_exp_cnt[i][j]/sum;
+ }
+
+ //emission probs
+
+ sum=0;
+ for(int j=0;j<emit[i].length;j++){
+ sum+=emit_exp_cnt[i][j];
+ }
+ if(sum==0){
+ sum=1;
+ }
+ for(int j=0;j<emit[i].length;j++){
+ emit[i][j]=emit_exp_cnt[i][j]/sum;
+ }
+
+
+ //initial probs
+ for(int j=0;j<pi.length;j++){
+ pi[j]=start_exp_cnt[j];
+ }
+ l1normalize(pi);
+ }
+ }
+
+ private void addOneSmooth(double a[][]){
+ for(int i=0;i<a.length;i++){
+ for(int j=0;j<a[i].length;j++){
+ a[i][j]+=0.01;
+ }
+ //l1normalize(a[i]);
+ }
+ }
+
+ public void PREM(){
+
+ o.optimizeWithProjectedGradientDescent();
+
+ double trans_exp_cnt[][]=new double [trans.length][trans.length];
+ double emit_exp_cnt[][]=new double[trans.length][emit[0].length];
+ double start_exp_cnt[]=new double[trans.length];
+
+ o.loglikelihood=0;
+ //E
+ for(int sentNum=0;sentNum<data.length;sentNum++){
+
+ double [][][] post=o.forwardBackward(sentNum);
+ incrementExpCnt(post, data[sentNum],
+ trans_exp_cnt,
+ emit_exp_cnt,
+ start_exp_cnt);
+
+
+ if(sentNum%100==0){
+ System.out.print(".");
+ }
+ if(sentNum%1000==0){
+ System.out.println(sentNum);
+ }
+
+ }
+
+ System.out.println("Log likelihood: "+o.getValue());
+
+ //M
+ addOneSmooth(emit_exp_cnt);
+ for(int i=0;i<trans.length;i++){
+
+ //transition probs
+ double sum=0;
+ for(int j=0;j<trans.length;j++){
+ sum+=trans_exp_cnt[i][j];
+ }
+ //avoid NAN
+ if(sum==0){
+ sum=1;
+ }
+ for(int j=0;j<trans[i].length;j++){
+ trans[i][j]=trans_exp_cnt[i][j]/sum;
+ }
+
+ //emission probs
+
+ sum=0;
+ for(int j=0;j<emit[i].length;j++){
+ sum+=emit_exp_cnt[i][j];
+ }
+ //avoid NAN
+ if(sum==0){
+ sum=1;
+ }
+ for(int j=0;j<emit[i].length;j++){
+ emit[i][j]=emit_exp_cnt[i][j]/sum;
+ }
+
+
+ //initial probs
+ for(int j=0;j<pi.length;j++){
+ pi[j]=start_exp_cnt[j];
+ }
+ l1normalize(pi);
+ }
+
+ }
+
+ public void computeMaxwt(double[][]maxwt, int[][] d){
+
+ for(int sentNum=0;sentNum<d.length;sentNum++){
+ double post[][][]=forwardBackward(d[sentNum]);
+
+ for(int n=0;n<post.length;n++){
+ for(int i=0;i<trans.length;i++){
+ double py=0;
+ for(int j=0;j<trans.length;j++){
+ py+=post[n][i][j];
+ }
+
+ if(py>maxwt[i][d[sentNum][n]]){
+ maxwt[i][d[sentNum][n]]=py;
+ }
+
+ }
+ }
+
+ //the last state
+ int len=post.length;
+ for(int i=0;i<trans.length;i++){
+ double py=0;
+ for(int j=0;j<trans.length;j++){
+ py+=post[len-2][j][i];
+ }
+
+ if(py>maxwt[i][d[sentNum][len-1]]){
+ maxwt[i][d[sentNum][len-1]]=py;
+ }
+
+ }
+
+ }
+
+ }
+
+}//end of class
diff --git a/gi/posterior-regularisation/prjava/src/hmm/HMMObjective.java b/gi/posterior-regularisation/prjava/src/hmm/HMMObjective.java
new file mode 100644
index 00000000..551210c0
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/hmm/HMMObjective.java
@@ -0,0 +1,348 @@
+package hmm;
+
+import gnu.trove.TIntArrayList;
+import optimization.gradientBasedMethods.ProjectedGradientDescent;
+import optimization.gradientBasedMethods.ProjectedObjective;
+import optimization.gradientBasedMethods.stats.OptimizerStats;
+import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;
+import optimization.linesearch.InterpolationPickFirstStep;
+import optimization.linesearch.LineSearchMethod;
+import optimization.projections.SimplexProjection;
+import optimization.stopCriteria.CompositeStopingCriteria;
+import optimization.stopCriteria.ProjectedGradientL2Norm;
+import optimization.stopCriteria.StopingCriteria;
+import optimization.stopCriteria.ValueDifference;
+
+public class HMMObjective extends ProjectedObjective{
+
+
+ private static final double GRAD_DIFF = 3;
+ public static double INIT_STEP_SIZE=10;
+ public static double VAL_DIFF=2000;
+
+ private HMM hmm;
+ double[] newPoint ;
+
+ //posterior[sent num][tok num][tag]=index into lambda
+ private int posteriorMap[][][];
+ //projection[word][tag].get(occurence)=index into lambda
+ private TIntArrayList projectionMap[][];
+
+ //Size of the simplex
+ public double scale=10;
+ private SimplexProjection projection;
+
+ private int wordFreq[];
+ private static int MIN_FREQ=3;
+ private int numWordsToProject=0;
+
+ private int n_param;
+
+ public double loglikelihood;
+
+ public HMMObjective(HMM h){
+ hmm=h;
+
+ countWords();
+ buildMap();
+
+ gradient=new double [n_param];
+ projection = new SimplexProjection(scale);
+ newPoint = new double[n_param];
+ setInitialParameters(new double[n_param]);
+
+ }
+
+ /**@brief counts word frequency in the corpus
+ *
+ */
+ private void countWords(){
+ wordFreq=new int [hmm.emit[0].length];
+ for(int i=0;i<hmm.data.length;i++){
+ for(int j=0;j<hmm.data[i].length;j++){
+ wordFreq[hmm.data[i][j]]++;
+ }
+ }
+ }
+
+ /**@brief build posterior and projection indices
+ *
+ */
+ private void buildMap(){
+ //number of sentences hidden states and words
+ int n_states=hmm.trans.length;
+ int n_words=hmm.emit[0].length;
+ int n_sents=hmm.data.length;
+
+ n_param=0;
+ posteriorMap=new int[n_sents][][];
+ projectionMap=new TIntArrayList[n_words][];
+ for(int sentNum=0;sentNum<n_sents;sentNum++){
+ int [] data=hmm.data[sentNum];
+ posteriorMap[sentNum]=new int[data.length][n_states];
+ numWordsToProject=0;
+ for(int i=0;i<data.length;i++){
+ int word=data[i];
+ for(int state=0;state<n_states;state++){
+ if(wordFreq[word]>MIN_FREQ){
+ if(projectionMap[word]==null){
+ projectionMap[word]=new TIntArrayList[n_states];
+ }
+
+ posteriorMap[sentNum][i][state]=n_param;
+ if(projectionMap[word][state]==null){
+ projectionMap[word][state]=new TIntArrayList();
+ numWordsToProject++;
+ }
+ projectionMap[word][state].add(n_param);
+ n_param++;
+ }else{
+
+ posteriorMap[sentNum][i][state]=-1;
+ }
+ }
+ }
+ }
+ }
+
+ @Override
+ public double[] projectPoint(double[] point) {
+ // TODO Auto-generated method stub
+ for(int i=0;i<projectionMap.length;i++){
+
+ if(projectionMap[i]==null){
+ //this word is not constrained
+ continue;
+ }
+
+ for(int j=0;j<projectionMap[i].length;j++){
+ TIntArrayList instances=projectionMap[i][j];
+ double[] toProject = new double[instances.size()];
+
+ for (int k = 0; k < toProject.length; k++) {
+ // System.out.print(instances.get(k) + " ");
+ toProject[k] = point[instances.get(k)];
+ }
+
+ projection.project(toProject);
+ for (int k = 0; k < toProject.length; k++) {
+ newPoint[instances.get(k)]=toProject[k];
+ }
+ }
+ }
+ return newPoint;
+ }
+
+ @Override
+ public double[] getGradient() {
+ // TODO Auto-generated method stub
+ gradientCalls++;
+ return gradient;
+ }
+
+ @Override
+ public double getValue() {
+ // TODO Auto-generated method stub
+ functionCalls++;
+ return loglikelihood;
+ }
+
+
+ @Override
+ public String toString() {
+ // TODO Auto-generated method stub
+ StringBuffer sb = new StringBuffer();
+ for (int i = 0; i < parameters.length; i++) {
+ sb.append(parameters[i]+" ");
+ if(i%100==0){
+ sb.append("\n");
+ }
+ }
+ sb.append("\n");
+ /*
+ for (int i = 0; i < gradient.length; i++) {
+ sb.append(gradient[i]+" ");
+ if(i%100==0){
+ sb.append("\n");
+ }
+ }
+ sb.append("\n");
+ */
+ return sb.toString();
+ }
+
+
+ /**
+ * @param seq
+ * @return posterior probability of each transition
+ */
+ public double [][][]forwardBackward(int sentNum){
+ int [] seq=hmm.data[sentNum];
+ int n_states=hmm.trans.length;
+ double a[][]=new double [seq.length][n_states];
+ double b[][]=new double [seq.length][n_states];
+
+ int len=seq.length;
+
+ boolean constrained=
+ (projectionMap[seq[0]]!=null);
+
+ //initialize the first step
+ for(int i=0;i<n_states;i++){
+ a[0][i]=hmm.emit[i][seq[0]]*hmm.pi[i];
+ if(constrained){
+ a[0][i]*=
+ Math.exp(- parameters[ posteriorMap[sentNum][0][i] ] );
+ }
+ b[len-1][i]=1;
+ }
+
+ loglikelihood+=Math.log(hmm.l1norm(a[0]));
+ hmm.l1normalize(a[0]);
+ hmm.l1normalize(b[len-1]);
+
+ //forward
+ for(int n=1;n<len;n++){
+
+ constrained=
+ (projectionMap[seq[n]]!=null);
+
+ for(int i=0;i<n_states;i++){
+ for(int j=0;j<n_states;j++){
+ a[n][i]+=hmm.trans[j][i]*a[n-1][j];
+ }
+ a[n][i]*=hmm.emit[i][seq[n]];
+
+ if(constrained){
+ a[n][i]*=
+ Math.exp(- parameters[ posteriorMap[sentNum][n][i] ] );
+ }
+
+ }
+ loglikelihood+=Math.log(hmm.l1norm(a[n]));
+ hmm.l1normalize(a[n]);
+ }
+
+ //temp variable for e^{-\lambda}
+ double factor=1;
+ //backward
+ for(int n=len-2;n>=0;n--){
+
+ constrained=
+ (projectionMap[seq[n+1]]!=null);
+
+ for(int i=0;i<n_states;i++){
+ for(int j=0;j<n_states;j++){
+
+ if(constrained){
+ factor=
+ Math.exp(- parameters[ posteriorMap[sentNum][n+1][j] ] );
+ }else{
+ factor=1;
+ }
+
+ b[n][i]+=hmm.trans[i][j]*b[n+1][j]*hmm.emit[j][seq[n+1]]*factor;
+
+ }
+ }
+ hmm.l1normalize(b[n]);
+ }
+
+ //expected transition
+ double p[][][]=new double [seq.length][n_states][n_states];
+ for(int n=0;n<len-1;n++){
+
+ constrained=
+ (projectionMap[seq[n+1]]!=null);
+
+ for(int i=0;i<n_states;i++){
+ for(int j=0;j<n_states;j++){
+
+ if(constrained){
+ factor=
+ Math.exp(- parameters[ posteriorMap[sentNum][n+1][j] ] );
+ }else{
+ factor=1;
+ }
+
+ p[n][i][j]=a[n][i]*hmm.trans[i][j]*
+ hmm.emit[j][seq[n+1]]*b[n+1][j]*factor;
+
+ }
+ }
+
+ hmm.l1normalize(p[n]);
+ }
+ return p;
+ }
+
+ public void optimizeWithProjectedGradientDescent(){
+ LineSearchMethod ls =
+ new ArmijoLineSearchMinimizationAlongProjectionArc
+ (new InterpolationPickFirstStep(INIT_STEP_SIZE));
+
+ OptimizerStats stats = new OptimizerStats();
+
+
+ ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
+ StopingCriteria stopGrad = new ProjectedGradientL2Norm(GRAD_DIFF);
+ StopingCriteria stopValue = new ValueDifference(VAL_DIFF);
+ CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
+ compositeStop.add(stopGrad);
+ compositeStop.add(stopValue);
+
+ optimizer.setMaxIterations(10);
+ updateFunction();
+ boolean succed = optimizer.optimize(this,stats,compositeStop);
+ System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
+ if(succed){
+ System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
+ }else{
+ System.out.println("Failed to optimize");
+ }
+ }
+
+ @Override
+ public void setParameters(double[] params) {
+ super.setParameters(params);
+ updateFunction();
+ }
+
+ private void updateFunction(){
+
+ updateCalls++;
+ loglikelihood=0;
+
+ for(int sentNum=0;sentNum<hmm.data.length;sentNum++){
+ double [][][]p=forwardBackward(sentNum);
+
+ for(int n=0;n<p.length-1;n++){
+ for(int i=0;i<p[n].length;i++){
+ if(projectionMap[hmm.data[sentNum][n]]!=null){
+ double posterior=0;
+ for(int j=0;j<p[n][i].length;j++){
+ posterior+=p[n][i][j];
+ }
+ gradient[posteriorMap[sentNum][n][i]]=-posterior;
+ }
+ }
+ }
+
+ //the last state
+ int n=p.length-2;
+ for(int i=0;i<p[n].length;i++){
+ if(projectionMap[hmm.data[sentNum][n+1]]!=null){
+
+ double posterior=0;
+ for(int j=0;j<p[n].length;j++){
+ posterior+=p[n][j][i];
+ }
+ gradient[posteriorMap[sentNum][n+1][i]]=-posterior;
+
+ }
+ }
+ }
+
+ }
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/hmm/POS.java b/gi/posterior-regularisation/prjava/src/hmm/POS.java
new file mode 100644
index 00000000..722d38e2
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/hmm/POS.java
@@ -0,0 +1,126 @@
+package hmm;
+
+import java.io.PrintStream;
+import java.util.HashMap;
+
+import data.Corpus;
+
+public class POS {
+
+ //public String trainFilename="../posdata/en_train.conll";
+ //public static String trainFilename="../posdata/small_train.txt";
+ public static String trainFilename="../posdata/en_test.conll";
+// public static String trainFilename="../posdata/trial1.txt";
+
+ public static String testFilename="../posdata/en_test.conll";
+ //public static String testFilename="../posdata/trial1.txt";
+
+ public static String predFilename="../posdata/en_test.predict.conll";
+ public static String modelFilename="../posdata/posModel.out";
+ public static final int ITER=20;
+ public static final int N_STATE=30;
+
+ public static void main(String[] args) {
+ //POS p=new POS();
+ //POS p=new POS(true);
+ PRPOS();
+ }
+
+
+ public POS(){
+ Corpus c= new Corpus(trainFilename);
+ //size of vocabulary +1 for unknown tokens
+ HMM hmm =new HMM(N_STATE, c.getVocabSize()+1,c.getAllData());
+ for(int i=0;i<ITER;i++){
+ System.out.println("Iter"+i);
+ hmm.EM();
+ if((i+1)%10==0){
+ hmm.writeModel(modelFilename+i);
+ }
+ }
+
+ hmm.writeModel(modelFilename);
+
+ Corpus test=new Corpus(testFilename,c.vocab);
+
+ PrintStream ps= io.FileUtil.openOutFile(predFilename);
+
+ int [][]data=test.getAllData();
+ for(int i=0;i<data.length;i++){
+ int []tag=hmm.viterbi(data[i]);
+ String sent[]=test.get(i);
+ for(int j=0;j<data[i].length;j++){
+ ps.println(sent[j]+"\t"+tag[j]);
+ }
+ ps.println();
+ }
+ ps.close();
+ }
+
+ //POS induction with L1/Linf constraints
+ public static void PRPOS(){
+ Corpus c= new Corpus(trainFilename);
+ //size of vocabulary +1 for unknown tokens
+ HMM hmm =new HMM(N_STATE, c.getVocabSize()+1,c.getAllData());
+ hmm.o=new HMMObjective(hmm);
+ for(int i=0;i<ITER;i++){
+ System.out.println("Iter: "+i);
+ hmm.PREM();
+ if((i+1)%10==0){
+ hmm.writeModel(modelFilename+i);
+ }
+ }
+
+ hmm.writeModel(modelFilename);
+
+ Corpus test=new Corpus(testFilename,c.vocab);
+
+ PrintStream ps= io.FileUtil.openOutFile(predFilename);
+
+ int [][]data=test.getAllData();
+ for(int i=0;i<data.length;i++){
+ int []tag=hmm.viterbi(data[i]);
+ String sent[]=test.get(i);
+ for(int j=0;j<data[i].length;j++){
+ ps.println(sent[j]+"\t"+tag[j]);
+ }
+ ps.println();
+ }
+ ps.close();
+ }
+
+
+ public POS(boolean supervised){
+ Corpus c= new Corpus(trainFilename);
+ //size of vocabulary +1 for unknown tokens
+ HMM hmm =new HMM(c.tagVocab.size() , c.getVocabSize()+1,c.getAllData());
+ hmm.train(c.getTagData());
+
+ hmm.writeModel(modelFilename);
+
+ Corpus test=new Corpus(testFilename,c.vocab);
+
+ HashMap<String, Integer>tagVocab=
+ (HashMap<String, Integer>) io.SerializedObjects.readSerializedObject(Corpus.tagalphaFilename);
+ String [] tagdict=new String [tagVocab.size()+1];
+ for(String key:tagVocab.keySet()){
+ tagdict[tagVocab.get(key)]=key;
+ }
+ tagdict[tagdict.length-1]=Corpus.UNK_TOK;
+
+ System.out.println(c.vocab.get("<e>"));
+
+ PrintStream ps= io.FileUtil.openOutFile(predFilename);
+
+ int [][]data=test.getAllData();
+ for(int i=0;i<data.length;i++){
+ int []tag=hmm.viterbi(data[i]);
+ String sent[]=test.get(i);
+ for(int j=0;j<data[i].length;j++){
+ ps.println(sent[j]+"\t"+tagdict[tag[j]]);
+ }
+ ps.println();
+ }
+ ps.close();
+ }
+}
diff --git a/gi/posterior-regularisation/prjava/src/io/FileUtil.java b/gi/posterior-regularisation/prjava/src/io/FileUtil.java
new file mode 100644
index 00000000..bfa58213
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/io/FileUtil.java
@@ -0,0 +1,37 @@
+package io;
+import java.util.*;
+import java.io.*;
+public class FileUtil {
+ public static Scanner openInFile(String filename){
+ Scanner localsc=null;
+ try
+ {
+ localsc=new Scanner (new FileInputStream(filename));
+
+ }catch(IOException ioe){
+ System.out.println(ioe.getMessage());
+ }
+ return localsc;
+ }
+ public static PrintStream openOutFile(String filename){
+ PrintStream localps=null;
+ try
+ {
+ localps=new PrintStream (new FileOutputStream(filename));
+
+ }catch(IOException ioe){
+ System.out.println(ioe.getMessage());
+ }
+ return localps;
+ }
+ public static FileInputStream openInputStream(String infilename){
+ FileInputStream fis=null;
+ try {
+ fis =(new FileInputStream(infilename));
+
+ } catch (IOException ioe) {
+ System.out.println(ioe.getMessage());
+ }
+ return fis;
+ }
+}
diff --git a/gi/posterior-regularisation/prjava/src/io/SerializedObjects.java b/gi/posterior-regularisation/prjava/src/io/SerializedObjects.java
new file mode 100644
index 00000000..d1631b51
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/io/SerializedObjects.java
@@ -0,0 +1,83 @@
+package io;
+
+
+
+import java.io.BufferedInputStream;
+import java.io.BufferedOutputStream;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.ObjectInput;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutput;
+import java.io.ObjectOutputStream;
+import java.io.OutputStream;
+
+public class SerializedObjects
+{
+ public static void writeSerializedObject(Object object, String outFile)
+ {
+ ObjectOutput output = null;
+ try{
+ //use buffering
+ OutputStream file = new FileOutputStream(outFile);
+ OutputStream buffer = new BufferedOutputStream( file );
+ output = new ObjectOutputStream( buffer );
+ output.writeObject(object);
+ buffer.close();
+ file.close();
+ }
+ catch(IOException ex){
+ ex.printStackTrace();
+ }
+ finally{
+ try {
+ if (output != null) {
+ //flush and close "output" and its underlying streams
+ output.close();
+ }
+ }
+ catch (IOException ex ){
+ ex.printStackTrace();
+ }
+ }
+ }
+
+ public static Object readSerializedObject(String inputFile)
+ {
+ ObjectInput input = null;
+ Object recoveredObject=null;
+ try{
+ //use buffering
+ InputStream file = new FileInputStream(inputFile);
+ InputStream buffer = new BufferedInputStream(file);
+ input = new ObjectInputStream(buffer);
+ //deserialize the List
+ recoveredObject = input.readObject();
+ }
+ catch(IOException ex){
+ ex.printStackTrace();
+ }
+ catch (ClassNotFoundException ex){
+ ex.printStackTrace();
+ }
+ catch(Exception ex)
+ {
+ ex.printStackTrace();
+ }
+ finally{
+ try {
+ if ( input != null ) {
+ //close "input" and its underlying streams
+ input.close();
+ }
+ }
+ catch (IOException ex){
+ ex.printStackTrace();
+ }
+ }
+ return recoveredObject;
+ }
+
+} \ No newline at end of file
diff --git a/gi/posterior-regularisation/prjava/src/test/CorpusTest.java b/gi/posterior-regularisation/prjava/src/test/CorpusTest.java
new file mode 100644
index 00000000..b4c3041f
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/test/CorpusTest.java
@@ -0,0 +1,60 @@
+package test;
+
+import java.util.Arrays;
+import java.util.HashMap;
+
+import data.Corpus;
+import hmm.POS;
+
+public class CorpusTest {
+
+ public static void main(String[] args) {
+ Corpus c=new Corpus(POS.trainFilename);
+
+
+ int idx=30;
+
+
+ HashMap<String, Integer>vocab=
+ (HashMap<String, Integer>) io.SerializedObjects.readSerializedObject(Corpus.alphaFilename);
+
+ HashMap<String, Integer>tagVocab=
+ (HashMap<String, Integer>) io.SerializedObjects.readSerializedObject(Corpus.tagalphaFilename);
+
+
+ String [] dict=new String [vocab.size()+1];
+ for(String key:vocab.keySet()){
+ dict[vocab.get(key)]=key;
+ }
+ dict[dict.length-1]=Corpus.UNK_TOK;
+
+ String [] tagdict=new String [tagVocab.size()+1];
+ for(String key:tagVocab.keySet()){
+ tagdict[tagVocab.get(key)]=key;
+ }
+ tagdict[tagdict.length-1]=Corpus.UNK_TOK;
+
+ String[] sent=c.get(idx);
+ int []data=c.getInt(idx);
+
+
+ String []roundtrip=new String [sent.length];
+ for(int i=0;i<sent.length;i++){
+ roundtrip[i]=dict[data[i]];
+ }
+ System.out.println(Arrays.toString(sent));
+ System.out.println(Arrays.toString(roundtrip));
+
+ sent=c.tag.get(idx);
+ data=c.tagData.get(idx);
+
+
+ roundtrip=new String [sent.length];
+ for(int i=0;i<sent.length;i++){
+ roundtrip[i]=tagdict[data[i]];
+ }
+ System.out.println(Arrays.toString(sent));
+ System.out.println(Arrays.toString(roundtrip));
+ }
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/test/HMMModelStats.java b/gi/posterior-regularisation/prjava/src/test/HMMModelStats.java
new file mode 100644
index 00000000..26d7abec
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/test/HMMModelStats.java
@@ -0,0 +1,96 @@
+package test;
+
+import hmm.HMM;
+import hmm.POS;
+
+import java.io.PrintStream;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+
+import data.Corpus;
+
+public class HMMModelStats {
+
+ public static String modelFilename="../posdata/posModel.out";
+ public static String alphaFilename="../posdata/corpus.alphabet";
+ public static String statsFilename="../posdata/model.stats";
+
+ public static final int NUM_WORD=50;
+
+ public static String testFilename="../posdata/en_test.conll";
+
+ public static double [][]maxwt;
+
+ public static void main(String[] args) {
+ HashMap<String, Integer>vocab=
+ (HashMap<String, Integer>) io.SerializedObjects.readSerializedObject(alphaFilename);
+
+ Corpus test=new Corpus(testFilename,vocab);
+
+ String [] dict=new String [vocab.size()+1];
+ for(String key:vocab.keySet()){
+ dict[vocab.get(key)]=key;
+ }
+ dict[dict.length-1]=Corpus.UNK_TOK;
+
+ HMM hmm=new HMM();
+ hmm.readModel(modelFilename);
+
+
+
+ PrintStream ps=io.FileUtil.openOutFile(statsFilename);
+
+ double [][] emit=hmm.getEmitProb();
+ for(int i=0;i<emit.length;i++){
+ ArrayList<IntDoublePair>l=new ArrayList<IntDoublePair>();
+ for(int j=0;j<emit[i].length;j++){
+ l.add(new IntDoublePair(j,emit[i][j]));
+ }
+ Collections.sort(l);
+ ps.println(i);
+ for(int j=0;j<NUM_WORD;j++){
+ if(j>=dict.length){
+ break;
+ }
+ ps.print(dict[l.get(j).idx]+"\t");
+ if((1+j)%10==0){
+ ps.println();
+ }
+ }
+ ps.println("\n");
+ }
+
+ checkMaxwt(hmm,ps,test.getAllData());
+
+ int terminalSym=vocab.get(Corpus .END_SYM);
+ //sample 10 sentences
+ for(int i=0;i<10;i++){
+ int []sent=hmm.sample(terminalSym);
+ for(int j=0;j<sent.length;j++){
+ ps.print(dict[sent[j]]+"\t");
+ }
+ ps.println();
+ }
+
+ ps.close();
+
+ }
+
+ public static void checkMaxwt(HMM hmm,PrintStream ps,int [][]data){
+ double [][]emit=hmm.getEmitProb();
+ maxwt=new double[emit.length][emit[0].length];
+
+ hmm.computeMaxwt(maxwt,data);
+ double sum=0;
+ for(int i=0;i<maxwt.length;i++){
+ for(int j=0;j<maxwt.length;j++){
+ sum+=maxwt[i][j];
+ }
+ }
+
+ ps.println("max w t P(w_i|t)"+sum);
+
+ }
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/test/IntDoublePair.java b/gi/posterior-regularisation/prjava/src/test/IntDoublePair.java
new file mode 100644
index 00000000..3f9f0ad7
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/test/IntDoublePair.java
@@ -0,0 +1,23 @@
+package test;
+
+public class IntDoublePair implements Comparable{
+ double val;
+ int idx;
+ public int compareTo(Object o){
+ if(o instanceof IntDoublePair){
+ IntDoublePair pair=(IntDoublePair)o;
+ if(pair.val>val){
+ return 1;
+ }
+ if(pair.val<val){
+ return -1;
+ }
+ return 0;
+ }
+ return -1;
+ }
+ public IntDoublePair(int i,double v){
+ val=v;
+ idx=i;
+ }
+}
diff --git a/gi/posterior-regularisation/prjava/src/test/X2y2WithConstraints.java b/gi/posterior-regularisation/prjava/src/test/X2y2WithConstraints.java
new file mode 100644
index 00000000..9059a59e
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/test/X2y2WithConstraints.java
@@ -0,0 +1,131 @@
+package test;
+
+
+
+import optimization.gradientBasedMethods.ProjectedGradientDescent;
+import optimization.gradientBasedMethods.ProjectedObjective;
+import optimization.gradientBasedMethods.stats.OptimizerStats;
+import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;
+import optimization.linesearch.InterpolationPickFirstStep;
+import optimization.linesearch.LineSearchMethod;
+import optimization.projections.BoundsProjection;
+import optimization.projections.Projection;
+import optimization.projections.SimplexProjection;
+import optimization.stopCriteria.CompositeStopingCriteria;
+import optimization.stopCriteria.GradientL2Norm;
+import optimization.stopCriteria.ProjectedGradientL2Norm;
+import optimization.stopCriteria.StopingCriteria;
+import optimization.stopCriteria.ValueDifference;
+
+
+/**
+ * @author javg
+ *
+ *
+ *ax2+ b(y2 -displacement)
+ */
+public class X2y2WithConstraints extends ProjectedObjective{
+
+
+ double a, b;
+ double dx;
+ double dy;
+ Projection projection;
+
+
+ public X2y2WithConstraints(double a, double b, double[] params, double dx, double dy, Projection proj){
+ //projection = new BoundsProjection(0.2,Double.MAX_VALUE);
+ super();
+ projection = proj;
+ this.a = a;
+ this.b = b;
+ this.dx = dx;
+ this.dy = dy;
+ setInitialParameters(params);
+ System.out.println("Function " +a+"(x-"+dx+")^2 + "+b+"(y-"+dy+")^2");
+ System.out.println("Gradient " +(2*a)+"(x-"+dx+") ; "+(b*2)+"(y-"+dy+")");
+ printParameters();
+ projection.project(parameters);
+ printParameters();
+ gradient = new double[2];
+ }
+
+ public double getValue() {
+ functionCalls++;
+ return a*(parameters[0]-dx)*(parameters[0]-dx)+b*((parameters[1]-dy)*(parameters[1]-dy));
+ }
+
+ public double[] getGradient() {
+ if(gradient == null){
+ gradient = new double[2];
+ }
+ gradientCalls++;
+ gradient[0]=2*a*(parameters[0]-dx);
+ gradient[1]=2*b*(parameters[1]-dy);
+ return gradient;
+ }
+
+
+ public double[] projectPoint(double[] point) {
+ double[] newPoint = point.clone();
+ projection.project(newPoint);
+ return newPoint;
+ }
+
+ public void optimizeWithProjectedGradientDescent(LineSearchMethod ls, OptimizerStats stats, X2y2WithConstraints o){
+ ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
+ StopingCriteria stopGrad = new ProjectedGradientL2Norm(0.001);
+ StopingCriteria stopValue = new ValueDifference(0.001);
+ CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
+ compositeStop.add(stopGrad);
+ compositeStop.add(stopValue);
+
+ optimizer.setMaxIterations(5);
+ boolean succed = optimizer.optimize(o,stats,compositeStop);
+ System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
+ System.out.println("Solution: " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1]);
+ if(succed){
+ System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
+ }else{
+ System.out.println("Failed to optimize");
+ }
+ }
+
+
+
+ public String toString(){
+
+ return "P1: " + parameters[0] + " P2: " + parameters[1] + " value " + getValue() + " grad (" + getGradient()[0] + ":" + getGradient()[1]+")";
+ }
+
+ public static void main(String[] args) {
+ double a = 1;
+ double b=1;
+ double x0 = 0;
+ double y0 =1;
+ double dx = 0.5;
+ double dy = 0.2 ;
+ double [] parameters = new double[2];
+ parameters[0] = x0;
+ parameters[1] = y0;
+ X2y2WithConstraints o = new X2y2WithConstraints(a,b,parameters,dx,dy,
+ new SimplexProjection(0.5)
+ //new BoundsProjection(0.0,0.4)
+ );
+ System.out.println("Starting optimization " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1] + " a " + a + " b "+b );
+ o.setDebugLevel(4);
+
+ LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc(new InterpolationPickFirstStep(1));
+
+ OptimizerStats stats = new OptimizerStats();
+ o.optimizeWithProjectedGradientDescent(ls, stats, o);
+
+// o = new x2y2WithConstraints(a,b,x0,y0,dx,dy);
+// stats = new OptimizerStats();
+// o.optimizeWithSpectralProjectedGradientDescent(stats, o);
+ }
+
+
+
+
+}