From a034f92b1fe0c6368ebb140bc691f0718dd23a23 Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Thu, 8 Jul 2010 21:46:05 +0000 Subject: New context constraints. git-svn-id: https://ws10smt.googlecode.com/svn/trunk@190 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../prjava/src/phrase/Corpus.java | 221 +++++++++++++ .../prjava/src/phrase/Lexicon.java | 34 ++ .../prjava/src/phrase/PhraseCluster.java | 347 +++++++++++++-------- .../prjava/src/phrase/PhraseContextObjective.java | 302 ++++++++++++++++++ .../prjava/src/phrase/PhraseCorpus.java | 29 +- .../prjava/src/phrase/PhraseObjective.java | 120 ++++--- 6 files changed, 832 insertions(+), 221 deletions(-) create mode 100644 gi/posterior-regularisation/prjava/src/phrase/Corpus.java create mode 100644 gi/posterior-regularisation/prjava/src/phrase/Lexicon.java create mode 100644 gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java (limited to 'gi/posterior-regularisation/prjava/src/phrase') diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java new file mode 100644 index 00000000..d5e856ca --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java @@ -0,0 +1,221 @@ +package phrase; + +import gnu.trove.TIntArrayList; + +import java.io.*; +import java.util.*; +import java.util.regex.Pattern; + + +public class Corpus +{ + private Lexicon wordLexicon = new Lexicon(); + private Lexicon phraseLexicon = new Lexicon(); + private Lexicon contextLexicon = new Lexicon(); + private List edges = new ArrayList(); + private List> phraseToContext = new ArrayList>(); + private List> contextToPhrase = new ArrayList>(); + + public class Edge + { + Edge(int phraseId, int contextId, int count) + { + this.phraseId = phraseId; + this.contextId = contextId; + this.count = count; + } + public int getPhraseId() + { + return phraseId; + } + public TIntArrayList getPhrase() + { + return Corpus.this.getPhrase(phraseId); + } + public String getPhraseString() + { + return Corpus.this.getPhraseString(phraseId); + } + public int getContextId() + { + return contextId; + } + public TIntArrayList getContext() + { + return Corpus.this.getContext(contextId); + } + public String getContextString(boolean insertPhraseSentinel) + { + return Corpus.this.getContextString(contextId, insertPhraseSentinel); + } + public int getCount() + { + return count; + } + public boolean equals(Object other) + { + if (other instanceof Edge) + { + Edge oe = (Edge) other; + return oe.phraseId == phraseId && oe.contextId == contextId; + } + else return false; + } + public int hashCode() + { // this is how boost's hash_combine does it + int seed = phraseId; + seed ^= contextId + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; + } + public String toString() + { + return getPhraseString() + "\t" + getContextString(true); + } + + private int phraseId; + private int contextId; + private int count; + } + + List getEdges() + { + return edges; + } + + int getNumEdges() + { + return edges.size(); + } + + int getNumPhrases() + { + return phraseLexicon.size(); + } + + int getNumContextPositions() + { + return contextLexicon.lookup(0).size(); + } + + List getEdgesForPhrase(int phraseId) + { + return phraseToContext.get(phraseId); + } + + int getNumContexts() + { + return contextLexicon.size(); + } + + List getEdgesForContext(int contextId) + { + return contextToPhrase.get(contextId); + } + + int getNumWords() + { + return wordLexicon.size(); + } + + String getWord(int wordId) + { + return wordLexicon.lookup(wordId); + } + + public TIntArrayList getPhrase(int phraseId) + { + return phraseLexicon.lookup(phraseId); + } + + public String getPhraseString(int phraseId) + { + StringBuffer b = new StringBuffer(); + for (int tid: getPhrase(phraseId).toNativeArray()) + { + if (b.length() > 0) + b.append(" "); + b.append(wordLexicon.lookup(tid)); + } + return b.toString(); + } + + public TIntArrayList getContext(int contextId) + { + return contextLexicon.lookup(contextId); + } + + public String getContextString(int contextId, boolean insertPhraseSentinel) + { + StringBuffer b = new StringBuffer(); + TIntArrayList c = getContext(contextId); + for (int i = 0; i < c.size(); ++i) + { + if (i > 0) b.append(" "); + if (i == c.size() / 2) b.append(" "); + b.append(wordLexicon.lookup(c.get(i))); + } + return b.toString(); + } + + static Corpus readFromFile(Reader in) throws IOException + { + Corpus c = new Corpus(); + + // read in line-by-line + BufferedReader bin = new BufferedReader(in); + String line; + Pattern separator = Pattern.compile(" \\|\\|\\| "); + + while ((line = bin.readLine()) != null) + { + // split into phrase and contexts + StringTokenizer st = new StringTokenizer(line, "\t"); + assert (st.hasMoreTokens()); + String phraseToks = st.nextToken(); + assert (st.hasMoreTokens()); + String rest = st.nextToken(); + assert (!st.hasMoreTokens()); + + // process phrase + st = new StringTokenizer(phraseToks, " "); + TIntArrayList ptoks = new TIntArrayList(); + while (st.hasMoreTokens()) + ptoks.add(c.wordLexicon.insert(st.nextToken())); + int phraseId = c.phraseLexicon.insert(ptoks); + if (phraseId == c.phraseToContext.size()) + c.phraseToContext.add(new ArrayList()); + + // process contexts + String[] parts = separator.split(rest); + assert (parts.length % 2 == 0); + for (int i = 0; i < parts.length; i += 2) + { + // process pairs of strings - context and count + TIntArrayList ctx = new TIntArrayList(); + String ctxString = parts[i]; + String countString = parts[i + 1]; + StringTokenizer ctxStrtok = new StringTokenizer(ctxString, " "); + while (ctxStrtok.hasMoreTokens()) + { + String token = ctxStrtok.nextToken(); + if (!token.equals("")) + ctx.add(c.wordLexicon.insert(token)); + } + int contextId = c.contextLexicon.insert(ctx); + if (contextId == c.contextToPhrase.size()) + c.contextToPhrase.add(new ArrayList()); + + assert (countString.startsWith("C=")); + Edge e = c.new Edge(phraseId, contextId, + Integer.parseInt(countString.substring(2).trim())); + c.edges.add(e); + + // index the edge for fast phrase, context lookup + c.phraseToContext.get(phraseId).add(e); + c.contextToPhrase.get(contextId).add(e); + } + } + + return c; + } +} diff --git a/gi/posterior-regularisation/prjava/src/phrase/Lexicon.java b/gi/posterior-regularisation/prjava/src/phrase/Lexicon.java new file mode 100644 index 00000000..a386e4a3 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/Lexicon.java @@ -0,0 +1,34 @@ +package phrase; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class Lexicon +{ + public int insert(T word) + { + Integer i = wordToIndex.get(word); + if (i == null) + { + i = indexToWord.size(); + wordToIndex.put(word, i); + indexToWord.add(word); + } + return i; + } + + public T lookup(int index) + { + return indexToWord.get(index); + } + + public int size() + { + return indexToWord.size(); + } + + private Map wordToIndex = new HashMap(); + private List indexToWord = new ArrayList(); +} \ No newline at end of file diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index 731d03ac..e4db2a1a 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -1,44 +1,54 @@ package phrase; +import gnu.trove.TIntArrayList; import io.FileUtil; - -import java.io.FileOutputStream; import java.io.IOException; -import java.io.OutputStream; import java.io.PrintStream; import java.util.Arrays; +import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; -import java.util.zip.GZIPOutputStream; + +import phrase.Corpus.Edge; public class PhraseCluster { public int K; - public double scale; - private int n_phrase; - private int n_words; - public PhraseCorpus c; + public double scalePT, scaleCT; + private int n_phrases, n_words, n_contexts, n_positions; + public Corpus c; private ExecutorService pool; - /**@brief - * emit[tag][position][word] - */ + // emit[tag][position][word] = p(word | tag, position in context) private double emit[][][]; + // pi[phrase][tag] = p(tag | phrase) private double pi[][]; - - public static void main(String[] args) { + public static void main(String[] args) + { String input_fname = args[0]; int tags = Integer.parseInt(args[1]); String output_fname = args[2]; int iterations = Integer.parseInt(args[3]); - double scale = Double.parseDouble(args[4]); - int threads = Integer.parseInt(args[5]); - boolean runEM = Boolean.parseBoolean(args[6]); - - PhraseCorpus corpus = new PhraseCorpus(input_fname); - PhraseCluster cluster = new PhraseCluster(tags, corpus, scale, threads); + double scalePT = Double.parseDouble(args[4]); + double scaleCT = Double.parseDouble(args[5]); + int threads = Integer.parseInt(args[6]); + boolean runEM = Boolean.parseBoolean(args[7]); + + assert(tags >= 2); + assert(scalePT >= 0); + assert(scaleCT >= 0); + + Corpus corpus = null; + try { + corpus = Corpus.readFromFile(FileUtil.openBufferedReader(input_fname)); + } catch (IOException e) { + System.err.println("Failed to open input file: " + input_fname); + e.printStackTrace(); + System.exit(1); + } + PhraseCluster cluster = new PhraseCluster(tags, corpus, scalePT, scaleCT, threads); //PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out"); @@ -48,19 +58,25 @@ public class PhraseCluster { double o; if (runEM || i < 3) o = cluster.EM(); - else - o = cluster.PREM(); + else if (scaleCT == 0) + { + if (threads >= 1) + o = cluster.PREM_phrase_constraints_parallel(); + else + o = cluster.PREM_phrase_constraints(); + } + else + o = cluster.PREM_phrase_context_constraints(); + //PhraseObjective.ps. System.out.println("ITER: "+i+" objective: " + o); last = o; } - if (runEM) - { - double l1lmax = cluster.posterior_l1lmax(); - System.out.println("Final l1lmax term " + l1lmax + ", total PR objective " + (last - scale*l1lmax)); - // nb. KL is 0 by definition - } + double pl1lmax = cluster.phrase_l1lmax(); + double cl1lmax = cluster.context_l1lmax(); + System.out.println("Final posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax); + if (runEM) System.out.println("With PR objective " + (last - scalePT*pl1lmax - scaleCT*cl1lmax)); PrintStream ps=io.FileUtil.openOutFile(output_fname); cluster.displayPosterior(ps); @@ -75,17 +91,20 @@ public class PhraseCluster { cluster.finish(); } - public PhraseCluster(int numCluster, PhraseCorpus corpus, double scale, int threads){ + public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){ K=numCluster; c=corpus; - n_words=c.wordLex.size(); - n_phrase=c.data.length; - this.scale = scale; - if (threads > 0) + n_words=c.getNumWords(); + n_phrases=c.getNumPhrases(); + n_contexts=c.getNumContexts(); + n_positions=c.getNumContextPositions(); + this.scalePT = scalep; + this.scaleCT = scalec; + if (threads > 0 && scalec <= 0) pool = Executors.newFixedThreadPool(threads); - emit=new double [K][c.numContexts][n_words]; - pi=new double[n_phrase][K]; + emit=new double [K][n_positions][n_words]; + pi=new double[n_phrases][K]; for(double [][]i:emit){ for(double []j:i){ @@ -105,30 +124,32 @@ public class PhraseCluster { } public double EM(){ - double [][][]exp_emit=new double [K][c.numContexts][n_words]; - double [][]exp_pi=new double[n_phrase][K]; + double [][][]exp_emit=new double [K][n_positions][n_words]; + double [][]exp_pi=new double[n_phrases][K]; double loglikelihood=0; //E - for(int phrase=0;phrase contexts = c.getEdgesForPhrase(phrase); + + for (int ctx=0; ctx 0; loglikelihood+=Math.log(z); arr.F.l1normalize(p); - int contextCnt=context[context.length-1]; + int count = edge.getCount(); //increment expected count + TIntArrayList context = edge.getContext(); for(int tag=0;tag edges = c.getEdgesForPhrase(phrase); + for(int edge=0;edge expectations = new LinkedBlockingQueue(); - double [][][]exp_emit=new double [K][c.numContexts][n_words]; - double [][]exp_pi=new double[n_phrase][K]; + double [][][]exp_emit=new double [K][n_positions][n_words]; + double [][]exp_pi=new double[n_phrases][K]; - double loglikelihood=0; - double primal=0; + double loglikelihood=0, kl=0, l1lmax=0, primal=0; + //E - for(int phrase=0;phrase edges = c.getEdgesForPhrase(phrase); for(int edge=0;edge 0); + + double [][][]exp_emit=new double [K][n_positions][n_words]; + double [][]exp_pi=new double[n_phrases][K]; + + //E step + // TODO: cache the lambda values (the null below) + PhraseContextObjective pco = new PhraseContextObjective(this, null); + pco.optimizeWithProjectedGradientDescent(); + + //now extract expectations + List edges = c.getEdges(); + for(int e = 0; e < edges.size(); ++e) + { + double [] q = pco.posterior(e); + Corpus.Edge edge = edges.get(e); + + TIntArrayList context = edge.getContext(); + int contextCnt = edge.getCount(); + //increment expected count + for(int tag=0;tag EPS) + ps.print("\t" + j + ": " + pi[i][j]); } ps.println(); } @@ -355,14 +413,11 @@ public class PhraseCluster { ps.println("P(word|tag,position)"); for (int i = 0; i < K; ++i) { - for(int position=0;position 1e-10) - ps.print(c.wordList[word]+"="+emit[i][position][word]+"\t"); + if (emit[i][position][word] > EPS) + ps.print(c.getWord(word)+"="+emit[i][position][word]+"\t"); } ps.println(); } @@ -371,19 +426,35 @@ public class PhraseCluster { } - double posterior_l1lmax() + double phrase_l1lmax() { double sum=0; - for(int phrase=0;phrase data; + + // log likelihood under q + private double loglikelihood; + private SimplexProjection projectionPhrase; + private SimplexProjection projectionContext; + + double[] newPoint; + private int n_param; + + // likelihood under p + public double llh; + + private Map edgeIndex; + + public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters) + { + c=cluster; + data=c.c.getEdges(); + n_param=data.size()*c.K*2; + + parameters = startingParameters; + if (parameters == null) + parameters = new double[n_param]; + + newPoint = new double[n_param]; + gradient = new double[n_param]; + initP(); + projectionPhrase = new SimplexProjection(c.scalePT); + projectionContext = new SimplexProjection(c.scaleCT); + q=new double [data.size()][c.K]; + + edgeIndex = new HashMap(); + for (int e=0; e 0) + { + // first project using the phrase-tag constraints, + // for all p,t: sum_c lambda_ptc < scaleP + for (int p = 0; p < c.c.getNumPhrases(); ++p) + { + List edges = c.c.getEdgesForPhrase(p); + double toProject[] = new double[edges.size()]; + for(int tag=0;tag 1e-6) + { + // now project using the context-tag constraints, + // for all c,t: sum_p omega_pct < scaleC + for (int ctx = 0; ctx < c.c.getNumContexts(); ++ctx) + { + List edges = c.c.getEdgesForContext(ctx); + double toProject[] = new double[edges.size()]; + for(int tag=0;tag edges = c.c.getEdgesForPhrase(p); + for(int tag=0;tag edges = c.c.getEdgesForContext(ctx); + for(int tag=0; tagwordLex; public HashMapphraseLex; @@ -21,16 +18,8 @@ public class PhraseCorpus { //data[phrase][num context][position] public int data[][][]; - public int numContexts; - - public static void main(String[] args) { - // TODO Auto-generated method stub - PhraseCorpus c=new PhraseCorpus(DATA_FILENAME); - c.saveLex(LEX_FILENAME); - c.loadLex(LEX_FILENAME); - c.saveLex(LEX_FILENAME); - } - + public int numContexts; + public PhraseCorpus(String filename){ BufferedReader r=io.FileUtil.openBufferedReader(filename); @@ -185,5 +174,13 @@ public class PhraseCorpus { } return null; } - + + public static void main(String[] args) { + String LEX_FILENAME="../pdata/lex.out"; + String DATA_FILENAME="../pdata/btec.con"; + PhraseCorpus c=new PhraseCorpus(DATA_FILENAME); + c.saveLex(LEX_FILENAME); + c.loadLex(LEX_FILENAME); + c.saveLex(LEX_FILENAME); + } } diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java index 0fdc169b..015ef106 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java @@ -2,6 +2,7 @@ package phrase; import java.io.PrintStream; import java.util.Arrays; +import java.util.List; import optimization.gradientBasedMethods.ProjectedGradientDescent; import optimization.gradientBasedMethods.ProjectedObjective; @@ -17,11 +18,12 @@ import optimization.stopCriteria.StopingCriteria; import optimization.stopCriteria.ValueDifference; import optimization.util.MathUtils; -public class PhraseObjective extends ProjectedObjective{ - - private static final double GRAD_DIFF = 0.00002; - public static double INIT_STEP_SIZE = 10; - public static double VAL_DIFF = 0.000001; // FIXME needs to be tuned +public class PhraseObjective extends ProjectedObjective +{ + static final double GRAD_DIFF = 0.00002; + static double INIT_STEP_SIZE = 10; + static double VAL_DIFF = 1e-6; // FIXME needs to be tuned + static int ITERATIONS = 100; //private double c1=0.0001; // wolf stuff //private double c2=0.9; private static double lambda[][]; @@ -46,7 +48,7 @@ public class PhraseObjective extends ProjectedObjective{ * q[edge][tag] propto p[edge][tag]*exp(-lambda) */ private double q[][]; - private int data[][]; + private List data; /**@brief log likelihood of the associated phrase * @@ -66,14 +68,14 @@ public class PhraseObjective extends ProjectedObjective{ public PhraseObjective(PhraseCluster cluster, int phraseIdx){ phrase=phraseIdx; c=cluster; - data=c.c.data[phrase]; - n_param=data.length*c.K; + data=c.c.getEdgesForPhrase(phrase); + n_param=data.size()*c.K; - if( lambda==null){ - lambda=new double[c.c.data.length][]; + if (lambda==null){ + lambda=new double[c.c.getNumPhrases()][]; } - if(lambda[phrase]==null){ + if (lambda[phrase]==null){ lambda[phrase]=new double[n_param]; } @@ -81,22 +83,17 @@ public class PhraseObjective extends ProjectedObjective{ newPoint = new double[n_param]; gradient = new double[n_param]; initP(); - projection=new SimplexProjection(c.scale); - q=new double [data.length][c.K]; + projection=new SimplexProjection(c.scalePT); + q=new double [data.size()][c.K]; setParameters(parameters); } private void initP(){ - int countIdx=data[0].length-1; - - p=new double[data.length][]; - for(int edge=0;edgemax){ + for(int edge=0;edgemax) max=q[edge][tag]; - } } sum+=max; } - //System.out.println("l1lmax " + sum); -// ps.println(", "+sum); - l=l-c.scale*sum; - return l; + return sum; + } + + public double primal() + { + return loglikelihood() - KL_divergence() - c.scalePT * l1lmax(); } - } -- cgit v1.2.3