summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/test')
-rw-r--r--gi/posterior-regularisation/prjava/src/test/CorpusTest.java60
-rw-r--r--gi/posterior-regularisation/prjava/src/test/HMMModelStats.java96
-rw-r--r--gi/posterior-regularisation/prjava/src/test/IntDoublePair.java23
-rw-r--r--gi/posterior-regularisation/prjava/src/test/X2y2WithConstraints.java131
4 files changed, 310 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/test/CorpusTest.java b/gi/posterior-regularisation/prjava/src/test/CorpusTest.java
new file mode 100644
index 00000000..b4c3041f
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/test/CorpusTest.java
@@ -0,0 +1,60 @@
+package test;
+
+import java.util.Arrays;
+import java.util.HashMap;
+
+import data.Corpus;
+import hmm.POS;
+
+public class CorpusTest {
+
+ public static void main(String[] args) {
+ Corpus c=new Corpus(POS.trainFilename);
+
+
+ int idx=30;
+
+
+ HashMap<String, Integer>vocab=
+ (HashMap<String, Integer>) io.SerializedObjects.readSerializedObject(Corpus.alphaFilename);
+
+ HashMap<String, Integer>tagVocab=
+ (HashMap<String, Integer>) io.SerializedObjects.readSerializedObject(Corpus.tagalphaFilename);
+
+
+ String [] dict=new String [vocab.size()+1];
+ for(String key:vocab.keySet()){
+ dict[vocab.get(key)]=key;
+ }
+ dict[dict.length-1]=Corpus.UNK_TOK;
+
+ String [] tagdict=new String [tagVocab.size()+1];
+ for(String key:tagVocab.keySet()){
+ tagdict[tagVocab.get(key)]=key;
+ }
+ tagdict[tagdict.length-1]=Corpus.UNK_TOK;
+
+ String[] sent=c.get(idx);
+ int []data=c.getInt(idx);
+
+
+ String []roundtrip=new String [sent.length];
+ for(int i=0;i<sent.length;i++){
+ roundtrip[i]=dict[data[i]];
+ }
+ System.out.println(Arrays.toString(sent));
+ System.out.println(Arrays.toString(roundtrip));
+
+ sent=c.tag.get(idx);
+ data=c.tagData.get(idx);
+
+
+ roundtrip=new String [sent.length];
+ for(int i=0;i<sent.length;i++){
+ roundtrip[i]=tagdict[data[i]];
+ }
+ System.out.println(Arrays.toString(sent));
+ System.out.println(Arrays.toString(roundtrip));
+ }
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/test/HMMModelStats.java b/gi/posterior-regularisation/prjava/src/test/HMMModelStats.java
new file mode 100644
index 00000000..26d7abec
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/test/HMMModelStats.java
@@ -0,0 +1,96 @@
+package test;
+
+import hmm.HMM;
+import hmm.POS;
+
+import java.io.PrintStream;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+
+import data.Corpus;
+
+public class HMMModelStats {
+
+ public static String modelFilename="../posdata/posModel.out";
+ public static String alphaFilename="../posdata/corpus.alphabet";
+ public static String statsFilename="../posdata/model.stats";
+
+ public static final int NUM_WORD=50;
+
+ public static String testFilename="../posdata/en_test.conll";
+
+ public static double [][]maxwt;
+
+ public static void main(String[] args) {
+ HashMap<String, Integer>vocab=
+ (HashMap<String, Integer>) io.SerializedObjects.readSerializedObject(alphaFilename);
+
+ Corpus test=new Corpus(testFilename,vocab);
+
+ String [] dict=new String [vocab.size()+1];
+ for(String key:vocab.keySet()){
+ dict[vocab.get(key)]=key;
+ }
+ dict[dict.length-1]=Corpus.UNK_TOK;
+
+ HMM hmm=new HMM();
+ hmm.readModel(modelFilename);
+
+
+
+ PrintStream ps=io.FileUtil.openOutFile(statsFilename);
+
+ double [][] emit=hmm.getEmitProb();
+ for(int i=0;i<emit.length;i++){
+ ArrayList<IntDoublePair>l=new ArrayList<IntDoublePair>();
+ for(int j=0;j<emit[i].length;j++){
+ l.add(new IntDoublePair(j,emit[i][j]));
+ }
+ Collections.sort(l);
+ ps.println(i);
+ for(int j=0;j<NUM_WORD;j++){
+ if(j>=dict.length){
+ break;
+ }
+ ps.print(dict[l.get(j).idx]+"\t");
+ if((1+j)%10==0){
+ ps.println();
+ }
+ }
+ ps.println("\n");
+ }
+
+ checkMaxwt(hmm,ps,test.getAllData());
+
+ int terminalSym=vocab.get(Corpus .END_SYM);
+ //sample 10 sentences
+ for(int i=0;i<10;i++){
+ int []sent=hmm.sample(terminalSym);
+ for(int j=0;j<sent.length;j++){
+ ps.print(dict[sent[j]]+"\t");
+ }
+ ps.println();
+ }
+
+ ps.close();
+
+ }
+
+ public static void checkMaxwt(HMM hmm,PrintStream ps,int [][]data){
+ double [][]emit=hmm.getEmitProb();
+ maxwt=new double[emit.length][emit[0].length];
+
+ hmm.computeMaxwt(maxwt,data);
+ double sum=0;
+ for(int i=0;i<maxwt.length;i++){
+ for(int j=0;j<maxwt.length;j++){
+ sum+=maxwt[i][j];
+ }
+ }
+
+ ps.println("max w t P(w_i|t)"+sum);
+
+ }
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/test/IntDoublePair.java b/gi/posterior-regularisation/prjava/src/test/IntDoublePair.java
new file mode 100644
index 00000000..3f9f0ad7
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/test/IntDoublePair.java
@@ -0,0 +1,23 @@
+package test;
+
+public class IntDoublePair implements Comparable{
+ double val;
+ int idx;
+ public int compareTo(Object o){
+ if(o instanceof IntDoublePair){
+ IntDoublePair pair=(IntDoublePair)o;
+ if(pair.val>val){
+ return 1;
+ }
+ if(pair.val<val){
+ return -1;
+ }
+ return 0;
+ }
+ return -1;
+ }
+ public IntDoublePair(int i,double v){
+ val=v;
+ idx=i;
+ }
+}
diff --git a/gi/posterior-regularisation/prjava/src/test/X2y2WithConstraints.java b/gi/posterior-regularisation/prjava/src/test/X2y2WithConstraints.java
new file mode 100644
index 00000000..9059a59e
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/test/X2y2WithConstraints.java
@@ -0,0 +1,131 @@
+package test;
+
+
+
+import optimization.gradientBasedMethods.ProjectedGradientDescent;
+import optimization.gradientBasedMethods.ProjectedObjective;
+import optimization.gradientBasedMethods.stats.OptimizerStats;
+import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;
+import optimization.linesearch.InterpolationPickFirstStep;
+import optimization.linesearch.LineSearchMethod;
+import optimization.projections.BoundsProjection;
+import optimization.projections.Projection;
+import optimization.projections.SimplexProjection;
+import optimization.stopCriteria.CompositeStopingCriteria;
+import optimization.stopCriteria.GradientL2Norm;
+import optimization.stopCriteria.ProjectedGradientL2Norm;
+import optimization.stopCriteria.StopingCriteria;
+import optimization.stopCriteria.ValueDifference;
+
+
+/**
+ * @author javg
+ *
+ *
+ *ax2+ b(y2 -displacement)
+ */
+public class X2y2WithConstraints extends ProjectedObjective{
+
+
+ double a, b;
+ double dx;
+ double dy;
+ Projection projection;
+
+
+ public X2y2WithConstraints(double a, double b, double[] params, double dx, double dy, Projection proj){
+ //projection = new BoundsProjection(0.2,Double.MAX_VALUE);
+ super();
+ projection = proj;
+ this.a = a;
+ this.b = b;
+ this.dx = dx;
+ this.dy = dy;
+ setInitialParameters(params);
+ System.out.println("Function " +a+"(x-"+dx+")^2 + "+b+"(y-"+dy+")^2");
+ System.out.println("Gradient " +(2*a)+"(x-"+dx+") ; "+(b*2)+"(y-"+dy+")");
+ printParameters();
+ projection.project(parameters);
+ printParameters();
+ gradient = new double[2];
+ }
+
+ public double getValue() {
+ functionCalls++;
+ return a*(parameters[0]-dx)*(parameters[0]-dx)+b*((parameters[1]-dy)*(parameters[1]-dy));
+ }
+
+ public double[] getGradient() {
+ if(gradient == null){
+ gradient = new double[2];
+ }
+ gradientCalls++;
+ gradient[0]=2*a*(parameters[0]-dx);
+ gradient[1]=2*b*(parameters[1]-dy);
+ return gradient;
+ }
+
+
+ public double[] projectPoint(double[] point) {
+ double[] newPoint = point.clone();
+ projection.project(newPoint);
+ return newPoint;
+ }
+
+ public void optimizeWithProjectedGradientDescent(LineSearchMethod ls, OptimizerStats stats, X2y2WithConstraints o){
+ ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
+ StopingCriteria stopGrad = new ProjectedGradientL2Norm(0.001);
+ StopingCriteria stopValue = new ValueDifference(0.001);
+ CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
+ compositeStop.add(stopGrad);
+ compositeStop.add(stopValue);
+
+ optimizer.setMaxIterations(5);
+ boolean succed = optimizer.optimize(o,stats,compositeStop);
+ System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
+ System.out.println("Solution: " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1]);
+ if(succed){
+ System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
+ }else{
+ System.out.println("Failed to optimize");
+ }
+ }
+
+
+
+ public String toString(){
+
+ return "P1: " + parameters[0] + " P2: " + parameters[1] + " value " + getValue() + " grad (" + getGradient()[0] + ":" + getGradient()[1]+")";
+ }
+
+ public static void main(String[] args) {
+ double a = 1;
+ double b=1;
+ double x0 = 0;
+ double y0 =1;
+ double dx = 0.5;
+ double dy = 0.2 ;
+ double [] parameters = new double[2];
+ parameters[0] = x0;
+ parameters[1] = y0;
+ X2y2WithConstraints o = new X2y2WithConstraints(a,b,parameters,dx,dy,
+ new SimplexProjection(0.5)
+ //new BoundsProjection(0.0,0.4)
+ );
+ System.out.println("Starting optimization " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1] + " a " + a + " b "+b );
+ o.setDebugLevel(4);
+
+ LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc(new InterpolationPickFirstStep(1));
+
+ OptimizerStats stats = new OptimizerStats();
+ o.optimizeWithProjectedGradientDescent(ls, stats, o);
+
+// o = new x2y2WithConstraints(a,b,x0,y0,dx,dy);
+// stats = new OptimizerStats();
+// o.optimizeWithSpectralProjectedGradientDescent(stats, o);
+ }
+
+
+
+
+}