diff options
author | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-08 21:46:05 +0000 |
---|---|---|
committer | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-08 21:46:05 +0000 |
commit | f2be77ccae455563e167607ee7527abbf8d96e60 (patch) | |
tree | e5e02f908fca912f6ed4f0582fff7ecef926f24e /gi/posterior-regularisation/prjava/src | |
parent | fab1e67814aca6abba75fb1bee086255bea8daf0 (diff) |
New context constraints.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@190 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src')
6 files changed, 832 insertions, 221 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java new file mode 100644 index 00000000..d5e856ca --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java @@ -0,0 +1,221 @@ +package phrase; + +import gnu.trove.TIntArrayList; + +import java.io.*; +import java.util.*; +import java.util.regex.Pattern; + + +public class Corpus +{ + private Lexicon<String> wordLexicon = new Lexicon<String>(); + private Lexicon<TIntArrayList> phraseLexicon = new Lexicon<TIntArrayList>(); + private Lexicon<TIntArrayList> contextLexicon = new Lexicon<TIntArrayList>(); + private List<Edge> edges = new ArrayList<Edge>(); + private List<List<Edge>> phraseToContext = new ArrayList<List<Edge>>(); + private List<List<Edge>> contextToPhrase = new ArrayList<List<Edge>>(); + + public class Edge + { + Edge(int phraseId, int contextId, int count) + { + this.phraseId = phraseId; + this.contextId = contextId; + this.count = count; + } + public int getPhraseId() + { + return phraseId; + } + public TIntArrayList getPhrase() + { + return Corpus.this.getPhrase(phraseId); + } + public String getPhraseString() + { + return Corpus.this.getPhraseString(phraseId); + } + public int getContextId() + { + return contextId; + } + public TIntArrayList getContext() + { + return Corpus.this.getContext(contextId); + } + public String getContextString(boolean insertPhraseSentinel) + { + return Corpus.this.getContextString(contextId, insertPhraseSentinel); + } + public int getCount() + { + return count; + } + public boolean equals(Object other) + { + if (other instanceof Edge) + { + Edge oe = (Edge) other; + return oe.phraseId == phraseId && oe.contextId == contextId; + } + else return false; + } + public int hashCode() + { // this is how boost's hash_combine does it + int seed = phraseId; + seed ^= contextId + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; + } + public String toString() + { + return getPhraseString() + "\t" + getContextString(true); + } + + private int phraseId; + private int contextId; + private int count; + } + + List<Edge> getEdges() + { + return edges; + } + + int getNumEdges() + { + return edges.size(); + } + + int getNumPhrases() + { + return phraseLexicon.size(); + } + + int getNumContextPositions() + { + return contextLexicon.lookup(0).size(); + } + + List<Edge> getEdgesForPhrase(int phraseId) + { + return phraseToContext.get(phraseId); + } + + int getNumContexts() + { + return contextLexicon.size(); + } + + List<Edge> getEdgesForContext(int contextId) + { + return contextToPhrase.get(contextId); + } + + int getNumWords() + { + return wordLexicon.size(); + } + + String getWord(int wordId) + { + return wordLexicon.lookup(wordId); + } + + public TIntArrayList getPhrase(int phraseId) + { + return phraseLexicon.lookup(phraseId); + } + + public String getPhraseString(int phraseId) + { + StringBuffer b = new StringBuffer(); + for (int tid: getPhrase(phraseId).toNativeArray()) + { + if (b.length() > 0) + b.append(" "); + b.append(wordLexicon.lookup(tid)); + } + return b.toString(); + } + + public TIntArrayList getContext(int contextId) + { + return contextLexicon.lookup(contextId); + } + + public String getContextString(int contextId, boolean insertPhraseSentinel) + { + StringBuffer b = new StringBuffer(); + TIntArrayList c = getContext(contextId); + for (int i = 0; i < c.size(); ++i) + { + if (i > 0) b.append(" "); + if (i == c.size() / 2) b.append("<PHRASE> "); + b.append(wordLexicon.lookup(c.get(i))); + } + return b.toString(); + } + + static Corpus readFromFile(Reader in) throws IOException + { + Corpus c = new Corpus(); + + // read in line-by-line + BufferedReader bin = new BufferedReader(in); + String line; + Pattern separator = Pattern.compile(" \\|\\|\\| "); + + while ((line = bin.readLine()) != null) + { + // split into phrase and contexts + StringTokenizer st = new StringTokenizer(line, "\t"); + assert (st.hasMoreTokens()); + String phraseToks = st.nextToken(); + assert (st.hasMoreTokens()); + String rest = st.nextToken(); + assert (!st.hasMoreTokens()); + + // process phrase + st = new StringTokenizer(phraseToks, " "); + TIntArrayList ptoks = new TIntArrayList(); + while (st.hasMoreTokens()) + ptoks.add(c.wordLexicon.insert(st.nextToken())); + int phraseId = c.phraseLexicon.insert(ptoks); + if (phraseId == c.phraseToContext.size()) + c.phraseToContext.add(new ArrayList<Edge>()); + + // process contexts + String[] parts = separator.split(rest); + assert (parts.length % 2 == 0); + for (int i = 0; i < parts.length; i += 2) + { + // process pairs of strings - context and count + TIntArrayList ctx = new TIntArrayList(); + String ctxString = parts[i]; + String countString = parts[i + 1]; + StringTokenizer ctxStrtok = new StringTokenizer(ctxString, " "); + while (ctxStrtok.hasMoreTokens()) + { + String token = ctxStrtok.nextToken(); + if (!token.equals("<PHRASE>")) + ctx.add(c.wordLexicon.insert(token)); + } + int contextId = c.contextLexicon.insert(ctx); + if (contextId == c.contextToPhrase.size()) + c.contextToPhrase.add(new ArrayList<Edge>()); + + assert (countString.startsWith("C=")); + Edge e = c.new Edge(phraseId, contextId, + Integer.parseInt(countString.substring(2).trim())); + c.edges.add(e); + + // index the edge for fast phrase, context lookup + c.phraseToContext.get(phraseId).add(e); + c.contextToPhrase.get(contextId).add(e); + } + } + + return c; + } +} diff --git a/gi/posterior-regularisation/prjava/src/phrase/Lexicon.java b/gi/posterior-regularisation/prjava/src/phrase/Lexicon.java new file mode 100644 index 00000000..a386e4a3 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/Lexicon.java @@ -0,0 +1,34 @@ +package phrase; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class Lexicon<T> +{ + public int insert(T word) + { + Integer i = wordToIndex.get(word); + if (i == null) + { + i = indexToWord.size(); + wordToIndex.put(word, i); + indexToWord.add(word); + } + return i; + } + + public T lookup(int index) + { + return indexToWord.get(index); + } + + public int size() + { + return indexToWord.size(); + } + + private Map<T, Integer> wordToIndex = new HashMap<T, Integer>(); + private List<T> indexToWord = new ArrayList<T>(); +}
\ No newline at end of file diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 731d03ac..e4db2a1a 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -1,44 +1,54 @@ package phrase;
+import gnu.trove.TIntArrayList;
import io.FileUtil;
-
-import java.io.FileOutputStream;
import java.io.IOException;
-import java.io.OutputStream;
import java.io.PrintStream;
import java.util.Arrays;
+import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
-import java.util.zip.GZIPOutputStream;
+
+import phrase.Corpus.Edge;
public class PhraseCluster {
public int K;
- public double scale;
- private int n_phrase;
- private int n_words;
- public PhraseCorpus c;
+ public double scalePT, scaleCT;
+ private int n_phrases, n_words, n_contexts, n_positions;
+ public Corpus c;
private ExecutorService pool;
- /**@brief
- * emit[tag][position][word]
- */
+ // emit[tag][position][word] = p(word | tag, position in context)
private double emit[][][];
+ // pi[phrase][tag] = p(tag | phrase)
private double pi[][];
-
- public static void main(String[] args) {
+ public static void main(String[] args)
+ {
String input_fname = args[0];
int tags = Integer.parseInt(args[1]);
String output_fname = args[2];
int iterations = Integer.parseInt(args[3]);
- double scale = Double.parseDouble(args[4]);
- int threads = Integer.parseInt(args[5]);
- boolean runEM = Boolean.parseBoolean(args[6]);
-
- PhraseCorpus corpus = new PhraseCorpus(input_fname);
- PhraseCluster cluster = new PhraseCluster(tags, corpus, scale, threads);
+ double scalePT = Double.parseDouble(args[4]);
+ double scaleCT = Double.parseDouble(args[5]);
+ int threads = Integer.parseInt(args[6]);
+ boolean runEM = Boolean.parseBoolean(args[7]);
+
+ assert(tags >= 2);
+ assert(scalePT >= 0);
+ assert(scaleCT >= 0);
+
+ Corpus corpus = null;
+ try {
+ corpus = Corpus.readFromFile(FileUtil.openBufferedReader(input_fname));
+ } catch (IOException e) {
+ System.err.println("Failed to open input file: " + input_fname);
+ e.printStackTrace();
+ System.exit(1);
+ }
+ PhraseCluster cluster = new PhraseCluster(tags, corpus, scalePT, scaleCT, threads);
//PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out");
@@ -48,19 +58,25 @@ public class PhraseCluster { double o;
if (runEM || i < 3)
o = cluster.EM();
- else
- o = cluster.PREM();
+ else if (scaleCT == 0)
+ {
+ if (threads >= 1)
+ o = cluster.PREM_phrase_constraints_parallel();
+ else
+ o = cluster.PREM_phrase_constraints();
+ }
+ else
+ o = cluster.PREM_phrase_context_constraints();
+
//PhraseObjective.ps.
System.out.println("ITER: "+i+" objective: " + o);
last = o;
}
- if (runEM)
- {
- double l1lmax = cluster.posterior_l1lmax();
- System.out.println("Final l1lmax term " + l1lmax + ", total PR objective " + (last - scale*l1lmax));
- // nb. KL is 0 by definition
- }
+ double pl1lmax = cluster.phrase_l1lmax();
+ double cl1lmax = cluster.context_l1lmax();
+ System.out.println("Final posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax);
+ if (runEM) System.out.println("With PR objective " + (last - scalePT*pl1lmax - scaleCT*cl1lmax));
PrintStream ps=io.FileUtil.openOutFile(output_fname);
cluster.displayPosterior(ps);
@@ -75,17 +91,20 @@ public class PhraseCluster { cluster.finish();
}
- public PhraseCluster(int numCluster, PhraseCorpus corpus, double scale, int threads){
+ public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){
K=numCluster;
c=corpus;
- n_words=c.wordLex.size();
- n_phrase=c.data.length;
- this.scale = scale;
- if (threads > 0)
+ n_words=c.getNumWords();
+ n_phrases=c.getNumPhrases();
+ n_contexts=c.getNumContexts();
+ n_positions=c.getNumContextPositions();
+ this.scalePT = scalep;
+ this.scaleCT = scalec;
+ if (threads > 0 && scalec <= 0)
pool = Executors.newFixedThreadPool(threads);
- emit=new double [K][c.numContexts][n_words];
- pi=new double[n_phrase][K];
+ emit=new double [K][n_positions][n_words];
+ pi=new double[n_phrases][K];
for(double [][]i:emit){
for(double []j:i){
@@ -105,30 +124,32 @@ public class PhraseCluster { }
public double EM(){
- double [][][]exp_emit=new double [K][c.numContexts][n_words];
- double [][]exp_pi=new double[n_phrase][K];
+ double [][][]exp_emit=new double [K][n_positions][n_words];
+ double [][]exp_pi=new double[n_phrases][K];
double loglikelihood=0;
//E
- for(int phrase=0;phrase<c.data.length;phrase++){
- int [][] data=c.data[phrase];
- for(int ctx=0;ctx<data.length;ctx++){
- int context[]=data[ctx];
- double p[]=posterior(phrase,context);
+ for(int phrase=0; phrase < n_phrases; phrase++){
+ List<Edge> contexts = c.getEdgesForPhrase(phrase);
+
+ for (int ctx=0; ctx<contexts.size(); ctx++){
+ Edge edge = contexts.get(ctx);
+ double p[]=posterior(edge);
double z = arr.F.l1norm(p);
assert z > 0;
loglikelihood+=Math.log(z);
arr.F.l1normalize(p);
- int contextCnt=context[context.length-1];
+ int count = edge.getCount();
//increment expected count
+ TIntArrayList context = edge.getContext();
for(int tag=0;tag<K;tag++){
- for(int pos=0;pos<context.length-1;pos++){
- exp_emit[tag][pos][context[pos]]+=p[tag]*contextCnt;
+ for(int pos=0;pos<n_positions;pos++){
+ exp_emit[tag][pos][context.get(pos)]+=p[tag]*count;
}
- exp_pi[phrase][tag]+=p[tag]*contextCnt;
+ exp_pi[phrase][tag]+=p[tag]*count;
}
}
}
@@ -153,29 +174,32 @@ public class PhraseCluster { return loglikelihood;
}
- public double PREM(){
- if (pool != null)
- return PREMParallel();
+ public double PREM_phrase_constraints(){
+ assert (scaleCT <= 0);
- double [][][]exp_emit=new double [K][c.numContexts][n_words];
- double [][]exp_pi=new double[n_phrase][K];
+ double [][][]exp_emit=new double[K][n_positions][n_words];
+ double [][]exp_pi=new double[n_phrases][K];
- double loglikelihood=0;
- double primal=0;
+ double loglikelihood=0, kl=0, l1lmax=0, primal=0;
//E
- for(int phrase=0;phrase<c.data.length;phrase++){
+ for(int phrase=0; phrase<n_phrases; phrase++){
PhraseObjective po=new PhraseObjective(this,phrase);
po.optimizeWithProjectedGradientDescent();
double [][] q=po.posterior();
- loglikelihood+=po.llh;
- primal+=po.primal();
+ loglikelihood += po.loglikelihood();
+ kl += po.KL_divergence();
+ l1lmax += po.l1lmax();
+ primal += po.primal();
+ List<Edge> edges = c.getEdgesForPhrase(phrase);
+
for(int edge=0;edge<q.length;edge++){
- int []context=c.data[phrase][edge];
- int contextCnt=context[context.length-1];
+ Edge e = edges.get(edge);
+ TIntArrayList context = e.getContext();
+ int contextCnt = e.getCount();
//increment expected count
for(int tag=0;tag<K;tag++){
- for(int pos=0;pos<context.length-1;pos++){
- exp_emit[tag][pos][context[pos]]+=q[edge][tag]*contextCnt;
+ for(int pos=0;pos<n_positions;pos++){
+ exp_emit[tag][pos][context.get(pos)]+=q[edge][tag]*contextCnt;
}
exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
@@ -183,8 +207,9 @@ public class PhraseCluster { }
}
- System.out.println("Log likelihood: "+loglikelihood);
- System.out.println("Primal Objective: "+primal);
+ System.out.println("\tllh: " + loglikelihood);
+ System.out.println("\tKL: " + kl);
+ System.out.println("\tphrase l1lmax: " + l1lmax);
//M
for(double [][]i:exp_emit){
@@ -204,18 +229,21 @@ public class PhraseCluster { return primal;
}
- public double PREMParallel(){
+ public double PREM_phrase_constraints_parallel()
+ {
assert(pool != null);
+ assert(scaleCT <= 0);
+
final LinkedBlockingQueue<PhraseObjective> expectations
= new LinkedBlockingQueue<PhraseObjective>();
- double [][][]exp_emit=new double [K][c.numContexts][n_words];
- double [][]exp_pi=new double[n_phrase][K];
+ double [][][]exp_emit=new double [K][n_positions][n_words];
+ double [][]exp_pi=new double[n_phrases][K];
- double loglikelihood=0;
- double primal=0;
+ double loglikelihood=0, kl=0, l1lmax=0, primal=0;
+
//E
- for(int phrase=0;phrase<c.data.length;phrase++){
+ for(int phrase=0;phrase<n_phrases;phrase++){
final int p=phrase;
pool.execute(new Runnable() {
public void run() {
@@ -235,7 +263,7 @@ public class PhraseCluster { }
// aggregate the expectations as they become available
- for(int count=0;count<c.data.length;count++) {
+ for(int count=0;count<n_phrases;count++) {
try {
//System.out.println("" + Thread.currentThread().getId() + " reading queue #" + count);
@@ -245,109 +273,139 @@ public class PhraseCluster { int phrase = po.phrase;
//System.out.println("" + Thread.currentThread().getId() + " taken phrase " + phrase);
double [][] q=po.posterior();
- loglikelihood+=po.llh;
- primal+=po.primal();
+ loglikelihood += po.loglikelihood();
+ kl += po.KL_divergence();
+ l1lmax += po.l1lmax();
+ primal += po.primal();
+
+ List<Edge> edges = c.getEdgesForPhrase(phrase);
for(int edge=0;edge<q.length;edge++){
- int []context=c.data[phrase][edge];
- int contextCnt=context[context.length-1];
+ Edge e = edges.get(edge);
+ TIntArrayList context = e.getContext();
+ int contextCnt = e.getCount();
//increment expected count
for(int tag=0;tag<K;tag++){
- for(int pos=0;pos<context.length-1;pos++){
- exp_emit[tag][pos][context[pos]]+=q[edge][tag]*contextCnt;
+ for(int pos=0;pos<n_positions;pos++){
+ exp_emit[tag][pos][context.get(pos)]+=q[edge][tag]*contextCnt;
}
exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
}
}
- } catch (InterruptedException e){
+ } catch (InterruptedException e)
+ {
System.err.println("M-step thread interrupted. Probably fatal!");
e.printStackTrace();
}
}
- System.out.println("Log likelihood: "+loglikelihood);
- System.out.println("Primal Objective: "+primal);
+ System.out.println("\tllh: " + loglikelihood);
+ System.out.println("\tKL: " + kl);
+ System.out.println("\tphrase l1lmax: " + l1lmax);
//M
- for(double [][]i:exp_emit){
- for(double []j:i){
+ for(double [][]i:exp_emit)
+ for(double []j:i)
arr.F.l1normalize(j);
+ emit=exp_emit;
+
+ for(double []j:exp_pi)
+ arr.F.l1normalize(j);
+ pi=exp_pi;
+
+ return primal;
+ }
+
+ public double PREM_phrase_context_constraints(){
+ assert (scaleCT > 0);
+
+ double [][][]exp_emit=new double [K][n_positions][n_words];
+ double [][]exp_pi=new double[n_phrases][K];
+
+ //E step
+ // TODO: cache the lambda values (the null below)
+ PhraseContextObjective pco = new PhraseContextObjective(this, null);
+ pco.optimizeWithProjectedGradientDescent();
+
+ //now extract expectations
+ List<Corpus.Edge> edges = c.getEdges();
+ for(int e = 0; e < edges.size(); ++e)
+ {
+ double [] q = pco.posterior(e);
+ Corpus.Edge edge = edges.get(e);
+
+ TIntArrayList context = edge.getContext();
+ int contextCnt = edge.getCount();
+ //increment expected count
+ for(int tag=0;tag<K;tag++)
+ {
+ for(int pos=0;pos<n_positions;pos++)
+ exp_emit[tag][pos][context.get(pos)]+=q[tag]*contextCnt;
+ exp_pi[edge.getPhraseId()][tag]+=q[tag]*contextCnt;
}
}
+ System.out.println("\tllh: " + pco.loglikelihood());
+ System.out.println("\tKL: " + pco.KL_divergence());
+ System.out.println("\tphrase l1lmax: " + pco.phrase_l1lmax());
+ System.out.println("\tcontext l1lmax: " + pco.context_l1lmax());
+
+ //M step
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ arr.F.l1normalize(j);
emit=exp_emit;
- for(double []j:exp_pi){
+ for(double []j:exp_pi)
arr.F.l1normalize(j);
- }
-
pi=exp_pi;
- return primal;
- }
-
+ return pco.primal();
+ }
+
/**
- *
* @param phrase index of phrase
* @param ctx array of context
* @return unnormalized posterior
*/
- public double[]posterior(int phrase, int[]ctx){
- double[] prob=Arrays.copyOf(pi[phrase], K);
+ public double[] posterior(Corpus.Edge edge)
+ {
+ double[] prob=Arrays.copyOf(pi[edge.getPhraseId()], K);
- for(int tag=0;tag<K;tag++){
- for(int c=0;c<ctx.length-1;c++){
- int word=ctx[c];
- prob[tag]*=emit[tag][c][word];
- }
- }
+ TIntArrayList ctx = edge.getContext();
+ for(int tag=0;tag<K;tag++)
+ for(int c=0;c<n_positions;c++)
+ prob[tag]*=emit[tag][c][ctx.get(c)];
return prob;
}
public void displayPosterior(PrintStream ps)
- {
-
- c.buildList();
-
- for (int i = 0; i < n_phrase; ++i)
+ {
+ for (Edge edge : c.getEdges())
{
- int [][]data=c.data[i];
- for (int[] e: data)
- {
- double probs[] = posterior(i, e);
- arr.F.l1normalize(probs);
+ double probs[] = posterior(edge);
+ arr.F.l1normalize(probs);
- // emit phrase
- ps.print(c.phraseList[i]);
- ps.print("\t");
- ps.print(c.getContextString(e, true));
- int t=arr.F.argmax(probs);
- ps.println(" ||| C=" + t);
-
- //ps.print("||| C=" + e[e.length-1] + " |||");
-
- //ps.print(t+"||| [");
- //for(t=0;t<K;t++){
- // ps.print(probs[t]+", ");
- //}
- // for (int t = 0; t < numTags; ++t)
- // System.out.print(" " + probs[t]);
- //ps.println("]");
- }
+ // emit phrase
+ ps.print(edge.getPhraseString());
+ ps.print("\t");
+ ps.print(edge.getContextString(true));
+ int t=arr.F.argmax(probs);
+ ps.println(" ||| C=" + t);
}
}
public void displayModelParam(PrintStream ps)
{
-
- c.buildList();
+ final double EPS = 1e-6;
ps.println("P(tag|phrase)");
- for (int i = 0; i < n_phrase; ++i)
+ for (int i = 0; i < n_phrases; ++i)
{
- ps.print(c.phraseList[i]);
+ ps.print(c.getPhrase(i));
for(int j=0;j<pi[i].length;j++){
- ps.print("\t"+pi[i][j]);
+ if (pi[i][j] > EPS)
+ ps.print("\t" + j + ": " + pi[i][j]);
}
ps.println();
}
@@ -355,14 +413,11 @@ public class PhraseCluster { ps.println("P(word|tag,position)");
for (int i = 0; i < K; ++i)
{
- for(int position=0;position<c.numContexts;position++){
+ for(int position=0;position<n_positions;position++){
ps.println("tag " + i + " position " + position);
for(int word=0;word<emit[i][position].length;word++){
- //if((word+1)%100==0){
- // ps.println();
- //}
- if (emit[i][position][word] > 1e-10)
- ps.print(c.wordList[word]+"="+emit[i][position][word]+"\t");
+ if (emit[i][position][word] > EPS)
+ ps.print(c.getWord(word)+"="+emit[i][position][word]+"\t");
}
ps.println();
}
@@ -371,19 +426,35 @@ public class PhraseCluster { }
- double posterior_l1lmax()
+ double phrase_l1lmax()
{
double sum=0;
- for(int phrase=0;phrase<c.data.length;phrase++)
+ for(int phrase=0; phrase<n_phrases; phrase++)
{
- int [][] data = c.data[phrase];
double [] maxes = new double[K];
- for(int ctx=0;ctx<data.length;ctx++)
+ for (Edge edge : c.getEdgesForPhrase(phrase))
{
- int context[]=data[ctx];
- double p[]=posterior(phrase,context);
+ double p[] = posterior(edge);
arr.F.l1normalize(p);
+ for(int tag=0;tag<K;tag++)
+ maxes[tag] = Math.max(maxes[tag], p[tag]);
+ }
+ for(int tag=0;tag<K;tag++)
+ sum += maxes[tag];
+ }
+ return sum;
+ }
+ double context_l1lmax()
+ {
+ double sum=0;
+ for(int context=0; context<n_contexts; context++)
+ {
+ double [] maxes = new double[K];
+ for (Edge edge : c.getEdgesForContext(context))
+ {
+ double p[] = posterior(edge);
+ arr.F.l1normalize(p);
for(int tag=0;tag<K;tag++)
maxes[tag] = Math.max(maxes[tag], p[tag]);
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java new file mode 100644 index 00000000..3273f0ad --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java @@ -0,0 +1,302 @@ +package phrase;
+
+import java.io.PrintStream;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import optimization.gradientBasedMethods.ProjectedGradientDescent;
+import optimization.gradientBasedMethods.ProjectedObjective;
+import optimization.gradientBasedMethods.stats.OptimizerStats;
+import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;
+import optimization.linesearch.InterpolationPickFirstStep;
+import optimization.linesearch.LineSearchMethod;
+import optimization.linesearch.WolfRuleLineSearch;
+import optimization.projections.SimplexProjection;
+import optimization.stopCriteria.CompositeStopingCriteria;
+import optimization.stopCriteria.ProjectedGradientL2Norm;
+import optimization.stopCriteria.StopingCriteria;
+import optimization.stopCriteria.ValueDifference;
+import optimization.util.MathUtils;
+import phrase.Corpus.Edge;
+
+public class PhraseContextObjective extends ProjectedObjective
+{
+ private static final double GRAD_DIFF = 0.00002;
+ private static double INIT_STEP_SIZE = 10;
+ private static double VAL_DIFF = 1e-4; // FIXME needs to be tuned
+ private static int ITERATIONS = 100;
+
+ private PhraseCluster c;
+
+ // un-regularized unnormalized posterior, p[edge][tag]
+ // P(tag|edge) \propto P(tag|phrase)P(context|tag)
+ private double p[][];
+
+ // regularized unnormalized posterior
+ // q[edge][tag] propto p[edge][tag]*exp(-lambda)
+ private double q[][];
+ private List<Corpus.Edge> data;
+
+ // log likelihood under q
+ private double loglikelihood;
+ private SimplexProjection projectionPhrase;
+ private SimplexProjection projectionContext;
+
+ double[] newPoint;
+ private int n_param;
+
+ // likelihood under p
+ public double llh;
+
+ private Map<Corpus.Edge, Integer> edgeIndex;
+
+ public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters)
+ {
+ c=cluster;
+ data=c.c.getEdges();
+ n_param=data.size()*c.K*2;
+
+ parameters = startingParameters;
+ if (parameters == null)
+ parameters = new double[n_param];
+
+ newPoint = new double[n_param];
+ gradient = new double[n_param];
+ initP();
+ projectionPhrase = new SimplexProjection(c.scalePT);
+ projectionContext = new SimplexProjection(c.scaleCT);
+ q=new double [data.size()][c.K];
+
+ edgeIndex = new HashMap<Edge, Integer>();
+ for (int e=0; e<data.size(); e++)
+ edgeIndex.put(data.get(e), e);
+
+ setParameters(parameters);
+ }
+
+ private void initP(){
+ p=new double[data.size()][];
+ for(int edge=0;edge<data.size();edge++)
+ {
+ p[edge]=c.posterior(data.get(edge));
+ llh += data.get(edge).getCount() * Math.log(arr.F.l1norm(p[edge]));
+ arr.F.l1normalize(p[edge]);
+ }
+ }
+
+ @Override
+ public void setParameters(double[] params) {
+ //System.out.println("setParameters " + Arrays.toString(parameters));
+ // TODO: test if params have changed and skip update otherwise
+ super.setParameters(params);
+ updateFunction();
+ }
+
+ private void updateFunction()
+ {
+ updateCalls++;
+ loglikelihood=0;
+
+ for (int e=0; e<data.size(); e++)
+ {
+ Edge edge = data.get(e);
+ int offset = edgeIndex.get(edge)*c.K*2;
+ for(int tag=0; tag<c.K; tag++)
+ {
+ int ip = offset + tag*2;
+ int ic = ip + 1;
+ q[e][tag] = p[e][tag]*
+ Math.exp((-parameters[ip]-parameters[ic]) / edge.getCount());
+ }
+ }
+
+ for(int edge=0;edge<data.size();edge++){
+ loglikelihood+=data.get(edge).getCount() * Math.log(arr.F.l1norm(q[edge]));
+ arr.F.l1normalize(q[edge]);
+ }
+
+ for (int e=0; e<data.size(); e++)
+ {
+ Edge edge = data.get(e);
+ int offset = edgeIndex.get(edge)*c.K*2;
+ for(int tag=0; tag<c.K; tag++)
+ {
+ int ip = offset + tag*2;
+ int ic = ip + 1;
+ gradient[ip]=-q[e][tag];
+ gradient[ic]=-q[e][tag];
+ }
+ }
+ //System.out.println("objective " + loglikelihood + " gradient: " + Arrays.toString(gradient));
+ }
+
+ @Override
+ public double[] projectPoint(double[] point)
+ {
+ //System.out.println("projectPoint: " + Arrays.toString(point));
+ Arrays.fill(newPoint, 0, newPoint.length, 0);
+ if (c.scalePT > 0)
+ {
+ // first project using the phrase-tag constraints,
+ // for all p,t: sum_c lambda_ptc < scaleP
+ for (int p = 0; p < c.c.getNumPhrases(); ++p)
+ {
+ List<Edge> edges = c.c.getEdgesForPhrase(p);
+ double toProject[] = new double[edges.size()];
+ for(int tag=0;tag<c.K;tag++)
+ {
+ for(int e=0; e<edges.size(); e++)
+ toProject[e] = point[index(edges.get(e), tag, true)];
+ projectionPhrase.project(toProject);
+ for(int e=0; e<edges.size(); e++)
+ newPoint[index(edges.get(e),tag, true)] = toProject[e];
+ }
+ }
+ }
+ //System.out.println("after PT " + Arrays.toString(newPoint));
+
+ if (c.scaleCT > 1e-6)
+ {
+ // now project using the context-tag constraints,
+ // for all c,t: sum_p omega_pct < scaleC
+ for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
+ {
+ List<Edge> edges = c.c.getEdgesForContext(ctx);
+ double toProject[] = new double[edges.size()];
+ for(int tag=0;tag<c.K;tag++)
+ {
+ for(int e=0; e<edges.size(); e++)
+ toProject[e] = point[index(edges.get(e), tag, false)];
+ projectionContext.project(toProject);
+ for(int e=0; e<edges.size(); e++)
+ newPoint[index(edges.get(e),tag, false)] = toProject[e];
+ }
+ }
+ }
+ double[] tmp = newPoint;
+ newPoint = point;
+
+ //System.out.println("\treturning " + Arrays.toString(tmp));
+ return tmp;
+ }
+
+ private int index(Edge edge, int tag, boolean phrase)
+ {
+ // NB if indexing changes must also change code in updateFunction and constructor
+ if (phrase)
+ return edgeIndex.get(edge)*c.K*2 + tag*2;
+ else
+ return edgeIndex.get(edge)*c.K*2 + tag*2 + 1;
+ }
+
+ @Override
+ public double[] getGradient() {
+ gradientCalls++;
+ return gradient;
+ }
+
+ @Override
+ public double getValue() {
+ functionCalls++;
+ return loglikelihood;
+ }
+
+ @Override
+ public String toString() {
+ return "No need for pointless toString";
+ }
+
+ public double []posterior(int edgeIndex){
+ return q[edgeIndex];
+ }
+
+ public double[] optimizeWithProjectedGradientDescent()
+ {
+ LineSearchMethod ls =
+ new ArmijoLineSearchMinimizationAlongProjectionArc
+ (new InterpolationPickFirstStep(INIT_STEP_SIZE));
+ //LineSearchMethod ls = new WolfRuleLineSearch(
+ // (new InterpolationPickFirstStep(INIT_STEP_SIZE)), c1, c2);
+ OptimizerStats stats = new OptimizerStats();
+
+
+ ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
+ StopingCriteria stopGrad = new ProjectedGradientL2Norm(GRAD_DIFF);
+ StopingCriteria stopValue = new ValueDifference(VAL_DIFF*(-llh));
+ CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
+ compositeStop.add(stopGrad);
+ compositeStop.add(stopValue);
+ optimizer.setMaxIterations(ITERATIONS);
+ updateFunction();
+ boolean succed = optimizer.optimize(this,stats,compositeStop);
+// System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
+ if(succed){
+ //System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
+ }else{
+ System.out.println("Failed to optimize");
+ }
+ // ps.println(Arrays.toString(parameters));
+
+ // for(int edge=0;edge<data.getSize();edge++){
+ // ps.println(Arrays.toString(q[edge]));
+ // }
+ //System.out.println(Arrays.toString(parameters));
+
+ return parameters;
+ }
+
+ double loglikelihood()
+ {
+ return llh;
+ }
+
+ double KL_divergence()
+ {
+ return -loglikelihood + MathUtils.dotProduct(parameters, gradient);
+ }
+
+ double phrase_l1lmax()
+ {
+ // \sum_{tag,phrase} max_{context} P(tag|context,phrase)
+ double sum=0;
+ for (int p = 0; p < c.c.getNumPhrases(); ++p)
+ {
+ List<Edge> edges = c.c.getEdgesForPhrase(p);
+ for(int tag=0;tag<c.K;tag++)
+ {
+ double max=0;
+ for (Edge edge: edges)
+ max = Math.max(max, q[edgeIndex.get(edge)][tag]);
+ sum+=max;
+ }
+ }
+ return sum;
+ }
+
+ double context_l1lmax()
+ {
+ // \sum_{tag,context} max_{phrase} P(tag|context,phrase)
+ double sum=0;
+ for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
+ {
+ List<Edge> edges = c.c.getEdgesForContext(ctx);
+ for(int tag=0; tag<c.K; tag++)
+ {
+ double max=0;
+ for (Edge edge: edges)
+ max = Math.max(max, q[edgeIndex.get(edge)][tag]);
+ sum+=max;
+ }
+ }
+ return sum;
+ }
+
+ // L - KL(q||p) - scalePT * l1lmax_phrase - scaleCT * l1lmax_context
+ public double primal()
+ {
+ return loglikelihood() - KL_divergence() - c.scalePT * phrase_l1lmax() - c.scalePT * context_l1lmax();
+ }
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java index b8f1f24a..11e948ff 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java @@ -8,11 +8,8 @@ import java.util.ArrayList; import java.util.HashMap;
import java.util.Scanner;
-public class PhraseCorpus {
-
- public static String LEX_FILENAME="../pdata/lex.out";
- public static String DATA_FILENAME="../pdata/btec.con";
-
+public class PhraseCorpus
+{
public HashMap<String,Integer>wordLex;
public HashMap<String,Integer>phraseLex;
@@ -21,16 +18,8 @@ public class PhraseCorpus { //data[phrase][num context][position]
public int data[][][];
- public int numContexts;
-
- public static void main(String[] args) {
- // TODO Auto-generated method stub
- PhraseCorpus c=new PhraseCorpus(DATA_FILENAME);
- c.saveLex(LEX_FILENAME);
- c.loadLex(LEX_FILENAME);
- c.saveLex(LEX_FILENAME);
- }
-
+ public int numContexts;
+
public PhraseCorpus(String filename){
BufferedReader r=io.FileUtil.openBufferedReader(filename);
@@ -185,5 +174,13 @@ public class PhraseCorpus { }
return null;
}
-
+
+ public static void main(String[] args) {
+ String LEX_FILENAME="../pdata/lex.out";
+ String DATA_FILENAME="../pdata/btec.con";
+ PhraseCorpus c=new PhraseCorpus(DATA_FILENAME);
+ c.saveLex(LEX_FILENAME);
+ c.loadLex(LEX_FILENAME);
+ c.saveLex(LEX_FILENAME);
+ }
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 0fdc169b..015ef106 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -2,6 +2,7 @@ package phrase; import java.io.PrintStream;
import java.util.Arrays;
+import java.util.List;
import optimization.gradientBasedMethods.ProjectedGradientDescent;
import optimization.gradientBasedMethods.ProjectedObjective;
@@ -17,11 +18,12 @@ import optimization.stopCriteria.StopingCriteria; import optimization.stopCriteria.ValueDifference;
import optimization.util.MathUtils;
-public class PhraseObjective extends ProjectedObjective{
-
- private static final double GRAD_DIFF = 0.00002;
- public static double INIT_STEP_SIZE = 10;
- public static double VAL_DIFF = 0.000001; // FIXME needs to be tuned
+public class PhraseObjective extends ProjectedObjective
+{
+ static final double GRAD_DIFF = 0.00002;
+ static double INIT_STEP_SIZE = 10;
+ static double VAL_DIFF = 1e-6; // FIXME needs to be tuned
+ static int ITERATIONS = 100;
//private double c1=0.0001; // wolf stuff
//private double c2=0.9;
private static double lambda[][];
@@ -46,7 +48,7 @@ public class PhraseObjective extends ProjectedObjective{ * q[edge][tag] propto p[edge][tag]*exp(-lambda)
*/
private double q[][];
- private int data[][];
+ private List<Corpus.Edge> data;
/**@brief log likelihood of the associated phrase
*
@@ -66,14 +68,14 @@ public class PhraseObjective extends ProjectedObjective{ public PhraseObjective(PhraseCluster cluster, int phraseIdx){
phrase=phraseIdx;
c=cluster;
- data=c.c.data[phrase];
- n_param=data.length*c.K;
+ data=c.c.getEdgesForPhrase(phrase);
+ n_param=data.size()*c.K;
- if( lambda==null){
- lambda=new double[c.c.data.length][];
+ if (lambda==null){
+ lambda=new double[c.c.getNumPhrases()][];
}
- if(lambda[phrase]==null){
+ if (lambda[phrase]==null){
lambda[phrase]=new double[n_param];
}
@@ -81,22 +83,17 @@ public class PhraseObjective extends ProjectedObjective{ newPoint = new double[n_param];
gradient = new double[n_param];
initP();
- projection=new SimplexProjection(c.scale);
- q=new double [data.length][c.K];
+ projection=new SimplexProjection(c.scalePT);
+ q=new double [data.size()][c.K];
setParameters(parameters);
}
private void initP(){
- int countIdx=data[0].length-1;
-
- p=new double[data.length][];
- for(int edge=0;edge<data.length;edge++){
- p[edge]=c.posterior(phrase,data[edge]);
- }
- for(int edge=0;edge<data.length;edge++){
- llh+=Math.log
- (data[edge][countIdx]*arr.F.l1norm(p[edge]));
+ p=new double[data.size()][];
+ for(int edge=0;edge<data.size();edge++){
+ p[edge]=c.posterior(data.get(edge));
+ llh += data.get(edge).getCount() * Math.log(arr.F.l1norm(p[edge])); // Was bug here - count inside log!
arr.F.l1normalize(p[edge]);
}
}
@@ -110,37 +107,36 @@ public class PhraseObjective extends ProjectedObjective{ private void updateFunction(){
updateCalls++;
loglikelihood=0;
- int countIdx=data[0].length-1;
+
for(int tag=0;tag<c.K;tag++){
- for(int edge=0;edge<data.length;edge++){
+ for(int edge=0;edge<data.size();edge++){
q[edge][tag]=p[edge][tag]*
- Math.exp(-parameters[tag*data.length+edge]/data[edge][countIdx]);
+ Math.exp(-parameters[tag*data.size()+edge]/data.get(edge).getCount());
}
}
- for(int edge=0;edge<data.length;edge++){
- loglikelihood+=data[edge][countIdx] * Math.log(arr.F.l1norm(q[edge]));
+ for(int edge=0;edge<data.size();edge++){
+ loglikelihood+=data.get(edge).getCount() * Math.log(arr.F.l1norm(q[edge]));
arr.F.l1normalize(q[edge]);
}
for(int tag=0;tag<c.K;tag++){
- for(int edge=0;edge<data.length;edge++){
- gradient[tag*data.length+edge]=-q[edge][tag];
+ for(int edge=0;edge<data.size();edge++){
+ gradient[tag*data.size()+edge]=-q[edge][tag];
}
}
}
@Override
- // TODO Auto-generated method stub
public double[] projectPoint(double[] point) {
- double toProject[]=new double[data.length];
+ double toProject[]=new double[data.size()];
for(int tag=0;tag<c.K;tag++){
- for(int edge=0;edge<data.length;edge++){
- toProject[edge]=point[tag*data.length+edge];
+ for(int edge=0;edge<data.size();edge++){
+ toProject[edge]=point[tag*data.size()+edge];
}
projection.project(toProject);
- for(int edge=0;edge<data.length;edge++){
- newPoint[tag*data.length+edge]=toProject[edge];
+ for(int edge=0;edge<data.size();edge++){
+ newPoint[tag*data.size()+edge]=toProject[edge];
}
}
return newPoint;
@@ -148,22 +144,19 @@ public class PhraseObjective extends ProjectedObjective{ @Override
public double[] getGradient() {
- // TODO Auto-generated method stub
gradientCalls++;
return gradient;
}
@Override
public double getValue() {
- // TODO Auto-generated method stub
functionCalls++;
return loglikelihood;
}
@Override
public String toString() {
- // TODO Auto-generated method stub
- return "";
+ return "No need for pointless toString";
}
public double [][]posterior(){
@@ -185,7 +178,7 @@ public class PhraseObjective extends ProjectedObjective{ CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
compositeStop.add(stopGrad);
compositeStop.add(stopValue);
- optimizer.setMaxIterations(100);
+ optimizer.setMaxIterations(ITERATIONS);
updateFunction();
boolean succed = optimizer.optimize(this,stats,compositeStop);
// System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
@@ -197,45 +190,38 @@ public class PhraseObjective extends ProjectedObjective{ lambda[phrase]=parameters;
// ps.println(Arrays.toString(parameters));
- // for(int edge=0;edge<data.length;edge++){
+ // for(int edge=0;edge<data.getSize();edge++){
// ps.println(Arrays.toString(q[edge]));
// }
}
- /**
- * L - KL(q||p) -
- * scale * \sum_{tag,phrase} max_i P(tag|i th occurrence of phrase)
- * @return
- */
- public double primal()
+ public double KL_divergence()
+ {
+ return -loglikelihood + MathUtils.dotProduct(parameters, gradient);
+ }
+
+ public double loglikelihood()
+ {
+ return llh;
+ }
+
+ public double l1lmax()
{
-
- double l=llh;
-
-// ps.print("Phrase "+phrase+": "+l);
- double kl=-loglikelihood
- +MathUtils.dotProduct(parameters, gradient);
-// ps.print(", "+kl);
- //System.out.println("llh " + llh);
- //System.out.println("kl " + kl);
-
-
- l=l-kl;
double sum=0;
for(int tag=0;tag<c.K;tag++){
double max=0;
- for(int edge=0;edge<data.length;edge++){
- if(q[edge][tag]>max){
+ for(int edge=0;edge<data.size();edge++){
+ if(q[edge][tag]>max)
max=q[edge][tag];
- }
}
sum+=max;
}
- //System.out.println("l1lmax " + sum);
-// ps.println(", "+sum);
- l=l-c.scale*sum;
- return l;
+ return sum;
+ }
+
+ public double primal()
+ {
+ return loglikelihood() - KL_divergence() - c.scalePT * l1lmax();
}
-
}
|