summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/phrase/VB.java
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/phrase/VB.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/VB.java97
1 files changed, 72 insertions, 25 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/VB.java b/gi/posterior-regularisation/prjava/src/phrase/VB.java
index cc1c1c96..a858c883 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/VB.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/VB.java
@@ -16,7 +16,7 @@ import phrase.Corpus.Edge;
public class VB {
- public static int MAX_ITER=40;
+ public static int MAX_ITER=400;
/**@brief
* hyper param for beta
@@ -28,11 +28,13 @@ public class VB {
* hyper param for theta
* where theta is dirichlet for z
*/
- public double alpha=0.000001;
+ public double alpha=0.0001;
/**@brief
* variational param for beta
*/
private double rho[][][];
+ private double digamma_rho[][][];
+ private double rho_sum[][];
/**@brief
* variational param for z
*/
@@ -41,8 +43,7 @@ public class VB {
* variational param for theta
*/
private double gamma[];
-
- private static double VAL_DIFF_RATIO=0.001;
+ private static double VAL_DIFF_RATIO=0.005;
/**@brief
* objective for a single document
@@ -55,8 +56,8 @@ public class VB {
private Corpus c;
public static void main(String[] args) {
- String in="../pdata/canned.con";
- //String in="../pdata/btec.con";
+ // String in="../pdata/canned.con";
+ String in="../pdata/btec.con";
String out="../pdata/vb.out";
int numCluster=25;
Corpus corpus = null;
@@ -118,6 +119,7 @@ public class VB {
}
}
}
+
}
private void inference(int phraseID){
@@ -128,26 +130,21 @@ public class VB {
phi[i][j]=1.0/K;
}
}
- gamma = new double[K];
- double digamma_gamma[]=new double[K];
- for(int i=0;i<gamma.length;i++){
- gamma[i] = alpha + 1.0/K;
+ if(gamma==null){
+ gamma=new double[K];
}
+ Arrays.fill(gamma,alpha+1.0/K);
- double rho_sum[][]=new double [K][n_positions];
- for(int i=0;i<K;i++){
- for(int pos=0;pos<n_positions;pos++){
- rho_sum[i][pos]=Gamma.digamma(arr.F.l1norm(rho[i][pos]));
- }
- }
- double gamma_sum=Gamma.digamma(arr.F.l1norm(gamma));
+ double digamma_gamma[]=new double[K];
+
+ double gamma_sum=digamma(arr.F.l1norm(gamma));
for(int i=0;i<K;i++){
- digamma_gamma[i]=Gamma.digamma(gamma[i]);
+ digamma_gamma[i]=digamma(gamma[i]);
}
double gammaSum[]=new double [K];
-
double prev_val=0;
obj=0;
+
for(int iter=0;iter<MAX_ITER;iter++){
prev_val=obj;
obj=0;
@@ -159,7 +156,7 @@ public class VB {
double sum=0;
for(int pos=0;pos<n_positions;pos++){
int word=context.get(pos);
- sum+=Gamma.digamma(rho[i][pos][word])-rho_sum[i][pos];
+ sum+=digamma_rho[i][pos][word]-rho_sum[i][pos];
}
sum+= digamma_gamma[i]-gamma_sum;
phi[n][i]=sum;
@@ -183,11 +180,12 @@ public class VB {
for(int i=0;i<K;i++){
gamma[i]=alpha+gammaSum[i];
}
- gamma_sum=Gamma.digamma(arr.F.l1norm(gamma));
+ gamma_sum=digamma(arr.F.l1norm(gamma));
for(int i=0;i<K;i++){
- digamma_gamma[i]=Gamma.digamma(gamma[i]);
+ digamma_gamma[i]=digamma(gamma[i]);
}
//compute objective for reporting
+
obj=0;
for(int i=0;i<K;i++){
@@ -209,13 +207,13 @@ public class VB {
double beta_sum=0;
for(int pos=0;pos<n_positions;pos++){
int word=context.get(pos);
- beta_sum+=(Gamma.digamma(rho[i][pos][word])-rho_sum[i][pos]);
+ beta_sum+=(digamma(rho[i][pos][word])-rho_sum[i][pos]);
}
obj+=phi[n][i]*beta_sum;
}
}
- obj-=Gamma.logGamma(arr.F.l1norm(gamma));
+ obj-=log_gamma(arr.F.l1norm(gamma));
for(int i=0;i<K;i++){
obj+=Gamma.logGamma(gamma[i]);
obj-=(gamma[i]-1)*(digamma_gamma[i]-gamma_sum);
@@ -233,6 +231,26 @@ public class VB {
*/
public double EM(){
double emObj=0;
+ if(digamma_rho==null){
+ digamma_rho=new double[K][n_positions][n_words];
+ }
+ for(int i=0;i<K;i++){
+ for (int pos=0;pos<n_positions;pos++){
+ for(int j=0;j<n_words;j++){
+ digamma_rho[i][pos][j]= digamma(rho[i][pos][j]);
+ }
+ }
+ }
+
+ if(rho_sum==null){
+ rho_sum=new double [K][n_positions];
+ }
+ for(int i=0;i<K;i++){
+ for(int pos=0;pos<n_positions;pos++){
+ rho_sum[i][pos]=digamma(arr.F.l1norm(rho[i][pos]));
+ }
+ }
+
//E
double exp_rho[][][]=new double[K][n_positions][n_words];
@@ -248,7 +266,13 @@ public class VB {
}
}
}
-
+/* if(d!=0 && d%100==0){
+ System.out.print(".");
+ }
+ if(d!=0 && d%1000==0){
+ System.out.println(d);
+ }
+*/
emObj+=obj;
}
@@ -313,5 +337,28 @@ public class VB {
}
return(v);
}
+
+ double digamma(double x)
+ {
+ double p;
+ x=x+6;
+ p=1/(x*x);
+ p=(((0.004166666666667*p-0.003968253986254)*p+
+ 0.008333333333333)*p-0.083333333333333)*p;
+ p=p+Math.log(x)-0.5/x-1/(x-1)-1/(x-2)-1/(x-3)-1/(x-4)-1/(x-5)-1/(x-6);
+ return p;
+ }
+
+ double log_gamma(double x)
+ {
+ double z=1/(x*x);
+
+ x=x+6;
+ z=(((-0.000595238095238*z+0.000793650793651)
+ *z-0.002777777777778)*z+0.083333333333333)/x;
+ z=(x-0.5)*Math.log(x)-x+0.918938533204673+z-Math.log(x-1)-
+ Math.log(x-2)-Math.log(x-3)-Math.log(x-4)-Math.log(x-5)-Math.log(x-6);
+ return z;
+ }
}//End of class