summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-08 21:46:05 +0000
committertrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-08 21:46:05 +0000
commitf2be77ccae455563e167607ee7527abbf8d96e60 (patch)
treee5e02f908fca912f6ed4f0582fff7ecef926f24e
parentfab1e67814aca6abba75fb1bee086255bea8daf0 (diff)
New context constraints.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@190 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Corpus.java221
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Lexicon.java34
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java347
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java302
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java29
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java120
6 files changed, 832 insertions, 221 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
new file mode 100644
index 00000000..d5e856ca
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
@@ -0,0 +1,221 @@
+package phrase;
+
+import gnu.trove.TIntArrayList;
+
+import java.io.*;
+import java.util.*;
+import java.util.regex.Pattern;
+
+
+public class Corpus
+{
+ private Lexicon<String> wordLexicon = new Lexicon<String>();
+ private Lexicon<TIntArrayList> phraseLexicon = new Lexicon<TIntArrayList>();
+ private Lexicon<TIntArrayList> contextLexicon = new Lexicon<TIntArrayList>();
+ private List<Edge> edges = new ArrayList<Edge>();
+ private List<List<Edge>> phraseToContext = new ArrayList<List<Edge>>();
+ private List<List<Edge>> contextToPhrase = new ArrayList<List<Edge>>();
+
+ public class Edge
+ {
+ Edge(int phraseId, int contextId, int count)
+ {
+ this.phraseId = phraseId;
+ this.contextId = contextId;
+ this.count = count;
+ }
+ public int getPhraseId()
+ {
+ return phraseId;
+ }
+ public TIntArrayList getPhrase()
+ {
+ return Corpus.this.getPhrase(phraseId);
+ }
+ public String getPhraseString()
+ {
+ return Corpus.this.getPhraseString(phraseId);
+ }
+ public int getContextId()
+ {
+ return contextId;
+ }
+ public TIntArrayList getContext()
+ {
+ return Corpus.this.getContext(contextId);
+ }
+ public String getContextString(boolean insertPhraseSentinel)
+ {
+ return Corpus.this.getContextString(contextId, insertPhraseSentinel);
+ }
+ public int getCount()
+ {
+ return count;
+ }
+ public boolean equals(Object other)
+ {
+ if (other instanceof Edge)
+ {
+ Edge oe = (Edge) other;
+ return oe.phraseId == phraseId && oe.contextId == contextId;
+ }
+ else return false;
+ }
+ public int hashCode()
+ { // this is how boost's hash_combine does it
+ int seed = phraseId;
+ seed ^= contextId + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+ return seed;
+ }
+ public String toString()
+ {
+ return getPhraseString() + "\t" + getContextString(true);
+ }
+
+ private int phraseId;
+ private int contextId;
+ private int count;
+ }
+
+ List<Edge> getEdges()
+ {
+ return edges;
+ }
+
+ int getNumEdges()
+ {
+ return edges.size();
+ }
+
+ int getNumPhrases()
+ {
+ return phraseLexicon.size();
+ }
+
+ int getNumContextPositions()
+ {
+ return contextLexicon.lookup(0).size();
+ }
+
+ List<Edge> getEdgesForPhrase(int phraseId)
+ {
+ return phraseToContext.get(phraseId);
+ }
+
+ int getNumContexts()
+ {
+ return contextLexicon.size();
+ }
+
+ List<Edge> getEdgesForContext(int contextId)
+ {
+ return contextToPhrase.get(contextId);
+ }
+
+ int getNumWords()
+ {
+ return wordLexicon.size();
+ }
+
+ String getWord(int wordId)
+ {
+ return wordLexicon.lookup(wordId);
+ }
+
+ public TIntArrayList getPhrase(int phraseId)
+ {
+ return phraseLexicon.lookup(phraseId);
+ }
+
+ public String getPhraseString(int phraseId)
+ {
+ StringBuffer b = new StringBuffer();
+ for (int tid: getPhrase(phraseId).toNativeArray())
+ {
+ if (b.length() > 0)
+ b.append(" ");
+ b.append(wordLexicon.lookup(tid));
+ }
+ return b.toString();
+ }
+
+ public TIntArrayList getContext(int contextId)
+ {
+ return contextLexicon.lookup(contextId);
+ }
+
+ public String getContextString(int contextId, boolean insertPhraseSentinel)
+ {
+ StringBuffer b = new StringBuffer();
+ TIntArrayList c = getContext(contextId);
+ for (int i = 0; i < c.size(); ++i)
+ {
+ if (i > 0) b.append(" ");
+ if (i == c.size() / 2) b.append("<PHRASE> ");
+ b.append(wordLexicon.lookup(c.get(i)));
+ }
+ return b.toString();
+ }
+
+ static Corpus readFromFile(Reader in) throws IOException
+ {
+ Corpus c = new Corpus();
+
+ // read in line-by-line
+ BufferedReader bin = new BufferedReader(in);
+ String line;
+ Pattern separator = Pattern.compile(" \\|\\|\\| ");
+
+ while ((line = bin.readLine()) != null)
+ {
+ // split into phrase and contexts
+ StringTokenizer st = new StringTokenizer(line, "\t");
+ assert (st.hasMoreTokens());
+ String phraseToks = st.nextToken();
+ assert (st.hasMoreTokens());
+ String rest = st.nextToken();
+ assert (!st.hasMoreTokens());
+
+ // process phrase
+ st = new StringTokenizer(phraseToks, " ");
+ TIntArrayList ptoks = new TIntArrayList();
+ while (st.hasMoreTokens())
+ ptoks.add(c.wordLexicon.insert(st.nextToken()));
+ int phraseId = c.phraseLexicon.insert(ptoks);
+ if (phraseId == c.phraseToContext.size())
+ c.phraseToContext.add(new ArrayList<Edge>());
+
+ // process contexts
+ String[] parts = separator.split(rest);
+ assert (parts.length % 2 == 0);
+ for (int i = 0; i < parts.length; i += 2)
+ {
+ // process pairs of strings - context and count
+ TIntArrayList ctx = new TIntArrayList();
+ String ctxString = parts[i];
+ String countString = parts[i + 1];
+ StringTokenizer ctxStrtok = new StringTokenizer(ctxString, " ");
+ while (ctxStrtok.hasMoreTokens())
+ {
+ String token = ctxStrtok.nextToken();
+ if (!token.equals("<PHRASE>"))
+ ctx.add(c.wordLexicon.insert(token));
+ }
+ int contextId = c.contextLexicon.insert(ctx);
+ if (contextId == c.contextToPhrase.size())
+ c.contextToPhrase.add(new ArrayList<Edge>());
+
+ assert (countString.startsWith("C="));
+ Edge e = c.new Edge(phraseId, contextId,
+ Integer.parseInt(countString.substring(2).trim()));
+ c.edges.add(e);
+
+ // index the edge for fast phrase, context lookup
+ c.phraseToContext.get(phraseId).add(e);
+ c.contextToPhrase.get(contextId).add(e);
+ }
+ }
+
+ return c;
+ }
+}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Lexicon.java b/gi/posterior-regularisation/prjava/src/phrase/Lexicon.java
new file mode 100644
index 00000000..a386e4a3
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/phrase/Lexicon.java
@@ -0,0 +1,34 @@
+package phrase;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class Lexicon<T>
+{
+ public int insert(T word)
+ {
+ Integer i = wordToIndex.get(word);
+ if (i == null)
+ {
+ i = indexToWord.size();
+ wordToIndex.put(word, i);
+ indexToWord.add(word);
+ }
+ return i;
+ }
+
+ public T lookup(int index)
+ {
+ return indexToWord.get(index);
+ }
+
+ public int size()
+ {
+ return indexToWord.size();
+ }
+
+ private Map<T, Integer> wordToIndex = new HashMap<T, Integer>();
+ private List<T> indexToWord = new ArrayList<T>();
+} \ No newline at end of file
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index 731d03ac..e4db2a1a 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -1,44 +1,54 @@
package phrase;
+import gnu.trove.TIntArrayList;
import io.FileUtil;
-
-import java.io.FileOutputStream;
import java.io.IOException;
-import java.io.OutputStream;
import java.io.PrintStream;
import java.util.Arrays;
+import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
-import java.util.zip.GZIPOutputStream;
+
+import phrase.Corpus.Edge;
public class PhraseCluster {
public int K;
- public double scale;
- private int n_phrase;
- private int n_words;
- public PhraseCorpus c;
+ public double scalePT, scaleCT;
+ private int n_phrases, n_words, n_contexts, n_positions;
+ public Corpus c;
private ExecutorService pool;
- /**@brief
- * emit[tag][position][word]
- */
+ // emit[tag][position][word] = p(word | tag, position in context)
private double emit[][][];
+ // pi[phrase][tag] = p(tag | phrase)
private double pi[][];
-
- public static void main(String[] args) {
+ public static void main(String[] args)
+ {
String input_fname = args[0];
int tags = Integer.parseInt(args[1]);
String output_fname = args[2];
int iterations = Integer.parseInt(args[3]);
- double scale = Double.parseDouble(args[4]);
- int threads = Integer.parseInt(args[5]);
- boolean runEM = Boolean.parseBoolean(args[6]);
-
- PhraseCorpus corpus = new PhraseCorpus(input_fname);
- PhraseCluster cluster = new PhraseCluster(tags, corpus, scale, threads);
+ double scalePT = Double.parseDouble(args[4]);
+ double scaleCT = Double.parseDouble(args[5]);
+ int threads = Integer.parseInt(args[6]);
+ boolean runEM = Boolean.parseBoolean(args[7]);
+
+ assert(tags >= 2);
+ assert(scalePT >= 0);
+ assert(scaleCT >= 0);
+
+ Corpus corpus = null;
+ try {
+ corpus = Corpus.readFromFile(FileUtil.openBufferedReader(input_fname));
+ } catch (IOException e) {
+ System.err.println("Failed to open input file: " + input_fname);
+ e.printStackTrace();
+ System.exit(1);
+ }
+ PhraseCluster cluster = new PhraseCluster(tags, corpus, scalePT, scaleCT, threads);
//PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out");
@@ -48,19 +58,25 @@ public class PhraseCluster {
double o;
if (runEM || i < 3)
o = cluster.EM();
- else
- o = cluster.PREM();
+ else if (scaleCT == 0)
+ {
+ if (threads >= 1)
+ o = cluster.PREM_phrase_constraints_parallel();
+ else
+ o = cluster.PREM_phrase_constraints();
+ }
+ else
+ o = cluster.PREM_phrase_context_constraints();
+
//PhraseObjective.ps.
System.out.println("ITER: "+i+" objective: " + o);
last = o;
}
- if (runEM)
- {
- double l1lmax = cluster.posterior_l1lmax();
- System.out.println("Final l1lmax term " + l1lmax + ", total PR objective " + (last - scale*l1lmax));
- // nb. KL is 0 by definition
- }
+ double pl1lmax = cluster.phrase_l1lmax();
+ double cl1lmax = cluster.context_l1lmax();
+ System.out.println("Final posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax);
+ if (runEM) System.out.println("With PR objective " + (last - scalePT*pl1lmax - scaleCT*cl1lmax));
PrintStream ps=io.FileUtil.openOutFile(output_fname);
cluster.displayPosterior(ps);
@@ -75,17 +91,20 @@ public class PhraseCluster {
cluster.finish();
}
- public PhraseCluster(int numCluster, PhraseCorpus corpus, double scale, int threads){
+ public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){
K=numCluster;
c=corpus;
- n_words=c.wordLex.size();
- n_phrase=c.data.length;
- this.scale = scale;
- if (threads > 0)
+ n_words=c.getNumWords();
+ n_phrases=c.getNumPhrases();
+ n_contexts=c.getNumContexts();
+ n_positions=c.getNumContextPositions();
+ this.scalePT = scalep;
+ this.scaleCT = scalec;
+ if (threads > 0 && scalec <= 0)
pool = Executors.newFixedThreadPool(threads);
- emit=new double [K][c.numContexts][n_words];
- pi=new double[n_phrase][K];
+ emit=new double [K][n_positions][n_words];
+ pi=new double[n_phrases][K];
for(double [][]i:emit){
for(double []j:i){
@@ -105,30 +124,32 @@ public class PhraseCluster {
}
public double EM(){
- double [][][]exp_emit=new double [K][c.numContexts][n_words];
- double [][]exp_pi=new double[n_phrase][K];
+ double [][][]exp_emit=new double [K][n_positions][n_words];
+ double [][]exp_pi=new double[n_phrases][K];
double loglikelihood=0;
//E
- for(int phrase=0;phrase<c.data.length;phrase++){
- int [][] data=c.data[phrase];
- for(int ctx=0;ctx<data.length;ctx++){
- int context[]=data[ctx];
- double p[]=posterior(phrase,context);
+ for(int phrase=0; phrase < n_phrases; phrase++){
+ List<Edge> contexts = c.getEdgesForPhrase(phrase);
+
+ for (int ctx=0; ctx<contexts.size(); ctx++){
+ Edge edge = contexts.get(ctx);
+ double p[]=posterior(edge);
double z = arr.F.l1norm(p);
assert z > 0;
loglikelihood+=Math.log(z);
arr.F.l1normalize(p);
- int contextCnt=context[context.length-1];
+ int count = edge.getCount();
//increment expected count
+ TIntArrayList context = edge.getContext();
for(int tag=0;tag<K;tag++){
- for(int pos=0;pos<context.length-1;pos++){
- exp_emit[tag][pos][context[pos]]+=p[tag]*contextCnt;
+ for(int pos=0;pos<n_positions;pos++){
+ exp_emit[tag][pos][context.get(pos)]+=p[tag]*count;
}
- exp_pi[phrase][tag]+=p[tag]*contextCnt;
+ exp_pi[phrase][tag]+=p[tag]*count;
}
}
}
@@ -153,29 +174,32 @@ public class PhraseCluster {
return loglikelihood;
}
- public double PREM(){
- if (pool != null)
- return PREMParallel();
+ public double PREM_phrase_constraints(){
+ assert (scaleCT <= 0);
- double [][][]exp_emit=new double [K][c.numContexts][n_words];
- double [][]exp_pi=new double[n_phrase][K];
+ double [][][]exp_emit=new double[K][n_positions][n_words];
+ double [][]exp_pi=new double[n_phrases][K];
- double loglikelihood=0;
- double primal=0;
+ double loglikelihood=0, kl=0, l1lmax=0, primal=0;
//E
- for(int phrase=0;phrase<c.data.length;phrase++){
+ for(int phrase=0; phrase<n_phrases; phrase++){
PhraseObjective po=new PhraseObjective(this,phrase);
po.optimizeWithProjectedGradientDescent();
double [][] q=po.posterior();
- loglikelihood+=po.llh;
- primal+=po.primal();
+ loglikelihood += po.loglikelihood();
+ kl += po.KL_divergence();
+ l1lmax += po.l1lmax();
+ primal += po.primal();
+ List<Edge> edges = c.getEdgesForPhrase(phrase);
+
for(int edge=0;edge<q.length;edge++){
- int []context=c.data[phrase][edge];
- int contextCnt=context[context.length-1];
+ Edge e = edges.get(edge);
+ TIntArrayList context = e.getContext();
+ int contextCnt = e.getCount();
//increment expected count
for(int tag=0;tag<K;tag++){
- for(int pos=0;pos<context.length-1;pos++){
- exp_emit[tag][pos][context[pos]]+=q[edge][tag]*contextCnt;
+ for(int pos=0;pos<n_positions;pos++){
+ exp_emit[tag][pos][context.get(pos)]+=q[edge][tag]*contextCnt;
}
exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
@@ -183,8 +207,9 @@ public class PhraseCluster {
}
}
- System.out.println("Log likelihood: "+loglikelihood);
- System.out.println("Primal Objective: "+primal);
+ System.out.println("\tllh: " + loglikelihood);
+ System.out.println("\tKL: " + kl);
+ System.out.println("\tphrase l1lmax: " + l1lmax);
//M
for(double [][]i:exp_emit){
@@ -204,18 +229,21 @@ public class PhraseCluster {
return primal;
}
- public double PREMParallel(){
+ public double PREM_phrase_constraints_parallel()
+ {
assert(pool != null);
+ assert(scaleCT <= 0);
+
final LinkedBlockingQueue<PhraseObjective> expectations
= new LinkedBlockingQueue<PhraseObjective>();
- double [][][]exp_emit=new double [K][c.numContexts][n_words];
- double [][]exp_pi=new double[n_phrase][K];
+ double [][][]exp_emit=new double [K][n_positions][n_words];
+ double [][]exp_pi=new double[n_phrases][K];
- double loglikelihood=0;
- double primal=0;
+ double loglikelihood=0, kl=0, l1lmax=0, primal=0;
+
//E
- for(int phrase=0;phrase<c.data.length;phrase++){
+ for(int phrase=0;phrase<n_phrases;phrase++){
final int p=phrase;
pool.execute(new Runnable() {
public void run() {
@@ -235,7 +263,7 @@ public class PhraseCluster {
}
// aggregate the expectations as they become available
- for(int count=0;count<c.data.length;count++) {
+ for(int count=0;count<n_phrases;count++) {
try {
//System.out.println("" + Thread.currentThread().getId() + " reading queue #" + count);
@@ -245,109 +273,139 @@ public class PhraseCluster {
int phrase = po.phrase;
//System.out.println("" + Thread.currentThread().getId() + " taken phrase " + phrase);
double [][] q=po.posterior();
- loglikelihood+=po.llh;
- primal+=po.primal();
+ loglikelihood += po.loglikelihood();
+ kl += po.KL_divergence();
+ l1lmax += po.l1lmax();
+ primal += po.primal();
+
+ List<Edge> edges = c.getEdgesForPhrase(phrase);
for(int edge=0;edge<q.length;edge++){
- int []context=c.data[phrase][edge];
- int contextCnt=context[context.length-1];
+ Edge e = edges.get(edge);
+ TIntArrayList context = e.getContext();
+ int contextCnt = e.getCount();
//increment expected count
for(int tag=0;tag<K;tag++){
- for(int pos=0;pos<context.length-1;pos++){
- exp_emit[tag][pos][context[pos]]+=q[edge][tag]*contextCnt;
+ for(int pos=0;pos<n_positions;pos++){
+ exp_emit[tag][pos][context.get(pos)]+=q[edge][tag]*contextCnt;
}
exp_pi[phrase][tag]+=q[edge][tag]*contextCnt;
}
}
- } catch (InterruptedException e){
+ } catch (InterruptedException e)
+ {
System.err.println("M-step thread interrupted. Probably fatal!");
e.printStackTrace();
}
}
- System.out.println("Log likelihood: "+loglikelihood);
- System.out.println("Primal Objective: "+primal);
+ System.out.println("\tllh: " + loglikelihood);
+ System.out.println("\tKL: " + kl);
+ System.out.println("\tphrase l1lmax: " + l1lmax);
//M
- for(double [][]i:exp_emit){
- for(double []j:i){
+ for(double [][]i:exp_emit)
+ for(double []j:i)
arr.F.l1normalize(j);
+ emit=exp_emit;
+
+ for(double []j:exp_pi)
+ arr.F.l1normalize(j);
+ pi=exp_pi;
+
+ return primal;
+ }
+
+ public double PREM_phrase_context_constraints(){
+ assert (scaleCT > 0);
+
+ double [][][]exp_emit=new double [K][n_positions][n_words];
+ double [][]exp_pi=new double[n_phrases][K];
+
+ //E step
+ // TODO: cache the lambda values (the null below)
+ PhraseContextObjective pco = new PhraseContextObjective(this, null);
+ pco.optimizeWithProjectedGradientDescent();
+
+ //now extract expectations
+ List<Corpus.Edge> edges = c.getEdges();
+ for(int e = 0; e < edges.size(); ++e)
+ {
+ double [] q = pco.posterior(e);
+ Corpus.Edge edge = edges.get(e);
+
+ TIntArrayList context = edge.getContext();
+ int contextCnt = edge.getCount();
+ //increment expected count
+ for(int tag=0;tag<K;tag++)
+ {
+ for(int pos=0;pos<n_positions;pos++)
+ exp_emit[tag][pos][context.get(pos)]+=q[tag]*contextCnt;
+ exp_pi[edge.getPhraseId()][tag]+=q[tag]*contextCnt;
}
}
+ System.out.println("\tllh: " + pco.loglikelihood());
+ System.out.println("\tKL: " + pco.KL_divergence());
+ System.out.println("\tphrase l1lmax: " + pco.phrase_l1lmax());
+ System.out.println("\tcontext l1lmax: " + pco.context_l1lmax());
+
+ //M step
+ for(double [][]i:exp_emit)
+ for(double []j:i)
+ arr.F.l1normalize(j);
emit=exp_emit;
- for(double []j:exp_pi){
+ for(double []j:exp_pi)
arr.F.l1normalize(j);
- }
-
pi=exp_pi;
- return primal;
- }
-
+ return pco.primal();
+ }
+
/**
- *
* @param phrase index of phrase
* @param ctx array of context
* @return unnormalized posterior
*/
- public double[]posterior(int phrase, int[]ctx){
- double[] prob=Arrays.copyOf(pi[phrase], K);
+ public double[] posterior(Corpus.Edge edge)
+ {
+ double[] prob=Arrays.copyOf(pi[edge.getPhraseId()], K);
- for(int tag=0;tag<K;tag++){
- for(int c=0;c<ctx.length-1;c++){
- int word=ctx[c];
- prob[tag]*=emit[tag][c][word];
- }
- }
+ TIntArrayList ctx = edge.getContext();
+ for(int tag=0;tag<K;tag++)
+ for(int c=0;c<n_positions;c++)
+ prob[tag]*=emit[tag][c][ctx.get(c)];
return prob;
}
public void displayPosterior(PrintStream ps)
- {
-
- c.buildList();
-
- for (int i = 0; i < n_phrase; ++i)
+ {
+ for (Edge edge : c.getEdges())
{
- int [][]data=c.data[i];
- for (int[] e: data)
- {
- double probs[] = posterior(i, e);
- arr.F.l1normalize(probs);
+ double probs[] = posterior(edge);
+ arr.F.l1normalize(probs);
- // emit phrase
- ps.print(c.phraseList[i]);
- ps.print("\t");
- ps.print(c.getContextString(e, true));
- int t=arr.F.argmax(probs);
- ps.println(" ||| C=" + t);
-
- //ps.print("||| C=" + e[e.length-1] + " |||");
-
- //ps.print(t+"||| [");
- //for(t=0;t<K;t++){
- // ps.print(probs[t]+", ");
- //}
- // for (int t = 0; t < numTags; ++t)
- // System.out.print(" " + probs[t]);
- //ps.println("]");
- }
+ // emit phrase
+ ps.print(edge.getPhraseString());
+ ps.print("\t");
+ ps.print(edge.getContextString(true));
+ int t=arr.F.argmax(probs);
+ ps.println(" ||| C=" + t);
}
}
public void displayModelParam(PrintStream ps)
{
-
- c.buildList();
+ final double EPS = 1e-6;
ps.println("P(tag|phrase)");
- for (int i = 0; i < n_phrase; ++i)
+ for (int i = 0; i < n_phrases; ++i)
{
- ps.print(c.phraseList[i]);
+ ps.print(c.getPhrase(i));
for(int j=0;j<pi[i].length;j++){
- ps.print("\t"+pi[i][j]);
+ if (pi[i][j] > EPS)
+ ps.print("\t" + j + ": " + pi[i][j]);
}
ps.println();
}
@@ -355,14 +413,11 @@ public class PhraseCluster {
ps.println("P(word|tag,position)");
for (int i = 0; i < K; ++i)
{
- for(int position=0;position<c.numContexts;position++){
+ for(int position=0;position<n_positions;position++){
ps.println("tag " + i + " position " + position);
for(int word=0;word<emit[i][position].length;word++){
- //if((word+1)%100==0){
- // ps.println();
- //}
- if (emit[i][position][word] > 1e-10)
- ps.print(c.wordList[word]+"="+emit[i][position][word]+"\t");
+ if (emit[i][position][word] > EPS)
+ ps.print(c.getWord(word)+"="+emit[i][position][word]+"\t");
}
ps.println();
}
@@ -371,19 +426,35 @@ public class PhraseCluster {
}
- double posterior_l1lmax()
+ double phrase_l1lmax()
{
double sum=0;
- for(int phrase=0;phrase<c.data.length;phrase++)
+ for(int phrase=0; phrase<n_phrases; phrase++)
{
- int [][] data = c.data[phrase];
double [] maxes = new double[K];
- for(int ctx=0;ctx<data.length;ctx++)
+ for (Edge edge : c.getEdgesForPhrase(phrase))
{
- int context[]=data[ctx];
- double p[]=posterior(phrase,context);
+ double p[] = posterior(edge);
arr.F.l1normalize(p);
+ for(int tag=0;tag<K;tag++)
+ maxes[tag] = Math.max(maxes[tag], p[tag]);
+ }
+ for(int tag=0;tag<K;tag++)
+ sum += maxes[tag];
+ }
+ return sum;
+ }
+ double context_l1lmax()
+ {
+ double sum=0;
+ for(int context=0; context<n_contexts; context++)
+ {
+ double [] maxes = new double[K];
+ for (Edge edge : c.getEdgesForContext(context))
+ {
+ double p[] = posterior(edge);
+ arr.F.l1normalize(p);
for(int tag=0;tag<K;tag++)
maxes[tag] = Math.max(maxes[tag], p[tag]);
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
new file mode 100644
index 00000000..3273f0ad
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
@@ -0,0 +1,302 @@
+package phrase;
+
+import java.io.PrintStream;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import optimization.gradientBasedMethods.ProjectedGradientDescent;
+import optimization.gradientBasedMethods.ProjectedObjective;
+import optimization.gradientBasedMethods.stats.OptimizerStats;
+import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;
+import optimization.linesearch.InterpolationPickFirstStep;
+import optimization.linesearch.LineSearchMethod;
+import optimization.linesearch.WolfRuleLineSearch;
+import optimization.projections.SimplexProjection;
+import optimization.stopCriteria.CompositeStopingCriteria;
+import optimization.stopCriteria.ProjectedGradientL2Norm;
+import optimization.stopCriteria.StopingCriteria;
+import optimization.stopCriteria.ValueDifference;
+import optimization.util.MathUtils;
+import phrase.Corpus.Edge;
+
+public class PhraseContextObjective extends ProjectedObjective
+{
+ private static final double GRAD_DIFF = 0.00002;
+ private static double INIT_STEP_SIZE = 10;
+ private static double VAL_DIFF = 1e-4; // FIXME needs to be tuned
+ private static int ITERATIONS = 100;
+
+ private PhraseCluster c;
+
+ // un-regularized unnormalized posterior, p[edge][tag]
+ // P(tag|edge) \propto P(tag|phrase)P(context|tag)
+ private double p[][];
+
+ // regularized unnormalized posterior
+ // q[edge][tag] propto p[edge][tag]*exp(-lambda)
+ private double q[][];
+ private List<Corpus.Edge> data;
+
+ // log likelihood under q
+ private double loglikelihood;
+ private SimplexProjection projectionPhrase;
+ private SimplexProjection projectionContext;
+
+ double[] newPoint;
+ private int n_param;
+
+ // likelihood under p
+ public double llh;
+
+ private Map<Corpus.Edge, Integer> edgeIndex;
+
+ public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters)
+ {
+ c=cluster;
+ data=c.c.getEdges();
+ n_param=data.size()*c.K*2;
+
+ parameters = startingParameters;
+ if (parameters == null)
+ parameters = new double[n_param];
+
+ newPoint = new double[n_param];
+ gradient = new double[n_param];
+ initP();
+ projectionPhrase = new SimplexProjection(c.scalePT);
+ projectionContext = new SimplexProjection(c.scaleCT);
+ q=new double [data.size()][c.K];
+
+ edgeIndex = new HashMap<Edge, Integer>();
+ for (int e=0; e<data.size(); e++)
+ edgeIndex.put(data.get(e), e);
+
+ setParameters(parameters);
+ }
+
+ private void initP(){
+ p=new double[data.size()][];
+ for(int edge=0;edge<data.size();edge++)
+ {
+ p[edge]=c.posterior(data.get(edge));
+ llh += data.get(edge).getCount() * Math.log(arr.F.l1norm(p[edge]));
+ arr.F.l1normalize(p[edge]);
+ }
+ }
+
+ @Override
+ public void setParameters(double[] params) {
+ //System.out.println("setParameters " + Arrays.toString(parameters));
+ // TODO: test if params have changed and skip update otherwise
+ super.setParameters(params);
+ updateFunction();
+ }
+
+ private void updateFunction()
+ {
+ updateCalls++;
+ loglikelihood=0;
+
+ for (int e=0; e<data.size(); e++)
+ {
+ Edge edge = data.get(e);
+ int offset = edgeIndex.get(edge)*c.K*2;
+ for(int tag=0; tag<c.K; tag++)
+ {
+ int ip = offset + tag*2;
+ int ic = ip + 1;
+ q[e][tag] = p[e][tag]*
+ Math.exp((-parameters[ip]-parameters[ic]) / edge.getCount());
+ }
+ }
+
+ for(int edge=0;edge<data.size();edge++){
+ loglikelihood+=data.get(edge).getCount() * Math.log(arr.F.l1norm(q[edge]));
+ arr.F.l1normalize(q[edge]);
+ }
+
+ for (int e=0; e<data.size(); e++)
+ {
+ Edge edge = data.get(e);
+ int offset = edgeIndex.get(edge)*c.K*2;
+ for(int tag=0; tag<c.K; tag++)
+ {
+ int ip = offset + tag*2;
+ int ic = ip + 1;
+ gradient[ip]=-q[e][tag];
+ gradient[ic]=-q[e][tag];
+ }
+ }
+ //System.out.println("objective " + loglikelihood + " gradient: " + Arrays.toString(gradient));
+ }
+
+ @Override
+ public double[] projectPoint(double[] point)
+ {
+ //System.out.println("projectPoint: " + Arrays.toString(point));
+ Arrays.fill(newPoint, 0, newPoint.length, 0);
+ if (c.scalePT > 0)
+ {
+ // first project using the phrase-tag constraints,
+ // for all p,t: sum_c lambda_ptc < scaleP
+ for (int p = 0; p < c.c.getNumPhrases(); ++p)
+ {
+ List<Edge> edges = c.c.getEdgesForPhrase(p);
+ double toProject[] = new double[edges.size()];
+ for(int tag=0;tag<c.K;tag++)
+ {
+ for(int e=0; e<edges.size(); e++)
+ toProject[e] = point[index(edges.get(e), tag, true)];
+ projectionPhrase.project(toProject);
+ for(int e=0; e<edges.size(); e++)
+ newPoint[index(edges.get(e),tag, true)] = toProject[e];
+ }
+ }
+ }
+ //System.out.println("after PT " + Arrays.toString(newPoint));
+
+ if (c.scaleCT > 1e-6)
+ {
+ // now project using the context-tag constraints,
+ // for all c,t: sum_p omega_pct < scaleC
+ for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
+ {
+ List<Edge> edges = c.c.getEdgesForContext(ctx);
+ double toProject[] = new double[edges.size()];
+ for(int tag=0;tag<c.K;tag++)
+ {
+ for(int e=0; e<edges.size(); e++)
+ toProject[e] = point[index(edges.get(e), tag, false)];
+ projectionContext.project(toProject);
+ for(int e=0; e<edges.size(); e++)
+ newPoint[index(edges.get(e),tag, false)] = toProject[e];
+ }
+ }
+ }
+ double[] tmp = newPoint;
+ newPoint = point;
+
+ //System.out.println("\treturning " + Arrays.toString(tmp));
+ return tmp;
+ }
+
+ private int index(Edge edge, int tag, boolean phrase)
+ {
+ // NB if indexing changes must also change code in updateFunction and constructor
+ if (phrase)
+ return edgeIndex.get(edge)*c.K*2 + tag*2;
+ else
+ return edgeIndex.get(edge)*c.K*2 + tag*2 + 1;
+ }
+
+ @Override
+ public double[] getGradient() {
+ gradientCalls++;
+ return gradient;
+ }
+
+ @Override
+ public double getValue() {
+ functionCalls++;
+ return loglikelihood;
+ }
+
+ @Override
+ public String toString() {
+ return "No need for pointless toString";
+ }
+
+ public double []posterior(int edgeIndex){
+ return q[edgeIndex];
+ }
+
+ public double[] optimizeWithProjectedGradientDescent()
+ {
+ LineSearchMethod ls =
+ new ArmijoLineSearchMinimizationAlongProjectionArc
+ (new InterpolationPickFirstStep(INIT_STEP_SIZE));
+ //LineSearchMethod ls = new WolfRuleLineSearch(
+ // (new InterpolationPickFirstStep(INIT_STEP_SIZE)), c1, c2);
+ OptimizerStats stats = new OptimizerStats();
+
+
+ ProjectedGradientDescent optimizer = new ProjectedGradientDescent(ls);
+ StopingCriteria stopGrad = new ProjectedGradientL2Norm(GRAD_DIFF);
+ StopingCriteria stopValue = new ValueDifference(VAL_DIFF*(-llh));
+ CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
+ compositeStop.add(stopGrad);
+ compositeStop.add(stopValue);
+ optimizer.setMaxIterations(ITERATIONS);
+ updateFunction();
+ boolean succed = optimizer.optimize(this,stats,compositeStop);
+// System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
+ if(succed){
+ //System.out.println("Ended optimization in " + optimizer.getCurrentIteration());
+ }else{
+ System.out.println("Failed to optimize");
+ }
+ // ps.println(Arrays.toString(parameters));
+
+ // for(int edge=0;edge<data.getSize();edge++){
+ // ps.println(Arrays.toString(q[edge]));
+ // }
+ //System.out.println(Arrays.toString(parameters));
+
+ return parameters;
+ }
+
+ double loglikelihood()
+ {
+ return llh;
+ }
+
+ double KL_divergence()
+ {
+ return -loglikelihood + MathUtils.dotProduct(parameters, gradient);
+ }
+
+ double phrase_l1lmax()
+ {
+ // \sum_{tag,phrase} max_{context} P(tag|context,phrase)
+ double sum=0;
+ for (int p = 0; p < c.c.getNumPhrases(); ++p)
+ {
+ List<Edge> edges = c.c.getEdgesForPhrase(p);
+ for(int tag=0;tag<c.K;tag++)
+ {
+ double max=0;
+ for (Edge edge: edges)
+ max = Math.max(max, q[edgeIndex.get(edge)][tag]);
+ sum+=max;
+ }
+ }
+ return sum;
+ }
+
+ double context_l1lmax()
+ {
+ // \sum_{tag,context} max_{phrase} P(tag|context,phrase)
+ double sum=0;
+ for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx)
+ {
+ List<Edge> edges = c.c.getEdgesForContext(ctx);
+ for(int tag=0; tag<c.K; tag++)
+ {
+ double max=0;
+ for (Edge edge: edges)
+ max = Math.max(max, q[edgeIndex.get(edge)][tag]);
+ sum+=max;
+ }
+ }
+ return sum;
+ }
+
+ // L - KL(q||p) - scalePT * l1lmax_phrase - scaleCT * l1lmax_context
+ public double primal()
+ {
+ return loglikelihood() - KL_divergence() - c.scalePT * phrase_l1lmax() - c.scalePT * context_l1lmax();
+ }
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java
index b8f1f24a..11e948ff 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java
@@ -8,11 +8,8 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.Scanner;
-public class PhraseCorpus {
-
- public static String LEX_FILENAME="../pdata/lex.out";
- public static String DATA_FILENAME="../pdata/btec.con";
-
+public class PhraseCorpus
+{
public HashMap<String,Integer>wordLex;
public HashMap<String,Integer>phraseLex;
@@ -21,16 +18,8 @@ public class PhraseCorpus {
//data[phrase][num context][position]
public int data[][][];
- public int numContexts;
-
- public static void main(String[] args) {
- // TODO Auto-generated method stub
- PhraseCorpus c=new PhraseCorpus(DATA_FILENAME);
- c.saveLex(LEX_FILENAME);
- c.loadLex(LEX_FILENAME);
- c.saveLex(LEX_FILENAME);
- }
-
+ public int numContexts;
+
public PhraseCorpus(String filename){
BufferedReader r=io.FileUtil.openBufferedReader(filename);
@@ -185,5 +174,13 @@ public class PhraseCorpus {
}
return null;
}
-
+
+ public static void main(String[] args) {
+ String LEX_FILENAME="../pdata/lex.out";
+ String DATA_FILENAME="../pdata/btec.con";
+ PhraseCorpus c=new PhraseCorpus(DATA_FILENAME);
+ c.saveLex(LEX_FILENAME);
+ c.loadLex(LEX_FILENAME);
+ c.saveLex(LEX_FILENAME);
+ }
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
index 0fdc169b..015ef106 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
@@ -2,6 +2,7 @@ package phrase;
import java.io.PrintStream;
import java.util.Arrays;
+import java.util.List;
import optimization.gradientBasedMethods.ProjectedGradientDescent;
import optimization.gradientBasedMethods.ProjectedObjective;
@@ -17,11 +18,12 @@ import optimization.stopCriteria.StopingCriteria;
import optimization.stopCriteria.ValueDifference;
import optimization.util.MathUtils;
-public class PhraseObjective extends ProjectedObjective{
-
- private static final double GRAD_DIFF = 0.00002;
- public static double INIT_STEP_SIZE = 10;
- public static double VAL_DIFF = 0.000001; // FIXME needs to be tuned
+public class PhraseObjective extends ProjectedObjective
+{
+ static final double GRAD_DIFF = 0.00002;
+ static double INIT_STEP_SIZE = 10;
+ static double VAL_DIFF = 1e-6; // FIXME needs to be tuned
+ static int ITERATIONS = 100;
//private double c1=0.0001; // wolf stuff
//private double c2=0.9;
private static double lambda[][];
@@ -46,7 +48,7 @@ public class PhraseObjective extends ProjectedObjective{
* q[edge][tag] propto p[edge][tag]*exp(-lambda)
*/
private double q[][];
- private int data[][];
+ private List<Corpus.Edge> data;
/**@brief log likelihood of the associated phrase
*
@@ -66,14 +68,14 @@ public class PhraseObjective extends ProjectedObjective{
public PhraseObjective(PhraseCluster cluster, int phraseIdx){
phrase=phraseIdx;
c=cluster;
- data=c.c.data[phrase];
- n_param=data.length*c.K;
+ data=c.c.getEdgesForPhrase(phrase);
+ n_param=data.size()*c.K;
- if( lambda==null){
- lambda=new double[c.c.data.length][];
+ if (lambda==null){
+ lambda=new double[c.c.getNumPhrases()][];
}
- if(lambda[phrase]==null){
+ if (lambda[phrase]==null){
lambda[phrase]=new double[n_param];
}
@@ -81,22 +83,17 @@ public class PhraseObjective extends ProjectedObjective{
newPoint = new double[n_param];
gradient = new double[n_param];
initP();
- projection=new SimplexProjection(c.scale);
- q=new double [data.length][c.K];
+ projection=new SimplexProjection(c.scalePT);
+ q=new double [data.size()][c.K];
setParameters(parameters);
}
private void initP(){
- int countIdx=data[0].length-1;
-
- p=new double[data.length][];
- for(int edge=0;edge<data.length;edge++){
- p[edge]=c.posterior(phrase,data[edge]);
- }
- for(int edge=0;edge<data.length;edge++){
- llh+=Math.log
- (data[edge][countIdx]*arr.F.l1norm(p[edge]));
+ p=new double[data.size()][];
+ for(int edge=0;edge<data.size();edge++){
+ p[edge]=c.posterior(data.get(edge));
+ llh += data.get(edge).getCount() * Math.log(arr.F.l1norm(p[edge])); // Was bug here - count inside log!
arr.F.l1normalize(p[edge]);
}
}
@@ -110,37 +107,36 @@ public class PhraseObjective extends ProjectedObjective{
private void updateFunction(){
updateCalls++;
loglikelihood=0;
- int countIdx=data[0].length-1;
+
for(int tag=0;tag<c.K;tag++){
- for(int edge=0;edge<data.length;edge++){
+ for(int edge=0;edge<data.size();edge++){
q[edge][tag]=p[edge][tag]*
- Math.exp(-parameters[tag*data.length+edge]/data[edge][countIdx]);
+ Math.exp(-parameters[tag*data.size()+edge]/data.get(edge).getCount());
}
}
- for(int edge=0;edge<data.length;edge++){
- loglikelihood+=data[edge][countIdx] * Math.log(arr.F.l1norm(q[edge]));
+ for(int edge=0;edge<data.size();edge++){
+ loglikelihood+=data.get(edge).getCount() * Math.log(arr.F.l1norm(q[edge]));
arr.F.l1normalize(q[edge]);
}
for(int tag=0;tag<c.K;tag++){
- for(int edge=0;edge<data.length;edge++){
- gradient[tag*data.length+edge]=-q[edge][tag];
+ for(int edge=0;edge<data.size();edge++){
+ gradient[tag*data.size()+edge]=-q[edge][tag];
}
}
}
@Override
- // TODO Auto-generated method stub
public double[] projectPoint(double[] point) {
- double toProject[]=new double[data.length];
+ double toProject[]=new double[data.size()];
for(int tag=0;tag<c.K;tag++){
- for(int edge=0;edge<data.length;edge++){
- toProject[edge]=point[tag*data.length+edge];
+ for(int edge=0;edge<data.size();edge++){
+ toProject[edge]=point[tag*data.size()+edge];
}
projection.project(toProject);
- for(int edge=0;edge<data.length;edge++){
- newPoint[tag*data.length+edge]=toProject[edge];
+ for(int edge=0;edge<data.size();edge++){
+ newPoint[tag*data.size()+edge]=toProject[edge];
}
}
return newPoint;
@@ -148,22 +144,19 @@ public class PhraseObjective extends ProjectedObjective{
@Override
public double[] getGradient() {
- // TODO Auto-generated method stub
gradientCalls++;
return gradient;
}
@Override
public double getValue() {
- // TODO Auto-generated method stub
functionCalls++;
return loglikelihood;
}
@Override
public String toString() {
- // TODO Auto-generated method stub
- return "";
+ return "No need for pointless toString";
}
public double [][]posterior(){
@@ -185,7 +178,7 @@ public class PhraseObjective extends ProjectedObjective{
CompositeStopingCriteria compositeStop = new CompositeStopingCriteria();
compositeStop.add(stopGrad);
compositeStop.add(stopValue);
- optimizer.setMaxIterations(100);
+ optimizer.setMaxIterations(ITERATIONS);
updateFunction();
boolean succed = optimizer.optimize(this,stats,compositeStop);
// System.out.println("Ended optimzation Projected Gradient Descent\n" + stats.prettyPrint(1));
@@ -197,45 +190,38 @@ public class PhraseObjective extends ProjectedObjective{
lambda[phrase]=parameters;
// ps.println(Arrays.toString(parameters));
- // for(int edge=0;edge<data.length;edge++){
+ // for(int edge=0;edge<data.getSize();edge++){
// ps.println(Arrays.toString(q[edge]));
// }
}
- /**
- * L - KL(q||p) -
- * scale * \sum_{tag,phrase} max_i P(tag|i th occurrence of phrase)
- * @return
- */
- public double primal()
+ public double KL_divergence()
+ {
+ return -loglikelihood + MathUtils.dotProduct(parameters, gradient);
+ }
+
+ public double loglikelihood()
+ {
+ return llh;
+ }
+
+ public double l1lmax()
{
-
- double l=llh;
-
-// ps.print("Phrase "+phrase+": "+l);
- double kl=-loglikelihood
- +MathUtils.dotProduct(parameters, gradient);
-// ps.print(", "+kl);
- //System.out.println("llh " + llh);
- //System.out.println("kl " + kl);
-
-
- l=l-kl;
double sum=0;
for(int tag=0;tag<c.K;tag++){
double max=0;
- for(int edge=0;edge<data.length;edge++){
- if(q[edge][tag]>max){
+ for(int edge=0;edge<data.size();edge++){
+ if(q[edge][tag]>max)
max=q[edge][tag];
- }
}
sum+=max;
}
- //System.out.println("l1lmax " + sum);
-// ps.println(", "+sum);
- l=l-c.scale*sum;
- return l;
+ return sum;
+ }
+
+ public double primal()
+ {
+ return loglikelihood() - KL_divergence() - c.scalePT * l1lmax();
}
-
}