summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/test
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-10-11 14:06:32 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-10-11 14:06:32 -0400
commit9339c80d465545aec5a6dccfef7c83ca715bf11f (patch)
tree64c56d558331edad1db3832018c80e799551c39a /gi/posterior-regularisation/prjava/src/test
parent438dac41810b7c69fa10203ac5130d20efa2da9f (diff)
parentafd7da3b2338661657ad0c4e9eec681e014d37bf (diff)
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/test')
-rw-r--r--gi/posterior-regularisation/prjava/src/test/CorpusTest.java60
-rw-r--r--gi/posterior-regularisation/prjava/src/test/HMMModelStats.java105
-rw-r--r--gi/posterior-regularisation/prjava/src/test/IntDoublePair.java23
-rw-r--r--gi/posterior-regularisation/prjava/src/test/X2y2WithConstraints.java131
4 files changed, 0 insertions, 319 deletions
diff --git a/gi/posterior-regularisation/prjava/src/test/CorpusTest.java b/gi/posterior-regularisation/prjava/src/test/CorpusTest.java
deleted file mode 100644
index b4c3041f..00000000
--- a/gi/posterior-regularisation/prjava/src/test/CorpusTest.java
+++ /dev/null
@@ -1,60 +0,0 @@
-package test;
-
-import java.util.Arrays;
-import java.util.HashMap;
-
-import data.Corpus;
-import hmm.POS;
-
-public class CorpusTest {
-
- public static void main(String[] args) {
- Corpus c=new Corpus(POS.trainFilename);
-
-
- int idx=30;
-
-
- HashMap<String, Integer>vocab=
- (HashMap<String, Integer>) io.SerializedObjects.readSerializedObject(Corpus.alphaFilename);
-
- HashMap<String, Integer>tagVocab=
- (HashMap<String, Integer>) io.SerializedObjects.readSerializedObject(Corpus.tagalphaFilename);
-
-
- String [] dict=new String [vocab.size()+1];
- for(String key:vocab.keySet()){
- dict[vocab.get(key)]=key;
- }
- dict[dict.length-1]=Corpus.UNK_TOK;
-
- String [] tagdict=new String [tagVocab.size()+1];
- for(String key:tagVocab.keySet()){
- tagdict[tagVocab.get(key)]=key;
- }
- tagdict[tagdict.length-1]=Corpus.UNK_TOK;
-
- String[] sent=c.get(idx);
- int []data=c.getInt(idx);
-
-
- String []roundtrip=new String [sent.length];
- for(int i=0;i<sent.length;i++){
- roundtrip[i]=dict[data[i]];
- }
- System.out.println(Arrays.toString(sent));
- System.out.println(Arrays.toString(roundtrip));
-
- sent=c.tag.get(idx);
- data=c.tagData.get(idx);
-
-
- roundtrip=new String [sent.length];
- for(int i=0;i<sent.length;i++){
- roundtrip[i]=tagdict[data[i]];
- }
- System.out.println(Arrays.toString(sent));
- System.out.println(Arrays.toString(roundtrip));
- }
-
-}
diff --git a/gi/posterior-regularisation/prjava/src/test/HMMModelStats.java b/gi/posterior-regularisation/prjava/src/test/HMMModelStats.java
deleted file mode 100644
index d54525c8..00000000
--- a/gi/posterior-regularisation/prjava/src/test/HMMModelStats.java
+++ /dev/null
@@ -1,105 +0,0 @@
-package test;
-
-import hmm.HMM;
-import hmm.POS;
-
-import java.io.File;
-import java.io.FileNotFoundException;
-import java.io.IOException;
-import java.io.PrintStream;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
-
-import data.Corpus;
-
-public class HMMModelStats {
-
- public static String modelFilename="../posdata/posModel.out";
- public static String alphaFilename="../posdata/corpus.alphabet";
- public static String statsFilename="../posdata/model.stats";
-
- public static final int NUM_WORD=50;
-
- public static String testFilename="../posdata/en_test.conll";
-
- public static double [][]maxwt;
-
- public static void main(String[] args) {
- HashMap<String, Integer>vocab=
- (HashMap<String, Integer>) io.SerializedObjects.readSerializedObject(alphaFilename);
-
- Corpus test=new Corpus(testFilename,vocab);
-
- String [] dict=new String [vocab.size()+1];
- for(String key:vocab.keySet()){
- dict[vocab.get(key)]=key;
- }
- dict[dict.length-1]=Corpus.UNK_TOK;
-
- HMM hmm=new HMM();
- hmm.readModel(modelFilename);
-
-
-
- PrintStream ps = null;
- try {
- ps = io.FileUtil.printstream(new File(statsFilename));
- } catch (IOException e) {
- e.printStackTrace();
- System.exit(1);
- }
-
- double [][] emit=hmm.getEmitProb();
- for(int i=0;i<emit.length;i++){
- ArrayList<IntDoublePair>l=new ArrayList<IntDoublePair>();
- for(int j=0;j<emit[i].length;j++){
- l.add(new IntDoublePair(j,emit[i][j]));
- }
- Collections.sort(l);
- ps.println(i);
- for(int j=0;j<NUM_WORD;j++){
- if(j>=dict.length){
- break;
- }
- ps.print(dict[l.get(j).idx]+"\t");
- if((1+j)%10==0){
- ps.println();
- }
- }
- ps.println("\n");
- }
-
- checkMaxwt(hmm,ps,test.getAllData());
-
- int terminalSym=vocab.get(Corpus .END_SYM);
- //sample 10 sentences
- for(int i=0;i<10;i++){
- int []sent=hmm.sample(terminalSym);
- for(int j=0;j<sent.length;j++){
- ps.print(dict[sent[j]]+"\t");
- }
- ps.println();
- }
-
- ps.close();
-
- }
-
- public static void checkMaxwt(HMM hmm,PrintStream ps,int [][]data){
- double [][]emit=hmm.getEmitProb();
- maxwt=new double[emit.length][emit[0].length];
-
- hmm.computeMaxwt(maxwt,data);
- double sum=0;
- for(int i=0;i<maxwt.length;i++){
- for(int j=0;j<maxwt.length;j++){
- sum+=maxwt[i][j];
- }
- }
-
- ps.println("max w t P(w_i|t): "+sum);
-
- }
-
-}
diff --git a/gi/posterior-regularisation/prjava/src/test/IntDoublePair.java b/gi/posterior-regularisation/prjava/src/test/IntDoublePair.java
deleted file mode 100644
index 3f9f0ad7..00000000
--- a/gi/posterior-regularisation/prjava/src/test/IntDoublePair.java
+++ /dev/null
@@ -1,23 +0,0 @@
-package test;
-
-public class IntDoublePair implements Comparable{
- double val;
- int idx;
- public int compareTo(Object o){
- if(o instanceof IntDoublePair){
- IntDoublePair pair=(IntDoublePair)o;
- if(pair.val>val){
- return 1;
- }
- if(pair.val<val){
- return -1;
- }
- return 0;
- }
- return -1;
- }
- public IntDoublePair(int i,double v){
- val=v;
- idx=i;
- }
-}
diff --git a/gi/posterior-regularisation/prjava/src/test/X2y2WithConstraints.java b/gi/posterior-regularisation/prjava/src/test/X2y2WithConstraints.java
deleted file mode 100644
index 9059a59e..00000000
--- a/gi/posterior-regularisation/prjava/src/test/X2y2WithConstraints.java
+++ /dev/null
@@ -1,131 +0,0 @@
-package test;
-
-
-
-import optimization.gradientBasedMethods.ProjectedGradientDescent;
-import optimization.gradientBasedMethods.ProjectedObjective;
-import optimization.gradientBasedMethods.stats.OptimizerStats;
-import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;
-import optimization.linesearch.InterpolationPickFirstStep;
-import optimization.linesearch.LineSearchMethod;
-import optimization.projections.BoundsProjection;
-import optimization.projections.Projection;
-import optimization.projections.SimplexProjection;
-import optimization.stopCriteria.CompositeStopingCriteria;
-import optimization.stopCriteria.GradientL2Norm;
-import optimization.stopCriteria.ProjectedGradientL2Norm;
-import optimization.stopCriteria.StopingCriteria;
-import optimization.stopCriteria.ValueDifference;
-
-
-/**
- * @author javg
- *
- *
- *ax2+ b(y2 -displacement)
- */
-public class X2y2WithConstraints extends ProjectedObjective{
-
-
- double a, b;
- double dx;
- double dy;
- Projection projection;
-
-
- public X2y2WithConstraints(double a, double b, double[] params, double dx, double dy, Projection proj){
- //projection = new BoundsProjection(0.2,Double.MAX_VALUE);
- super();
- projection = proj;
- this.a = a;
- this.b = b;
- this.dx = dx;
- this.dy = dy;
- setInitialParameters(params);
- System.out.println("Function " +a+"(x-"+dx+")^2 + "+b+"(y-"+dy+")^2");
- System.out.println("Gradient " +(2*a)+"(x-"+dx+") ; "+(b*2)+"(y-"+dy+")");
- printParameters();
- projection.project(parameters);
- printParameters();
- gradient = new double[2];
- }
-
- public double getValue() {
- functionCalls++;
- return a*(parameters[0]-dx)*(parameters[0]-dx)+b*((parameters[1]-dy)*(parameters[1]-dy));
- }
-
- public double[] getGradient() {
- if(gradient == null){
- gradient = new double[2];
- }
- gradientCalls++;
- gradient[0]=2*a*(parameters[0]-dx);
- gradient[1]=2*b*(parameters[1]-dy);
- return gradient;
- }
-
-
- public double[] projectPoint(double[] point) {
- double[] newPoint = point.clone();
- projection.project(newPoint);
- return newPoint;
- }
-
- public void optimizeWithProjectedGradientDescent(LineSearchMethod ls, OptimizerStats stats, X2y2WithConstraints o){
- ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
- StopingCriteria stopGrad = new ProjectedGradientL2Norm(0.001);
- StopingCriteria stopValue = new ValueDifference(0.001);
- CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
- compositeStop.add(stopGrad);
- compositeStop.add(stopValue);
-
- optimizer.setMaxIterations(5);
- boolean succed = optimizer.optimize(o,stats,compositeStop);
- System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
- System.out.println("Solution: " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1]);
- if(succed){
- System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
- }else{
- System.out.println("Failed to optimize");
- }
- }
-
-
-
- public String toString(){
-
- return "P1: " + parameters[0] + " P2: " + parameters[1] + " value " + getValue() + " grad (" + getGradient()[0] + ":" + getGradient()[1]+")";
- }
-
- public static void main(String[] args) {
- double a = 1;
- double b=1;
- double x0 = 0;
- double y0 =1;
- double dx = 0.5;
- double dy = 0.2 ;
- double [] parameters = new double[2];
- parameters[0] = x0;
- parameters[1] = y0;
- X2y2WithConstraints o = new X2y2WithConstraints(a,b,parameters,dx,dy,
- new SimplexProjection(0.5)
- //new BoundsProjection(0.0,0.4)
- );
- System.out.println("Starting optimization " + " x0 " + o.parameters[0]+ " x1 " + o.parameters[1] + " a " + a + " b "+b );
- o.setDebugLevel(4);
-
- LineSearchMethod ls = new ArmijoLineSearchMinimizationAlongProjectionArc(new InterpolationPickFirstStep(1));
-
- OptimizerStats stats = new OptimizerStats();
- o.optimizeWithProjectedGradientDescent(ls, stats, o);
-
-// o = new x2y2WithConstraints(a,b,x0,y0,dx,dy);
-// stats = new OptimizerStats();
-// o.optimizeWithSpectralProjectedGradientDescent(stats, o);
- }
-
-
-
-
-}