diff options
author | desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-23 17:08:53 +0000 |
---|---|---|
committer | desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-23 17:08:53 +0000 |
commit | 76ef39de737e7abc0a8fe989dfacb7885617e59f (patch) | |
tree | 77c6099236431c4488aa5ac95b6d680bfd5faf05 | |
parent | 7776119e54c477a27fb0617d8bf8b483ac78898e (diff) |
vb runnable from trainer
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@380 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/Trainer.java | 43 | ||||
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/VB.java | 97 |
2 files changed, 92 insertions, 48 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index b51db919..cea6a20a 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -18,7 +18,7 @@ public class Trainer { public static void main(String[] args) { - + OptionParser parser = new OptionParser(); parser.accepts("help"); parser.accepts("in").withRequiredArg().ofType(File.class); @@ -107,6 +107,7 @@ public class Trainer PhraseCluster cluster = null; Agree2Sides agree2sides = null; Agree agree= null; + VB vbModel=null; if (options.has("agree-language")) agree2sides = new Agree2Sides(tags, corpus,corpus1); else if (options.has("agree-direction")) @@ -115,7 +116,11 @@ public class Trainer { cluster = new PhraseCluster(tags, corpus); if (threads > 0) cluster.useThreadPool(threads); - if (vb) cluster.initialiseVB(alphaEmit, alphaPi); + + if (vb) { + //cluster.initialiseVB(alphaEmit, alphaPi); + vbModel=new VB(tags,corpus); + } if (options.has("no-parameter-cache")) cluster.cacheLambda = false; if (options.has("start")) @@ -149,7 +154,7 @@ public class Trainer if (!vb) o = cluster.EM((i < skip) ? i+1 : 0); else - o = cluster.VBEM(alphaEmit, alphaPi); + o = vbModel.EM(); } else o = cluster.PREM(scale_phrase, scale_context, (i < skip) ? i+1 : 0); @@ -166,10 +171,8 @@ public class Trainer last = o; } - if (cluster == null && agree != null) + if (cluster == null) cluster = agree.model1; - else if (cluster == null && agree2sides != null) - cluster = agree2sides.model1; double pl1lmax = cluster.phrase_l1lmax(); double cl1lmax = cluster.context_l1lmax(); @@ -180,26 +183,20 @@ public class Trainer File outfile = (File) options.valueOf("out"); try { PrintStream ps = FileUtil.printstream(outfile); - List<Edge> test = corpus.getEdges(); - if (options.has("test")) // just use the training + List<Edge> test; + if (!options.has("test")) // just use the training + test = corpus.getEdges(); + else { // if --test supplied, load up the file - if (agree2sides == null) - { - infile = (File) options.valueOf("test"); - System.out.println("Reading testing concordance from " + infile); - test = corpus.readEdges(FileUtil.reader(infile)); - } - else - System.err.println("Can't run bilingual agreement model on different test data cf training (yet); --test ignored."); + infile = (File) options.valueOf("test"); + System.out.println("Reading testing concordance from " + infile); + test = corpus.readEdges(FileUtil.reader(infile)); } - - if (agree != null) - agree.displayPosterior(ps, test); - else if (agree2sides != null) - agree2sides.displayPosterior(ps); - else + if(vb){ + vbModel.displayPosterior(ps); + }else{ cluster.displayPosterior(ps, test); - + } ps.close(); } catch (IOException e) { System.err.println("Failed to open either testing file or output file"); diff --git a/gi/posterior-regularisation/prjava/src/phrase/VB.java b/gi/posterior-regularisation/prjava/src/phrase/VB.java index cc1c1c96..a858c883 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/VB.java +++ b/gi/posterior-regularisation/prjava/src/phrase/VB.java @@ -16,7 +16,7 @@ import phrase.Corpus.Edge; public class VB {
- public static int MAX_ITER=40;
+ public static int MAX_ITER=400;
/**@brief
* hyper param for beta
@@ -28,11 +28,13 @@ public class VB { * hyper param for theta
* where theta is dirichlet for z
*/
- public double alpha=0.000001;
+ public double alpha=0.0001;
/**@brief
* variational param for beta
*/
private double rho[][][];
+ private double digamma_rho[][][];
+ private double rho_sum[][];
/**@brief
* variational param for z
*/
@@ -41,8 +43,7 @@ public class VB { * variational param for theta
*/
private double gamma[];
-
- private static double VAL_DIFF_RATIO=0.001;
+ private static double VAL_DIFF_RATIO=0.005;
/**@brief
* objective for a single document
@@ -55,8 +56,8 @@ public class VB { private Corpus c;
public static void main(String[] args) {
- String in="../pdata/canned.con";
- //String in="../pdata/btec.con";
+ // String in="../pdata/canned.con";
+ String in="../pdata/btec.con";
String out="../pdata/vb.out";
int numCluster=25;
Corpus corpus = null;
@@ -118,6 +119,7 @@ public class VB { }
}
}
+
}
private void inference(int phraseID){
@@ -128,26 +130,21 @@ public class VB { phi[i][j]=1.0/K;
}
}
- gamma = new double[K];
- double digamma_gamma[]=new double[K];
- for(int i=0;i<gamma.length;i++){
- gamma[i] = alpha + 1.0/K;
+ if(gamma==null){
+ gamma=new double[K];
}
+ Arrays.fill(gamma,alpha+1.0/K);
- double rho_sum[][]=new double [K][n_positions];
- for(int i=0;i<K;i++){
- for(int pos=0;pos<n_positions;pos++){
- rho_sum[i][pos]=Gamma.digamma(arr.F.l1norm(rho[i][pos]));
- }
- }
- double gamma_sum=Gamma.digamma(arr.F.l1norm(gamma));
+ double digamma_gamma[]=new double[K];
+
+ double gamma_sum=digamma(arr.F.l1norm(gamma));
for(int i=0;i<K;i++){
- digamma_gamma[i]=Gamma.digamma(gamma[i]);
+ digamma_gamma[i]=digamma(gamma[i]);
}
double gammaSum[]=new double [K];
-
double prev_val=0;
obj=0;
+
for(int iter=0;iter<MAX_ITER;iter++){
prev_val=obj;
obj=0;
@@ -159,7 +156,7 @@ public class VB { double sum=0;
for(int pos=0;pos<n_positions;pos++){
int word=context.get(pos);
- sum+=Gamma.digamma(rho[i][pos][word])-rho_sum[i][pos];
+ sum+=digamma_rho[i][pos][word]-rho_sum[i][pos];
}
sum+= digamma_gamma[i]-gamma_sum;
phi[n][i]=sum;
@@ -183,11 +180,12 @@ public class VB { for(int i=0;i<K;i++){
gamma[i]=alpha+gammaSum[i];
}
- gamma_sum=Gamma.digamma(arr.F.l1norm(gamma));
+ gamma_sum=digamma(arr.F.l1norm(gamma));
for(int i=0;i<K;i++){
- digamma_gamma[i]=Gamma.digamma(gamma[i]);
+ digamma_gamma[i]=digamma(gamma[i]);
}
//compute objective for reporting
+
obj=0;
for(int i=0;i<K;i++){
@@ -209,13 +207,13 @@ public class VB { double beta_sum=0;
for(int pos=0;pos<n_positions;pos++){
int word=context.get(pos);
- beta_sum+=(Gamma.digamma(rho[i][pos][word])-rho_sum[i][pos]);
+ beta_sum+=(digamma(rho[i][pos][word])-rho_sum[i][pos]);
}
obj+=phi[n][i]*beta_sum;
}
}
- obj-=Gamma.logGamma(arr.F.l1norm(gamma));
+ obj-=log_gamma(arr.F.l1norm(gamma));
for(int i=0;i<K;i++){
obj+=Gamma.logGamma(gamma[i]);
obj-=(gamma[i]-1)*(digamma_gamma[i]-gamma_sum);
@@ -233,6 +231,26 @@ public class VB { */
public double EM(){
double emObj=0;
+ if(digamma_rho==null){
+ digamma_rho=new double[K][n_positions][n_words];
+ }
+ for(int i=0;i<K;i++){
+ for (int pos=0;pos<n_positions;pos++){
+ for(int j=0;j<n_words;j++){
+ digamma_rho[i][pos][j]= digamma(rho[i][pos][j]);
+ }
+ }
+ }
+
+ if(rho_sum==null){
+ rho_sum=new double [K][n_positions];
+ }
+ for(int i=0;i<K;i++){
+ for(int pos=0;pos<n_positions;pos++){
+ rho_sum[i][pos]=digamma(arr.F.l1norm(rho[i][pos]));
+ }
+ }
+
//E
double exp_rho[][][]=new double[K][n_positions][n_words];
@@ -248,7 +266,13 @@ public class VB { }
}
}
-
+/* if(d!=0 && d%100==0){
+ System.out.print(".");
+ }
+ if(d!=0 && d%1000==0){
+ System.out.println(d);
+ }
+*/
emObj+=obj;
}
@@ -313,5 +337,28 @@ public class VB { }
return(v);
}
+
+ double digamma(double x)
+ {
+ double p;
+ x=x+6;
+ p=1/(x*x);
+ p=(((0.004166666666667*p-0.003968253986254)*p+
+ 0.008333333333333)*p-0.083333333333333)*p;
+ p=p+Math.log(x)-0.5/x-1/(x-1)-1/(x-2)-1/(x-3)-1/(x-4)-1/(x-5)-1/(x-6);
+ return p;
+ }
+
+ double log_gamma(double x)
+ {
+ double z=1/(x*x);
+
+ x=x+6;
+ z=(((-0.000595238095238*z+0.000793650793651)
+ *z-0.002777777777778)*z+0.083333333333333)/x;
+ z=(x-0.5)*Math.log(x)-x+0.918938533204673+z-Math.log(x-1)-
+ Math.log(x-2)-Math.log(x-3)-Math.log(x-4)-Math.log(x-5)-Math.log(x-6);
+ return z;
+ }
}//End of class
|