summaryrefslogtreecommitdiff
path: root/gi
diff options
context:
space:
mode:
authordesaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-23 17:08:53 +0000
committerdesaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-23 17:08:53 +0000
commit76ef39de737e7abc0a8fe989dfacb7885617e59f (patch)
tree77c6099236431c4488aa5ac95b6d680bfd5faf05 /gi
parent7776119e54c477a27fb0617d8bf8b483ac78898e (diff)
vb runnable from trainer
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@380 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Trainer.java43
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/VB.java97
2 files changed, 92 insertions, 48 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
index b51db919..cea6a20a 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
@@ -18,7 +18,7 @@ public class Trainer
{
public static void main(String[] args)
{
-
+
OptionParser parser = new OptionParser();
parser.accepts("help");
parser.accepts("in").withRequiredArg().ofType(File.class);
@@ -107,6 +107,7 @@ public class Trainer
PhraseCluster cluster = null;
Agree2Sides agree2sides = null;
Agree agree= null;
+ VB vbModel=null;
if (options.has("agree-language"))
agree2sides = new Agree2Sides(tags, corpus,corpus1);
else if (options.has("agree-direction"))
@@ -115,7 +116,11 @@ public class Trainer
{
cluster = new PhraseCluster(tags, corpus);
if (threads > 0) cluster.useThreadPool(threads);
- if (vb) cluster.initialiseVB(alphaEmit, alphaPi);
+
+ if (vb) {
+ //cluster.initialiseVB(alphaEmit, alphaPi);
+ vbModel=new VB(tags,corpus);
+ }
if (options.has("no-parameter-cache"))
cluster.cacheLambda = false;
if (options.has("start"))
@@ -149,7 +154,7 @@ public class Trainer
if (!vb)
o = cluster.EM((i < skip) ? i+1 : 0);
else
- o = cluster.VBEM(alphaEmit, alphaPi);
+ o = vbModel.EM();
}
else
o = cluster.PREM(scale_phrase, scale_context, (i < skip) ? i+1 : 0);
@@ -166,10 +171,8 @@ public class Trainer
last = o;
}
- if (cluster == null && agree != null)
+ if (cluster == null)
cluster = agree.model1;
- else if (cluster == null && agree2sides != null)
- cluster = agree2sides.model1;
double pl1lmax = cluster.phrase_l1lmax();
double cl1lmax = cluster.context_l1lmax();
@@ -180,26 +183,20 @@ public class Trainer
File outfile = (File) options.valueOf("out");
try {
PrintStream ps = FileUtil.printstream(outfile);
- List<Edge> test = corpus.getEdges();
- if (options.has("test")) // just use the training
+ List<Edge> test;
+ if (!options.has("test")) // just use the training
+ test = corpus.getEdges();
+ else
{ // if --test supplied, load up the file
- if (agree2sides == null)
- {
- infile = (File) options.valueOf("test");
- System.out.println("Reading testing concordance from " + infile);
- test = corpus.readEdges(FileUtil.reader(infile));
- }
- else
- System.err.println("Can't run bilingual agreement model on different test data cf training (yet); --test ignored.");
+ infile = (File) options.valueOf("test");
+ System.out.println("Reading testing concordance from " + infile);
+ test = corpus.readEdges(FileUtil.reader(infile));
}
-
- if (agree != null)
- agree.displayPosterior(ps, test);
- else if (agree2sides != null)
- agree2sides.displayPosterior(ps);
- else
+ if(vb){
+ vbModel.displayPosterior(ps);
+ }else{
cluster.displayPosterior(ps, test);
-
+ }
ps.close();
} catch (IOException e) {
System.err.println("Failed to open either testing file or output file");
diff --git a/gi/posterior-regularisation/prjava/src/phrase/VB.java b/gi/posterior-regularisation/prjava/src/phrase/VB.java
index cc1c1c96..a858c883 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/VB.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/VB.java
@@ -16,7 +16,7 @@ import phrase.Corpus.Edge;
public class VB {
- public static int MAX_ITER=40;
+ public static int MAX_ITER=400;
/**@brief
* hyper param for beta
@@ -28,11 +28,13 @@ public class VB {
* hyper param for theta
* where theta is dirichlet for z
*/
- public double alpha=0.000001;
+ public double alpha=0.0001;
/**@brief
* variational param for beta
*/
private double rho[][][];
+ private double digamma_rho[][][];
+ private double rho_sum[][];
/**@brief
* variational param for z
*/
@@ -41,8 +43,7 @@ public class VB {
* variational param for theta
*/
private double gamma[];
-
- private static double VAL_DIFF_RATIO=0.001;
+ private static double VAL_DIFF_RATIO=0.005;
/**@brief
* objective for a single document
@@ -55,8 +56,8 @@ public class VB {
private Corpus c;
public static void main(String[] args) {
- String in="../pdata/canned.con";
- //String in="../pdata/btec.con";
+ // String in="../pdata/canned.con";
+ String in="../pdata/btec.con";
String out="../pdata/vb.out";
int numCluster=25;
Corpus corpus = null;
@@ -118,6 +119,7 @@ public class VB {
}
}
}
+
}
private void inference(int phraseID){
@@ -128,26 +130,21 @@ public class VB {
phi[i][j]=1.0/K;
}
}
- gamma = new double[K];
- double digamma_gamma[]=new double[K];
- for(int i=0;i<gamma.length;i++){
- gamma[i] = alpha + 1.0/K;
+ if(gamma==null){
+ gamma=new double[K];
}
+ Arrays.fill(gamma,alpha+1.0/K);
- double rho_sum[][]=new double [K][n_positions];
- for(int i=0;i<K;i++){
- for(int pos=0;pos<n_positions;pos++){
- rho_sum[i][pos]=Gamma.digamma(arr.F.l1norm(rho[i][pos]));
- }
- }
- double gamma_sum=Gamma.digamma(arr.F.l1norm(gamma));
+ double digamma_gamma[]=new double[K];
+
+ double gamma_sum=digamma(arr.F.l1norm(gamma));
for(int i=0;i<K;i++){
- digamma_gamma[i]=Gamma.digamma(gamma[i]);
+ digamma_gamma[i]=digamma(gamma[i]);
}
double gammaSum[]=new double [K];
-
double prev_val=0;
obj=0;
+
for(int iter=0;iter<MAX_ITER;iter++){
prev_val=obj;
obj=0;
@@ -159,7 +156,7 @@ public class VB {
double sum=0;
for(int pos=0;pos<n_positions;pos++){
int word=context.get(pos);
- sum+=Gamma.digamma(rho[i][pos][word])-rho_sum[i][pos];
+ sum+=digamma_rho[i][pos][word]-rho_sum[i][pos];
}
sum+= digamma_gamma[i]-gamma_sum;
phi[n][i]=sum;
@@ -183,11 +180,12 @@ public class VB {
for(int i=0;i<K;i++){
gamma[i]=alpha+gammaSum[i];
}
- gamma_sum=Gamma.digamma(arr.F.l1norm(gamma));
+ gamma_sum=digamma(arr.F.l1norm(gamma));
for(int i=0;i<K;i++){
- digamma_gamma[i]=Gamma.digamma(gamma[i]);
+ digamma_gamma[i]=digamma(gamma[i]);
}
//compute objective for reporting
+
obj=0;
for(int i=0;i<K;i++){
@@ -209,13 +207,13 @@ public class VB {
double beta_sum=0;
for(int pos=0;pos<n_positions;pos++){
int word=context.get(pos);
- beta_sum+=(Gamma.digamma(rho[i][pos][word])-rho_sum[i][pos]);
+ beta_sum+=(digamma(rho[i][pos][word])-rho_sum[i][pos]);
}
obj+=phi[n][i]*beta_sum;
}
}
- obj-=Gamma.logGamma(arr.F.l1norm(gamma));
+ obj-=log_gamma(arr.F.l1norm(gamma));
for(int i=0;i<K;i++){
obj+=Gamma.logGamma(gamma[i]);
obj-=(gamma[i]-1)*(digamma_gamma[i]-gamma_sum);
@@ -233,6 +231,26 @@ public class VB {
*/
public double EM(){
double emObj=0;
+ if(digamma_rho==null){
+ digamma_rho=new double[K][n_positions][n_words];
+ }
+ for(int i=0;i<K;i++){
+ for (int pos=0;pos<n_positions;pos++){
+ for(int j=0;j<n_words;j++){
+ digamma_rho[i][pos][j]= digamma(rho[i][pos][j]);
+ }
+ }
+ }
+
+ if(rho_sum==null){
+ rho_sum=new double [K][n_positions];
+ }
+ for(int i=0;i<K;i++){
+ for(int pos=0;pos<n_positions;pos++){
+ rho_sum[i][pos]=digamma(arr.F.l1norm(rho[i][pos]));
+ }
+ }
+
//E
double exp_rho[][][]=new double[K][n_positions][n_words];
@@ -248,7 +266,13 @@ public class VB {
}
}
}
-
+/* if(d!=0 && d%100==0){
+ System.out.print(".");
+ }
+ if(d!=0 && d%1000==0){
+ System.out.println(d);
+ }
+*/
emObj+=obj;
}
@@ -313,5 +337,28 @@ public class VB {
}
return(v);
}
+
+ double digamma(double x)
+ {
+ double p;
+ x=x+6;
+ p=1/(x*x);
+ p=(((0.004166666666667*p-0.003968253986254)*p+
+ 0.008333333333333)*p-0.083333333333333)*p;
+ p=p+Math.log(x)-0.5/x-1/(x-1)-1/(x-2)-1/(x-3)-1/(x-4)-1/(x-5)-1/(x-6);
+ return p;
+ }
+
+ double log_gamma(double x)
+ {
+ double z=1/(x*x);
+
+ x=x+6;
+ z=(((-0.000595238095238*z+0.000793650793651)
+ *z-0.002777777777778)*z+0.083333333333333)/x;
+ z=(x-0.5)*Math.log(x)-x+0.918938533204673+z-Math.log(x-1)-
+ Math.log(x-2)-Math.log(x-3)-Math.log(x-4)-Math.log(x-5)-Math.log(x-6);
+ return z;
+ }
}//End of class