summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava
diff options
context:
space:
mode:
authortrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-12 19:48:54 +0000
committertrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-12 19:48:54 +0000
commitfc5860d6df8c30149cee280f38d7a11889102663 (patch)
tree5a490776602a3f8acac2b0c7b6d5ab8f88c29606 /gi/posterior-regularisation/prjava
parenta753ce95a0c285952fd64d658f2bd4a5c5db8e4c (diff)
Updated launcher to include agreement model.
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@226 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Agree.java6
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java91
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java18
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java8
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Trainer.java56
5 files changed, 104 insertions, 75 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Agree.java b/gi/posterior-regularisation/prjava/src/phrase/Agree.java
index d5b949b0..d61e6eef 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Agree.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Agree.java
@@ -12,8 +12,8 @@ import java.util.List;
import phrase.Corpus.Edge;
public class Agree {
- private PhraseCluster model1;
- private C2F model2;
+ PhraseCluster model1;
+ C2F model2;
Corpus c;
private int K,n_phrases, n_words, n_contexts, n_positions1,n_positions2;
@@ -32,7 +32,7 @@ public class Agree {
*/
public Agree(int numCluster, Corpus corpus){
- model1=new PhraseCluster(numCluster, corpus, 0, 0, 0);
+ model1=new PhraseCluster(numCluster, corpus);
model2=new C2F(numCluster,corpus);
c=corpus;
n_words=c.getNumWords();
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index 7bc63c33..abd868c4 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -18,21 +18,16 @@ import util.MathUtil;
public class PhraseCluster {
public int K;
- public double scalePT, scaleCT;
private int n_phrases, n_words, n_contexts, n_positions;
public Corpus c;
public ExecutorService pool;
// emit[tag][position][word] = p(word | tag, position in context)
- private double emit[][][];
+ double emit[][][];
// pi[phrase][tag] = p(tag | phrase)
- private double pi[][];
+ double pi[][];
- double alphaEmit;
- double alphaPi;
-
- public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads,
- double alphaEmit, double alphaPi)
+ public PhraseCluster(int numCluster, Corpus corpus)
{
K=numCluster;
c=corpus;
@@ -40,33 +35,34 @@ public class PhraseCluster {
n_phrases=c.getNumPhrases();
n_contexts=c.getNumContexts();
n_positions=c.getNumContextPositions();
- this.scalePT = scalep;
- this.scaleCT = scalec;
- if (threads > 0)
- pool = Executors.newFixedThreadPool(threads);
-
+
emit=new double [K][n_positions][n_words];
pi=new double[n_phrases][K];
for(double [][]i:emit)
- {
for(double []j:i)
- {
- arr.F.randomise(j, alphaEmit <= 0);
- if (alphaEmit > 0)
- digammaNormalize(j, alphaEmit);
- }
- }
-
+ arr.F.randomise(j, true);
for(double []j:pi)
- {
- arr.F.randomise(j, alphaPi <= 0);
- if (alphaPi > 0)
- digammaNormalize(j, alphaPi);
- }
+ arr.F.randomise(j, true);
+ }
+
+ public void initialiseVB(double alphaEmit, double alphaPi)
+ {
+ assert alphaEmit > 0;
+ assert alphaPi > 0;
+
+ for(double [][]i:emit)
+ for(double []j:i)
+ digammaNormalize(j, alphaEmit);
- this.alphaEmit = alphaEmit;
- this.alphaPi = alphaPi;
+ for(double []j:pi)
+ digammaNormalize(j, alphaPi);
+ }
+
+ void useThreadPool(int threads)
+ {
+ assert threads > 0;
+ pool = Executors.newFixedThreadPool(threads);
}
public double EM()
@@ -116,7 +112,7 @@ public class PhraseCluster {
return loglikelihood;
}
- public double VBEM()
+ public double VBEM(double alphaEmit, double alphaPi)
{
// FIXME: broken - needs to be done entirely in log-space
@@ -216,9 +212,22 @@ public class PhraseCluster {
return kl;
}
- public double PREM_phrase_constraints(){
- assert (scaleCT <= 0);
-
+ public double PREM(double scalePT, double scaleCT)
+ {
+ if (scaleCT == 0)
+ {
+ if (pool != null)
+ return PREM_phrase_constraints_parallel(scalePT);
+ else
+ return PREM_phrase_constraints(scalePT);
+ }
+ else
+ return this.PREM_phrase_context_constraints(scalePT, scaleCT);
+ }
+
+
+ public double PREM_phrase_constraints(double scalePT)
+ {
double [][][]exp_emit=new double[K][n_positions][n_words];
double [][]exp_pi=new double[n_phrases][K];
@@ -226,7 +235,7 @@ public class PhraseCluster {
int failures=0, iterations=0;
//E
for(int phrase=0; phrase<n_phrases; phrase++){
- PhraseObjective po=new PhraseObjective(this,phrase);
+ PhraseObjective po=new PhraseObjective(this, phrase, scalePT);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) ++failures;
iterations += po.getNumberUpdateCalls();
@@ -234,7 +243,7 @@ public class PhraseCluster {
loglikelihood += po.loglikelihood();
kl += po.KL_divergence();
l1lmax += po.l1lmax();
- primal += po.primal();
+ primal += po.primal(scalePT);
List<Edge> edges = c.getEdgesForPhrase(phrase);
for(int edge=0;edge<q.length;edge++){
@@ -272,10 +281,9 @@ public class PhraseCluster {
return primal;
}
- public double PREM_phrase_constraints_parallel()
+ public double PREM_phrase_constraints_parallel(final double scalePT)
{
assert(pool != null);
- assert(scaleCT <= 0);
final LinkedBlockingQueue<PhraseObjective> expectations
= new LinkedBlockingQueue<PhraseObjective>();
@@ -294,7 +302,7 @@ public class PhraseCluster {
public void run() {
try {
//System.out.println("" + Thread.currentThread().getId() + " optimising lambda for " + p);
- PhraseObjective po = new PhraseObjective(PhraseCluster.this, p);
+ PhraseObjective po = new PhraseObjective(PhraseCluster.this, p, scalePT);
boolean ok = po.optimizeWithProjectedGradientDescent();
if (!ok) failures.incrementAndGet();
//System.out.println("" + Thread.currentThread().getId() + " done optimising lambda for " + p);
@@ -322,7 +330,7 @@ public class PhraseCluster {
loglikelihood += po.loglikelihood();
kl += po.KL_divergence();
l1lmax += po.l1lmax();
- primal += po.primal();
+ primal += po.primal(scalePT);
iterations += po.getNumberUpdateCalls();
@@ -366,15 +374,14 @@ public class PhraseCluster {
return primal;
}
- public double PREM_phrase_context_constraints(){
- assert (scaleCT > 0);
-
+ public double PREM_phrase_context_constraints(double scalePT, double scaleCT)
+ {
double[][][] exp_emit = new double [K][n_positions][n_words];
double[][] exp_pi = new double[n_phrases][K];
double[] lambda = null;
//E step
- PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool);
+ PhraseContextObjective pco = new PhraseContextObjective(this, lambda, pool, scalePT, scaleCT);
lambda = pco.optimizeWithProjectedGradientDescent();
//now extract expectations
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
index 15bd29c2..ff135a3d 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
@@ -59,12 +59,18 @@ public class PhraseContextObjective extends ProjectedObjective
private long actualProjectionTime;
private ExecutorService pool;
- public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters, ExecutorService pool)
+ double scalePT;
+ double scaleCT;
+
+ public PhraseContextObjective(PhraseCluster cluster, double[] startingParameters, ExecutorService pool,
+ double scalePT, double scaleCT)
{
c=cluster;
data=c.c.getEdges();
n_param=data.size()*c.K*2;
this.pool=pool;
+ this.scalePT = scalePT;
+ this.scaleCT = scaleCT;
parameters = startingParameters;
if (parameters == null)
@@ -73,8 +79,8 @@ public class PhraseContextObjective extends ProjectedObjective
newPoint = new double[n_param];
gradient = new double[n_param];
initP();
- projectionPhrase = new SimplexProjection(c.scalePT);
- projectionContext = new SimplexProjection(c.scaleCT);
+ projectionPhrase = new SimplexProjection(scalePT);
+ projectionContext = new SimplexProjection(scaleCT);
q=new double [data.size()][c.K];
edgeIndex = new HashMap<Edge, Integer>();
@@ -151,7 +157,7 @@ public class PhraseContextObjective extends ProjectedObjective
//System.out.println("projectPoint: " + Arrays.toString(point));
Arrays.fill(newPoint, 0, newPoint.length, 0);
- if (c.scalePT > 0)
+ if (scalePT > 0)
{
// first project using the phrase-tag constraints,
// for all p,t: sum_c lambda_ptc < scaleP
@@ -201,7 +207,7 @@ public class PhraseContextObjective extends ProjectedObjective
}
//System.out.println("after PT " + Arrays.toString(newPoint));
- if (c.scaleCT > 1e-6)
+ if (scaleCT > 1e-6)
{
// now project using the context-tag constraints,
// for all c,t: sum_p omega_pct < scaleC
@@ -399,6 +405,6 @@ public class PhraseContextObjective extends ProjectedObjective
// L - KL(q||p) - scalePT * l1lmax_phrase - scaleCT * l1lmax_context
public double primal()
{
- return loglikelihood() - KL_divergence() - c.scalePT * phrase_l1lmax() - c.scalePT * context_l1lmax();
+ return loglikelihood() - KL_divergence() - scalePT * phrase_l1lmax() - scalePT * context_l1lmax();
}
} \ No newline at end of file
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
index cc12546d..33167c20 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
@@ -63,7 +63,7 @@ public class PhraseObjective extends ProjectedObjective
*/
public double llh;
- public PhraseObjective(PhraseCluster cluster, int phraseIdx){
+ public PhraseObjective(PhraseCluster cluster, int phraseIdx, double scale){
phrase=phraseIdx;
c=cluster;
data=c.c.getEdgesForPhrase(phrase);
@@ -81,7 +81,7 @@ public class PhraseObjective extends ProjectedObjective
newPoint = new double[n_param];
gradient = new double[n_param];
initP();
- projection=new SimplexProjection(c.scalePT);
+ projection=new SimplexProjection(scale);
q=new double [data.size()][c.K];
setParameters(parameters);
@@ -220,8 +220,8 @@ public class PhraseObjective extends ProjectedObjective
return sum;
}
- public double primal()
+ public double primal(double scale)
{
- return loglikelihood() - KL_divergence() - c.scalePT * l1lmax();
+ return loglikelihood() - KL_divergence() - scale * l1lmax();
}
}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
index 439fb337..240c4d64 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
@@ -30,6 +30,7 @@ public class Trainer
parser.accepts("variational-bayes");
parser.accepts("alpha-emit").withRequiredArg().ofType(Double.class).defaultsTo(0.1);
parser.accepts("alpha-pi").withRequiredArg().ofType(Double.class).defaultsTo(0.01);
+ parser.accepts("agree");
OptionSet options = parser.parse(args);
if (options.has("help") || !options.has("in"))
@@ -37,7 +38,7 @@ public class Trainer
try {
parser.printHelpOn(System.err);
} catch (IOException e) {
- System.err.println("This should never happen. Really.");
+ System.err.println("This should never happen.");
e.printStackTrace();
}
System.exit(1);
@@ -75,34 +76,46 @@ public class Trainer
System.exit(1);
}
- System.out.println("Running with " + tags + " tags " +
- "for " + em_iterations + " EM and " + pr_iterations + " PR iterations " +
- "with scale " + scale_phrase + " phrase and " + scale_context + " context " +
- "and " + threads + " threads");
- System.out.println();
+ if (!options.has("agree"))
+ System.out.println("Running with " + tags + " tags " +
+ "for " + em_iterations + " EM and " + pr_iterations + " PR iterations " +
+ "with scale " + scale_phrase + " phrase and " + scale_context + " context " +
+ "and " + threads + " threads");
+ else
+ System.out.println("Running agreement model with " + tags + " tags " +
+ "for " + em_iterations);
+
+ System.out.println();
- PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads, alphaEmit, alphaPi);
+ PhraseCluster cluster = null;
+ Agree agree = null;
+ if (options.has("agree"))
+ agree = new Agree(tags, corpus);
+ else
+ {
+ cluster = new PhraseCluster(tags, corpus);
+ if (threads > 0) cluster.useThreadPool(threads);
+ if (vb) cluster.initialiseVB(alphaEmit, alphaPi);
+ }
double last = 0;
for (int i=0; i<em_iterations+pr_iterations; i++)
{
double o;
- if (i < em_iterations)
- {
- if (!vb)
- o = cluster.EM();
- else
- o = cluster.VBEM();
- }
- else if (scale_context == 0)
+ if (agree != null)
+ o = agree.EM();
+ else
{
- if (threads >= 1)
- o = cluster.PREM_phrase_constraints_parallel();
+ if (i < em_iterations)
+ {
+ if (!vb)
+ o = cluster.EM();
+ else
+ o = cluster.VBEM(alphaEmit, alphaPi);
+ }
else
- o = cluster.PREM_phrase_constraints();
+ o = cluster.PREM(scale_phrase, scale_context);
}
- else
- o = cluster.PREM_phrase_context_constraints();
System.out.println("ITER: "+i+" objective: " + o);
@@ -120,6 +133,9 @@ public class Trainer
last = o;
}
+ if (cluster == null)
+ cluster = agree.model1;
+
double pl1lmax = cluster.phrase_l1lmax();
double cl1lmax = cluster.context_l1lmax();
System.out.println("\nFinal posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax);