From 4037e35c511aec96f780276aa4e3c1493e19eba1 Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Thu, 15 Jul 2010 22:48:44 +0000 Subject: Option to run on single word phrases before moving to larger ones. git-svn-id: https://ws10smt.googlecode.com/svn/trunk@272 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../prjava/src/phrase/C2F.java | 16 ++++-- .../prjava/src/phrase/Corpus.java | 26 +++++++++ .../prjava/src/phrase/PhraseCluster.java | 67 ++++++++++++++++++---- .../prjava/src/phrase/Trainer.java | 13 +++-- 4 files changed, 99 insertions(+), 23 deletions(-) (limited to 'gi/posterior-regularisation/prjava') diff --git a/gi/posterior-regularisation/prjava/src/phrase/C2F.java b/gi/posterior-regularisation/prjava/src/phrase/C2F.java index a8e557f2..31fd4fda 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/C2F.java +++ b/gi/posterior-regularisation/prjava/src/phrase/C2F.java @@ -38,10 +38,10 @@ public class C2F { n_contexts=c.getNumContexts(); //number of words in a phrase to be considered - //currently the first and last word - //if the phrase has length 1 - //use the same word for two positions - n_positions=2; + //currently the first and last word in source and target + //if the phrase has length 1 in either dimension then + //we use the same word for two positions + n_positions=c.phraseEdges(c.getEdges().get(0).getPhrase()).size(); emit=new double [K][n_positions][n_words]; pi=new double[n_contexts][K]; @@ -156,9 +156,13 @@ public class C2F { double[] prob=Arrays.copyOf(pi[edge.getContextId()], K); TIntArrayList phrase = edge.getPhrase(); + TIntArrayList offsets = c.phraseEdges(phrase); for(int tag=0;tag edges = new ArrayList(); private List> phraseToContext = new ArrayList>(); private List> contextToPhrase = new ArrayList>(); + public int splitSentinel; + public int phraseSentinel; + + public Corpus() + { + splitSentinel = wordLexicon.insert(""); + phraseSentinel = wordLexicon.insert(""); + } public class Edge { @@ -157,6 +165,11 @@ public class Corpus return b.toString(); } + public boolean isSentinel(int wordId) + { + return wordId == splitSentinel || wordId == phraseSentinel; + } + static Corpus readFromFile(Reader in) throws IOException { Corpus c = new Corpus(); @@ -218,6 +231,19 @@ public class Corpus return c; } + + TIntArrayList phraseEdges(TIntArrayList phrase) + { + TIntArrayList r = new TIntArrayList(4); + for (int p = 0; p < phrase.size(); ++p) + { + if (p == 0 || phrase.get(p-1) == splitSentinel) + r.add(p); + if (p == phrase.size() - 1 || phrase.get(p+1) == splitSentinel) + r.add(p); + } + return r; + } public void printStats(PrintStream out) { diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java index a369b319..5efaf52e 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java +++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java @@ -69,16 +69,29 @@ public class PhraseCluster { pool = Executors.newFixedThreadPool(threads); } - public double EM() + public double EM(boolean skipBigPhrases) { double [][][]exp_emit=new double [K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; + if (skipBigPhrases) + { + for(double [][]i:exp_emit) + for(double []j:i) + Arrays.fill(j, 1e-100); + } + double loglikelihood=0; //E for(int phrase=0; phrase < n_phrases; phrase++) { + if (skipBigPhrases && c.getPhrase(phrase).size() >= 2) + { + System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K); + continue; + } + List contexts = c.getEdgesForPhrase(phrase); for (int ctx=0; ctx= 2) + { + System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K); + continue; + } + PhraseObjective po = new PhraseObjective(this, phrase, scalePT, (cacheLambda) ? lambdaPT[phrase] : null); boolean ok = po.optimizeWithProjectedGradientDescent(); if (!ok) ++failures; @@ -292,7 +319,7 @@ public class PhraseCluster { return primal; } - public double PREM_phrase_constraints_parallel(final double scalePT) + public double PREM_phrase_constraints_parallel(final double scalePT, boolean skipBigPhrases) { assert(pool != null); @@ -302,10 +329,17 @@ public class PhraseCluster { double [][][]exp_emit=new double [K][n_positions][n_words]; double [][]exp_pi=new double[n_phrases][K]; + if (skipBigPhrases) + { + for(double [][]i:exp_emit) + for(double []j:i) + Arrays.fill(j, 1e-100); + } + double loglikelihood=0, kl=0, l1lmax=0, primal=0; final AtomicInteger failures = new AtomicInteger(0); final AtomicLong elapsed = new AtomicLong(0l); - int iterations=0; + int iterations=0, n=n_phrases; long start = System.currentTimeMillis(); if (lambdaPT == null && cacheLambda) @@ -313,6 +347,12 @@ public class PhraseCluster { //E for(int phrase=0;phrase= 2) + { + n -= 1; + System.arraycopy(pi[phrase], 0, exp_pi[phrase], 0, K); + continue; + } final int p=phrase; pool.execute(new Runnable() { public void run() { @@ -337,7 +377,7 @@ public class PhraseCluster { } // aggregate the expectations as they become available - for(int count=0;count