diff options
author | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-07 14:11:42 +0000 |
---|---|---|
committer | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-07 14:11:42 +0000 |
commit | a15d666d23169dafdf01b7f5923570a9ba10787b (patch) | |
tree | 5f74648ef8e0d9a6c36c211d0d31a0465b2a295c /gi/posterior-regularisation | |
parent | 43d74920424e83c397321db549290f167e15db46 (diff) |
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@173 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation')
-rw-r--r-- | gi/posterior-regularisation/PhraseContextModel.java | 2 | ||||
l--------- | gi/posterior-regularisation/prjava.jar | 1 | ||||
-rw-r--r-- | gi/posterior-regularisation/prjava/prjava-20100707.jar | bin | 0 -> 933814 bytes | |||
-rw-r--r-- | gi/posterior-regularisation/prjava/src/io/FileUtil.java | 12 | ||||
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java | 110 | ||||
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java | 19 | ||||
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java | 19 | ||||
-rw-r--r-- | gi/posterior-regularisation/train_pr_global.py | 8 | ||||
-rw-r--r-- | gi/posterior-regularisation/train_pr_parallel.py | 25 |
9 files changed, 132 insertions, 64 deletions
diff --git a/gi/posterior-regularisation/PhraseContextModel.java b/gi/posterior-regularisation/PhraseContextModel.java index db152e73..85bcfb89 100644 --- a/gi/posterior-regularisation/PhraseContextModel.java +++ b/gi/posterior-regularisation/PhraseContextModel.java @@ -149,7 +149,7 @@ class PhraseContextModel System.out.println("Failed to optimize"); //System.out.println("Ended optimization in " + optimizer.getCurrentIteration()); - lambda = objective.getParameters(); + //lambda = objective.getParameters(); llh = objective.primal(); for (int i = 0; i < training.getNumPhrases(); ++i) diff --git a/gi/posterior-regularisation/prjava.jar b/gi/posterior-regularisation/prjava.jar new file mode 120000 index 00000000..7cd1a3ff --- /dev/null +++ b/gi/posterior-regularisation/prjava.jar @@ -0,0 +1 @@ +prjava/prjava-20100707.jar
\ No newline at end of file diff --git a/gi/posterior-regularisation/prjava/prjava-20100707.jar b/gi/posterior-regularisation/prjava/prjava-20100707.jar Binary files differnew file mode 100644 index 00000000..195374d9 --- /dev/null +++ b/gi/posterior-regularisation/prjava/prjava-20100707.jar diff --git a/gi/posterior-regularisation/prjava/src/io/FileUtil.java b/gi/posterior-regularisation/prjava/src/io/FileUtil.java index 7d9f2bc5..67ce571e 100644 --- a/gi/posterior-regularisation/prjava/src/io/FileUtil.java +++ b/gi/posterior-regularisation/prjava/src/io/FileUtil.java @@ -1,5 +1,7 @@ package io;
import java.util.*;
+import java.util.zip.GZIPInputStream;
+import java.util.zip.GZIPOutputStream;
import java.io.*;
public class FileUtil {
public static Scanner openInFile(String filename){
@@ -18,7 +20,10 @@ public class FileUtil { BufferedReader r=null;
try
{
- r=(new BufferedReader(new FileReader(new File(filename))));
+ if (filename.endsWith(".gz"))
+ r=(new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(new File(filename))))));
+ else
+ r=(new BufferedReader(new FileReader(new File(filename))));
}catch(IOException ioe){
System.out.println(ioe.getMessage());
}
@@ -29,7 +34,10 @@ public class FileUtil { PrintStream localps=null;
try
{
- localps=new PrintStream (new FileOutputStream(filename));
+ if (filename.endsWith(".gz"))
+ localps=new PrintStream (new GZIPOutputStream(new FileOutputStream(filename)));
+ else
+ localps=new PrintStream (new FileOutputStream(filename));
}catch(IOException ioe){
System.out.println(ioe.getMessage());
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index cd28c12e..731d03ac 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -1,11 +1,16 @@ package phrase;
import io.FileUtil;
+
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
+import java.util.zip.GZIPOutputStream;
public class PhraseCluster {
@@ -26,28 +31,46 @@ public class PhraseCluster { public static void main(String[] args) {
String input_fname = args[0];
int tags = Integer.parseInt(args[1]);
- String outputDir = args[2];
+ String output_fname = args[2];
int iterations = Integer.parseInt(args[3]);
double scale = Double.parseDouble(args[4]);
int threads = Integer.parseInt(args[5]);
+ boolean runEM = Boolean.parseBoolean(args[6]);
PhraseCorpus corpus = new PhraseCorpus(input_fname);
PhraseCluster cluster = new PhraseCluster(tags, corpus, scale, threads);
- PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out");
+ //PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out");
+ double last = 0;
for(int i=0;i<iterations;i++){
- double o = cluster.PREM();
- //double o = cluster.EM();
- PhraseObjective.ps.println("ITER: "+i+" objective: " + o);
+
+ double o;
+ if (runEM || i < 3)
+ o = cluster.EM();
+ else
+ o = cluster.PREM();
+ //PhraseObjective.ps.
+ System.out.println("ITER: "+i+" objective: " + o);
+ last = o;
+ }
+
+ if (runEM)
+ {
+ double l1lmax = cluster.posterior_l1lmax();
+ System.out.println("Final l1lmax term " + l1lmax + ", total PR objective " + (last - scale*l1lmax));
+ // nb. KL is 0 by definition
}
- PrintStream ps=io.FileUtil.openOutFile(outputDir + "/posterior.out");
+ PrintStream ps=io.FileUtil.openOutFile(output_fname);
cluster.displayPosterior(ps);
- ps.println();
- cluster.displayModelParam(ps);
ps.close();
- PhraseObjective.ps.close();
+
+ //PhraseObjective.ps.close();
+
+ //ps = io.FileUtil.openOutFile(outputDir + "/parameters.out");
+ //cluster.displayModelParam(ps);
+ //ps.close();
cluster.finish();
}
@@ -61,7 +84,7 @@ public class PhraseCluster { if (threads > 0)
pool = Executors.newFixedThreadPool(threads);
- emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
+ emit=new double [K][c.numContexts][n_words];
pi=new double[n_phrase][K];
for(double [][]i:emit){
@@ -82,7 +105,7 @@ public class PhraseCluster { }
public double EM(){
- double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
+ double [][][]exp_emit=new double [K][c.numContexts][n_words];
double [][]exp_pi=new double[n_phrase][K];
double loglikelihood=0;
@@ -93,7 +116,9 @@ public class PhraseCluster { for(int ctx=0;ctx<data.length;ctx++){
int context[]=data[ctx];
double p[]=posterior(phrase,context);
- loglikelihood+=Math.log(arr.F.l1norm(p));
+ double z = arr.F.l1norm(p);
+ assert z > 0;
+ loglikelihood+=Math.log(z);
arr.F.l1normalize(p);
int contextCnt=context[context.length-1];
@@ -132,7 +157,7 @@ public class PhraseCluster { if (pool != null)
return PREMParallel();
- double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
+ double [][][]exp_emit=new double [K][c.numContexts][n_words];
double [][]exp_pi=new double[n_phrase][K];
double loglikelihood=0;
@@ -142,7 +167,7 @@ public class PhraseCluster { PhraseObjective po=new PhraseObjective(this,phrase);
po.optimizeWithProjectedGradientDescent();
double [][] q=po.posterior();
- loglikelihood+=po.getValue();
+ loglikelihood+=po.llh;
primal+=po.primal();
for(int edge=0;edge<q.length;edge++){
int []context=c.data[phrase][edge];
@@ -184,7 +209,7 @@ public class PhraseCluster { final LinkedBlockingQueue<PhraseObjective> expectations
= new LinkedBlockingQueue<PhraseObjective>();
- double [][][]exp_emit=new double [K][PhraseCorpus.NUM_CONTEXT][n_words];
+ double [][][]exp_emit=new double [K][c.numContexts][n_words];
double [][]exp_pi=new double[n_phrase][K];
double loglikelihood=0;
@@ -220,7 +245,7 @@ public class PhraseCluster { int phrase = po.phrase;
//System.out.println("" + Thread.currentThread().getId() + " taken phrase " + phrase);
double [][] q=po.posterior();
- loglikelihood+=po.getValue();
+ loglikelihood+=po.llh;
primal+=po.primal();
for(int edge=0;edge<q.length;edge++){
int []context=c.data[phrase][edge];
@@ -295,18 +320,19 @@ public class PhraseCluster { // emit phrase
ps.print(c.phraseList[i]);
ps.print("\t");
- ps.print(c.getContextString(e));
- ps.print("||| C=" + e[e.length-1] + " |||");
-
+ ps.print(c.getContextString(e, true));
int t=arr.F.argmax(probs);
+ ps.println(" ||| C=" + t);
+
+ //ps.print("||| C=" + e[e.length-1] + " |||");
- ps.print(t+"||| [");
- for(t=0;t<K;t++){
- ps.print(probs[t]+", ");
- }
+ //ps.print(t+"||| [");
+ //for(t=0;t<K;t++){
+ // ps.print(probs[t]+", ");
+ //}
// for (int t = 0; t < numTags; ++t)
// System.out.print(" " + probs[t]);
- ps.println("]");
+ //ps.println("]");
}
}
}
@@ -329,14 +355,14 @@ public class PhraseCluster { ps.println("P(word|tag,position)");
for (int i = 0; i < K; ++i)
{
- ps.println(i);
- for(int position=0;position<PhraseCorpus.NUM_CONTEXT;position++){
- ps.println(position);
+ for(int position=0;position<c.numContexts;position++){
+ ps.println("tag " + i + " position " + position);
for(int word=0;word<emit[i][position].length;word++){
- if((word+1)%100==0){
- ps.println();
- }
- ps.print(c.wordList[word]+"="+emit[i][position][word]+"\t");
+ //if((word+1)%100==0){
+ // ps.println();
+ //}
+ if (emit[i][position][word] > 1e-10)
+ ps.print(c.wordList[word]+"="+emit[i][position][word]+"\t");
}
ps.println();
}
@@ -344,4 +370,26 @@ public class PhraseCluster { }
}
+
+ double posterior_l1lmax()
+ {
+ double sum=0;
+ for(int phrase=0;phrase<c.data.length;phrase++)
+ {
+ int [][] data = c.data[phrase];
+ double [] maxes = new double[K];
+ for(int ctx=0;ctx<data.length;ctx++)
+ {
+ int context[]=data[ctx];
+ double p[]=posterior(phrase,context);
+ arr.F.l1normalize(p);
+
+ for(int tag=0;tag<K;tag++)
+ maxes[tag] = Math.max(maxes[tag], p[tag]);
+ }
+ for(int tag=0;tag<K;tag++)
+ sum += maxes[tag];
+ }
+ return sum;
+ }
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java index 99545371..b8f1f24a 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java @@ -9,11 +9,9 @@ import java.util.HashMap; import java.util.Scanner;
public class PhraseCorpus {
-
public static String LEX_FILENAME="../pdata/lex.out";
public static String DATA_FILENAME="../pdata/btec.con";
- public static int NUM_CONTEXT=4;
public HashMap<String,Integer>wordLex;
public HashMap<String,Integer>phraseLex;
@@ -23,6 +21,7 @@ public class PhraseCorpus { //data[phrase][num context][position]
public int data[][][];
+ public int numContexts;
public static void main(String[] args) {
// TODO Auto-generated method stub
@@ -40,6 +39,7 @@ public class PhraseCorpus { ArrayList<int[][]>dataList=new ArrayList<int[][]>();
String line=null;
+ numContexts = 0;
while((line=readLine(r))!=null){
@@ -54,7 +54,12 @@ public class PhraseCorpus { for(int i=0;i<toks.length;i+=2){
String ctx=toks[i];
String words[]=ctx.split(" ");
- int []context=new int [NUM_CONTEXT+1];
+ if (numContexts == 0)
+ numContexts = words.length - 1;
+ else
+ assert numContexts == words.length - 1;
+
+ int []context=new int [numContexts+1];
int idx=0;
for(String word:words){
if(word.equals("<PHRASE>")){
@@ -68,9 +73,7 @@ public class PhraseCorpus { String count=toks[i+1];
context[idx]=Integer.parseInt(count.trim().substring(2));
-
ctxList.add(context);
-
}
dataList.add(ctxList.toArray(new int [0][]));
@@ -157,13 +160,17 @@ public class PhraseCorpus { return dict;
}
- public String getContextString(int context[])
+ public String getContextString(int context[], boolean addPhraseMarker)
{
StringBuffer b = new StringBuffer();
for (int i=0;i<context.length-1;i++)
{
if (b.length() > 0)
b.append(" ");
+
+ if (i == context.length/2)
+ b.append("<PHRASE> ");
+
b.append(wordList[context[i]]);
}
return b.toString();
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 71c91b96..b7c62261 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -20,17 +20,17 @@ import optimization.util.MathUtils; public class PhraseObjective extends ProjectedObjective{
private static final double GRAD_DIFF = 0.002;
- public static double INIT_STEP_SIZE=1;
- public static double VAL_DIFF=0.001;
- private double c1=0.0001;
- private double c2=0.9;
+ public static double INIT_STEP_SIZE = 10;
+ public static double VAL_DIFF = 0.001; // FIXME needs to be tuned
+ //private double c1=0.0001; // wolf stuff
+ //private double c2=0.9;
private PhraseCluster c;
/**@brief
* for debugging purposes
*/
- public static PrintStream ps;
+ //public static PrintStream ps;
/**@brief current phrase being optimzed*/
public int phrase;
@@ -61,7 +61,7 @@ public class PhraseObjective extends ProjectedObjective{ /**@brief likelihood under p
*
*/
- private double llh;
+ public double llh;
public PhraseObjective(PhraseCluster cluster, int phraseIdx){
phrase=phraseIdx;
@@ -181,7 +181,7 @@ public class PhraseObjective extends ProjectedObjective{ boolean succed = optimizer.optimize(this,stats,compositeStop);
// System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
if(succed){
- System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
+ //System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
}else{
System.out.println("Failed to optimize");
}
@@ -208,6 +208,10 @@ public class PhraseObjective extends ProjectedObjective{ double kl=-loglikelihood
+MathUtils.dotProduct(parameters, gradient);
// ps.print(", "+kl);
+ //System.out.println("llh " + llh);
+ //System.out.println("kl " + kl);
+
+
l=l-kl;
double sum=0;
for(int tag=0;tag<c.K;tag++){
@@ -219,6 +223,7 @@ public class PhraseObjective extends ProjectedObjective{ }
sum+=max;
}
+ //System.out.println("l1lmax " + sum);
// ps.println(", "+sum);
l=l-c.scale*sum;
return l;
diff --git a/gi/posterior-regularisation/train_pr_global.py b/gi/posterior-regularisation/train_pr_global.py index 8b80c6bc..f2806b6e 100644 --- a/gi/posterior-regularisation/train_pr_global.py +++ b/gi/posterior-regularisation/train_pr_global.py @@ -290,10 +290,4 @@ for p, (phrase, ccs) in enumerate(edges_phrase_to_context): cz = sum(conditionals) conditionals /= cz - #scores = zeros(num_tags) - #li = lamba_index[phrase, context] - #for t in range(num_tags): - # scores[t] = conditionals[t] * exp(-lamba[li + t]) - - #print '%s\t%s ||| C=%d ||| %d |||' % (phrase, context, count, argmax(scores)), scores / sum(scores) - print '%s\t%s ||| C=%d ||| %d |||' % (phrase, context, count, argmax(conditionals)), conditionals + print '%s\t%s ||| C=%d |||' % (phrase, context, argmax(conditionals)), conditionals diff --git a/gi/posterior-regularisation/train_pr_parallel.py b/gi/posterior-regularisation/train_pr_parallel.py index 4de7f504..3b9cefed 100644 --- a/gi/posterior-regularisation/train_pr_parallel.py +++ b/gi/posterior-regularisation/train_pr_parallel.py @@ -41,7 +41,7 @@ for line in sys.stdin: # Step 2: initialise the model parameters # -num_tags = 5 +num_tags = 25 num_types = len(types) num_phrases = len(edges_phrase_to_context) num_contexts = len(edges_context_to_phrase) @@ -86,7 +86,7 @@ class GlobalDualObjective: self.posterior[index,t] = prob z = sum(self.posterior[index,:]) self.posterior[index,:] /= z - self.llh += log(z) + self.llh += log(z) * count index += 1 def objective(self, ls): @@ -192,7 +192,7 @@ class LocalDualObjective: self.posterior[i,t] = prob z = sum(self.posterior[i,:]) self.posterior[i,:] /= z - self.llh += log(z) + self.llh += log(z) * count def objective(self, ls): edges = edges_phrase_to_context[self.phraseId][1] @@ -243,9 +243,10 @@ class LocalDualObjective: gradient[t,i,t] -= count return gradient.reshape((num_tags, len(edges)*num_tags)) - def optimize(self): + def optimize(self, ls=None): edges = edges_phrase_to_context[self.phraseId][1] - ls = zeros(len(edges) * num_tags) + if ls == None: + ls = zeros(len(edges) * num_tags) #print '\tpre lambda optimisation dual', self.objective(ls) #, 'primal', primal(lamba) ls = scipy.optimize.fmin_slsqp(self.objective, ls, bounds=[(0, self.scale)] * len(edges) * num_tags, @@ -253,6 +254,7 @@ class LocalDualObjective: fprime=self.gradient, fprime_ieqcons=self.constraints_gradient, iprint=0) # =2 for verbose + #print '\tlambda', list(ls) #print '\tpost lambda optimisation dual', self.objective(ls) #, 'primal', primal(lamba) # returns llh, kl and l1lmax contribution @@ -263,8 +265,9 @@ class LocalDualObjective: lmax = max(lmax, self.q[i,t]) l1lmax += lmax - return self.llh, -self.objective(ls) + dot(ls, self.gradient(ls)), l1lmax + return self.llh, -self.objective(ls) + dot(ls, self.gradient(ls)), l1lmax, ls +ls = [None] * num_phrases for iteration in range(20): tagCounts = [zeros(num_tags) for p in range(num_phrases)] contextWordCounts = [[zeros(num_types) for t in range(num_tags)] for i in range(4)] @@ -275,11 +278,13 @@ for iteration in range(20): for p in range(num_phrases): o = LocalDualObjective(p, delta) #print '\toptimising lambda for phrase', p, '=', edges_phrase_to_context[p][0] - obj = o.optimize() - print '\tphrase', p, 'deltas', obj + #print '\toptimising lambda for phrase', p, 'ls', ls[p] + obj = o.optimize(ls[p]) + #print '\tphrase', p, 'deltas', obj llh += obj[0] kl += obj[1] l1lmax += obj[2] + ls[p] = obj[3] edges = edges_phrase_to_context[p][1] for j, (context, count) in enumerate(edges): @@ -305,7 +310,7 @@ for iteration in range(20): contextWordCounts[i][t][types[context[i]]] += count * o.q[index,t] index += 1 - print 'iteration', iteration, 'objective', (llh + kl + delta * l1lmax), 'llh', llh, 'kl', kl, 'l1lmax', l1lmax + print 'iteration', iteration, 'objective', (llh - kl - delta * l1lmax), 'llh', llh, 'kl', kl, 'l1lmax', l1lmax # M-step for p in range(num_phrases): @@ -325,4 +330,4 @@ for p, (phrase, ccs) in enumerate(edges_phrase_to_context): cz = sum(conditionals) conditionals /= cz - print '%s\t%s ||| C=%d ||| %d |||' % (phrase, context, count, argmax(conditionals)), conditionals + print '%s\t%s ||| C=%d |||' % (phrase, context, argmax(conditionals)), conditionals |