summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/em.cc
diff options
context:
space:
mode:
authortrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-28 19:34:58 +0000
committertrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-28 19:34:58 +0000
commit89afb013e40bf9c749956420e77ac72773874844 (patch)
tree2e18cf00d768507ae5cbc4cb085e22165f3b6fe6 /gi/posterior-regularisation/em.cc
parentc1fd8f9ba0c8a6f5ab07a3390d3f1a3910a3267e (diff)
First bits of code for PR training
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@44 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/em.cc')
-rw-r--r--gi/posterior-regularisation/em.cc830
1 files changed, 830 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/em.cc b/gi/posterior-regularisation/em.cc
new file mode 100644
index 00000000..f6c9fd68
--- /dev/null
+++ b/gi/posterior-regularisation/em.cc
@@ -0,0 +1,830 @@
+// Input of the form:
+// " the phantom of the opera " tickets for <PHRASE> tonight ? ||| C=1 ||| seats for <PHRASE> ? </s> ||| C=1 ||| i see <PHRASE> ? </s> ||| C=1
+// phrase TAB [context]+
+// where context = phrase ||| C=... which are separated by |||
+
+// Model parameterised as follows:
+// - each phrase, p, is allocated a latent state, t
+// - this is used to generate the contexts, c
+// - each context is generated using 4 independent multinomials, one for each position LL, L, R, RR
+
+// Training with EM:
+// - e-step is estimating P(t|p,c) for all x,c
+// - m-step is estimating model parameters P(p,c,t) = P(t) P(p|t) P(c|t)
+
+// Sexing it up:
+// - constrain the posteriors P(t|c) and P(t|p) to have few high-magnitude entries
+// - improve the generation of phrase internals, e.g., generate edge words from
+// different distribution to central words
+
+#include "alphabet.hh"
+#include "log_add.hh"
+#include <algorithm>
+#include <fstream>
+#include <iostream>
+#include <iterator>
+#include <map>
+#include <sstream>
+#include <stdexcept>
+#include <vector>
+#include <tr1/random>
+#include <tr1/tuple>
+#include <nlopt.h>
+
+using namespace std;
+using namespace std::tr1;
+
+const int numTags = 5;
+const int numIterations = 100;
+const bool posterior_regularisation = true;
+const double PHRASE_VIOLATION_WEIGHT = 10;
+const double CONTEXT_VIOLATION_WEIGHT = 0;
+const bool includePhraseProb = false;
+
+// Data structures:
+Alphabet<string> lexicon;
+typedef vector<int> Phrase;
+typedef tuple<int, int, int, int> Context;
+Alphabet<Phrase> phrases;
+Alphabet<Context> contexts;
+
+typedef map<int, int> ContextCounts;
+typedef map<int, int> PhraseCounts;
+typedef map<int, ContextCounts> PhraseToContextCounts;
+typedef map<int, PhraseCounts> ContextToPhraseCounts;
+
+PhraseToContextCounts concordancePhraseToContexts;
+ContextToPhraseCounts concordanceContextToPhrases;
+
+typedef vector<double> Dist;
+typedef vector<Dist> ConditionalDist;
+Dist prior; // class -> P(class)
+vector<ConditionalDist> probCtx; // word -> class -> P(word | class), for each position of context word
+ConditionalDist probPhrase; // class -> P(word | class)
+Dist probPhraseLength; // class -> P(length | class) expressed as geometric distribution parameter
+
+mt19937 randomGenerator((size_t) time(NULL));
+uniform_real<double> uniDist(0.0, 1e-1);
+variate_generator< mt19937, uniform_real<double> > rng(randomGenerator, uniDist);
+
+void addRandomNoise(Dist &d);
+void normalise(Dist &d);
+void addTo(Dist &d, const Dist &e);
+int argmax(const Dist &d);
+
+map<Phrase, map<Context, int> > lambda_indices;
+
+Dist conditional_probs(const Phrase &phrase, const Context &context, double *normalisation = 0);
+template <typename T>
+Dist
+penalised_conditionals(const Phrase &phrase, const Context &context,
+ const T &lambda, double *normalisation);
+//Dist penalised_conditionals(const Phrase &phrase, const Context &context, const double *lambda, double *normalisation = 0);
+double penalised_log_likelihood(int n, const double *lambda, double *gradient, void *data);
+void optimise_lambda(double delta, double gamma, vector<double> &lambda);
+double expected_violation_phrases(const double *lambda);
+double expected_violation_contexts(const double *lambda);
+double primal_kl_divergence(const double *lambda);
+double dual(const double *lambda);
+void print_primal_dual(const double *lambda, double delta, double gamma);
+
+ostream &operator<<(ostream &, const Phrase &);
+ostream &operator<<(ostream &, const Context &);
+ostream &operator<<(ostream &, const Dist &);
+ostream &operator<<(ostream &, const ConditionalDist &);
+
+int
+main(int argc, char *argv[])
+{
+ randomGenerator.seed(time(NULL));
+
+ int edges = 0;
+ istream &input = cin;
+ while (input.good())
+ {
+ // read the phrase
+ string phraseString;
+ Phrase phrase;
+ getline(input, phraseString, '\t');
+ istringstream pinput(phraseString);
+ string token;
+ while (pinput >> token)
+ phrase.push_back(lexicon.insert(token));
+ int phraseId = phrases.insert(phrase);
+
+ // read the rest, storing each context
+ string remainder;
+ getline(input, remainder, '\n');
+ istringstream rinput(remainder);
+ Context context(-1, -1, -1, -1);
+ int index = 0;
+ while (rinput >> token)
+ {
+ if (token != "|||" && token != "<PHRASE>")
+ {
+ if (index < 4)
+ {
+ // eugh! damn templates
+ switch (index)
+ {
+ case 0: get<0>(context) = lexicon.insert(token); break;
+ case 1: get<1>(context) = lexicon.insert(token); break;
+ case 2: get<2>(context) = lexicon.insert(token); break;
+ case 3: get<3>(context) = lexicon.insert(token); break;
+ default: assert(false);
+ }
+ index += 1;
+ }
+ else if (token.find("C=") == 0)
+ {
+ int contextId = contexts.insert(context);
+ int count = atoi(token.substr(strlen("C=")).c_str());
+ concordancePhraseToContexts[phraseId][contextId] += count;
+ concordanceContextToPhrases[contextId][phraseId] += count;
+ index = 0;
+ context = Context(-1, -1, -1, -1);
+ edges += 1;
+ }
+ }
+ }
+
+ // trigger EOF
+ input >> ws;
+ }
+
+ cout << "Read in " << phrases.size() << " phrases"
+ << " and " << contexts.size() << " contexts"
+ << " and " << edges << " edges"
+ << " and " << lexicon.size() << " word types\n";
+
+ // FIXME: filter out low count phrases and low count contexts (based on individual words?)
+ // now populate model parameters with uniform + random noise
+ prior.resize(numTags, 1.0);
+ addRandomNoise(prior);
+ normalise(prior);
+
+ probCtx.resize(4, ConditionalDist(numTags, Dist(lexicon.size(), 1.0)));
+ if (includePhraseProb)
+ probPhrase.resize(numTags, Dist(lexicon.size(), 1.0));
+ for (int t = 0; t < numTags; ++t)
+ {
+ for (int j = 0; j < 4; ++j)
+ {
+ addRandomNoise(probCtx[j][t]);
+ normalise(probCtx[j][t]);
+ }
+ if (includePhraseProb)
+ {
+ addRandomNoise(probPhrase[t]);
+ normalise(probPhrase[t]);
+ }
+ }
+ if (includePhraseProb)
+ {
+ probPhraseLength.resize(numTags, 0.5); // geometric distribution p=0.5
+ addRandomNoise(probPhraseLength);
+ }
+
+ cout << "\tprior: " << prior << "\n";
+ //cout << "\tcontext: " << probCtx << "\n";
+ //cout << "\tphrase: " << probPhrase << "\n";
+ //cout << "\tphraseLen: " << probPhraseLength << endl;
+
+ vector<double> lambda;
+
+ // now do EM training
+ for (int iteration = 0; iteration < numIterations; ++iteration)
+ {
+ cout << "EM iteration " << iteration << endl;
+
+ if (posterior_regularisation)
+ optimise_lambda(PHRASE_VIOLATION_WEIGHT, CONTEXT_VIOLATION_WEIGHT, lambda);
+ //cout << "\tlambda " << lambda << endl;
+
+ Dist countsPrior(numTags, 0.0);
+ vector<ConditionalDist> countsCtx(4, ConditionalDist(numTags, Dist(lexicon.size(), 1e-10)));
+ ConditionalDist countsPhrase(numTags, Dist(lexicon.size(), 1e-10));
+ Dist countsPhraseLength(numTags, 0.0);
+ Dist nPhrases(numTags, 0.0);
+
+ double llh = 0;
+ for (PhraseToContextCounts::iterator pcit = concordancePhraseToContexts.begin();
+ pcit != concordancePhraseToContexts.end(); ++pcit)
+ {
+ const Phrase &phrase = phrases.type(pcit->first);
+
+ // e-step: estimate latent class probs; compile (class,word) stats for m-step
+ for (ContextCounts::iterator ccit = pcit->second.begin();
+ ccit != pcit->second.end(); ++ccit)
+ {
+ const Context &context = contexts.type(ccit->first);
+
+ double z = 0;
+ Dist tagCounts;
+ if (!posterior_regularisation)
+ tagCounts = conditional_probs(phrase, context, &z);
+ else
+ tagCounts = penalised_conditionals(phrase, context, lambda, &z);
+
+ llh += log(z) * ccit->second;
+ addTo(countsPrior, tagCounts); // FIXME: times ccit->secon
+
+ for (int t = 0; t < numTags; ++t)
+ {
+ for (int j = 0; j < 4; ++j)
+ countsCtx[j][t][get<0>(context)] += tagCounts[t] * ccit->second;
+
+ if (includePhraseProb)
+ {
+ for (Phrase::const_iterator pit = phrase.begin(); pit != phrase.end(); ++pit)
+ countsPhrase[t][*pit] += tagCounts[t] * ccit->second;
+ countsPhraseLength[t] += phrase.size() * tagCounts[t] * ccit->second;
+ nPhrases[t] += tagCounts[t] * ccit->second;
+ }
+ }
+ }
+ }
+
+ cout << "M-step\n";
+
+ // m-step: normalise prior and (class,word) stats and assign to model parameters
+ normalise(countsPrior);
+ prior = countsPrior;
+ for (int t = 0; t < numTags; ++t)
+ {
+ //cout << "\t\tt " << t << " prior " << countsPrior[t] << "\n";
+ for (int j = 0; j < 4; ++j)
+ normalise(countsCtx[j][t]);
+ if (includePhraseProb)
+ {
+ normalise(countsPhrase[t]);
+ countsPhraseLength[t] = nPhrases[t] / countsPhraseLength[t];
+ }
+ }
+ probCtx = countsCtx;
+ if (includePhraseProb)
+ {
+ probPhrase = countsPhrase;
+ probPhraseLength = countsPhraseLength;
+ }
+
+ double *larray = new double[lambda.size()];
+ copy(lambda.begin(), lambda.end(), larray);
+ print_primal_dual(larray, PHRASE_VIOLATION_WEIGHT, CONTEXT_VIOLATION_WEIGHT);
+ delete [] larray;
+
+ //cout << "\tllh " << llh << endl;
+ //cout << "\tprior: " << prior << "\n";
+ //cout << "\tcontext: " << probCtx << "\n";
+ //cout << "\tphrase: " << probPhrase << "\n";
+ //cout << "\tphraseLen: " << probPhraseLength << "\n";
+ }
+
+ // output class membership
+ for (PhraseToContextCounts::iterator pcit = concordancePhraseToContexts.begin();
+ pcit != concordancePhraseToContexts.end(); ++pcit)
+ {
+ const Phrase &phrase = phrases.type(pcit->first);
+ for (ContextCounts::iterator ccit = pcit->second.begin();
+ ccit != pcit->second.end(); ++ccit)
+ {
+ const Context &context = contexts.type(ccit->first);
+ Dist tagCounts = conditional_probs(phrase, context, 0);
+ cout << phrase << " ||| " << context << " ||| " << argmax(tagCounts) << "\n";
+ }
+ }
+
+ return 0;
+}
+
+void addRandomNoise(Dist &d)
+{
+ for (Dist::iterator dit = d.begin(); dit != d.end(); ++dit)
+ *dit += rng();
+}
+
+void normalise(Dist &d)
+{
+ double z = 0;
+ for (Dist::iterator dit = d.begin(); dit != d.end(); ++dit)
+ z += *dit;
+ for (Dist::iterator dit = d.begin(); dit != d.end(); ++dit)
+ *dit /= z;
+}
+
+void addTo(Dist &d, const Dist &e)
+{
+ assert(d.size() == e.size());
+ for (int i = 0; i < (int) d.size(); ++i)
+ d[i] += e[i];
+}
+
+int argmax(const Dist &d)
+{
+ double best = d[0];
+ int index = 0;
+ for (int i = 1; i < (int) d.size(); ++i)
+ {
+ if (d[i] > best)
+ {
+ best = d[i];
+ index = i;
+ }
+ }
+ return index;
+}
+
+ostream &operator<<(ostream &out, const Phrase &phrase)
+{
+ for (Phrase::const_iterator pit = phrase.begin(); pit != phrase.end(); ++pit)
+ lexicon.display(((pit == phrase.begin()) ? out : out << " "), *pit);
+ return out;
+}
+
+ostream &operator<<(ostream &out, const Context &context)
+{
+ lexicon.display(out, get<0>(context));
+ lexicon.display(out << " ", get<1>(context));
+ lexicon.display(out << " <PHRASE> ", get<2>(context));
+ lexicon.display(out << " ", get<3>(context));
+ return out;
+}
+
+ostream &operator<<(ostream &out, const Dist &dist)
+{
+ for (Dist::const_iterator dit = dist.begin(); dit != dist.end(); ++dit)
+ out << ((dit == dist.begin()) ? "" : " ") << *dit;
+ return out;
+}
+
+ostream &operator<<(ostream &out, const ConditionalDist &dist)
+{
+ for (ConditionalDist::const_iterator dit = dist.begin(); dit != dist.end(); ++dit)
+ out << ((dit == dist.begin()) ? "" : "; ") << *dit;
+ return out;
+}
+
+// FIXME: slow - just use the phrase index, context index to do the mapping
+// (n.b. it's a sparse setup, not just equal to 3d array index)
+int
+lambda_index(const Phrase &phrase, const Context &context, int tag)
+{
+ return lambda_indices[phrase][context] + tag;
+}
+
+template <typename T>
+Dist
+penalised_conditionals(const Phrase &phrase, const Context &context,
+ const T &lambda, double *normalisation)
+{
+ Dist d = conditional_probs(phrase, context, 0);
+
+ double z = 0;
+ for (int t = 0; t < numTags; ++t)
+ {
+ d[t] *= exp(-lambda[lambda_index(phrase, context, t)]);
+ z += d[t];
+ }
+
+ if (normalisation)
+ *normalisation = z;
+
+ for (int t = 0; t < numTags; ++t)
+ d[t] /= z;
+
+ return d;
+}
+
+Dist
+conditional_probs(const Phrase &phrase, const Context &context, double *normalisation)
+{
+ Dist tagCounts(numTags, 0.0);
+ double z = 0;
+ for (int t = 0; t < numTags; ++t)
+ {
+ double prob = prior[t];
+ prob *= (probCtx[0][t][get<0>(context)] * probCtx[1][t][get<1>(context)] *
+ probCtx[2][t][get<2>(context)] * probCtx[3][t][get<3>(context)]);
+
+ if (includePhraseProb)
+ {
+ prob *= pow(1 - probPhraseLength[t], phrase.size() - 1) * probPhraseLength[t];
+ for (Phrase::const_iterator pit = phrase.begin(); pit != phrase.end(); ++pit)
+ prob *= probPhrase[t][*pit];
+ }
+
+ tagCounts[t] = prob;
+ z += prob;
+ }
+ if (normalisation)
+ *normalisation = z;
+
+ for (int t = 0; t < numTags; ++t)
+ tagCounts[t] /= z;
+
+ return tagCounts;
+}
+
+double
+penalised_log_likelihood(int n, const double *lambda, double *grad, void *)
+{
+ // return log Z(lambda, theta) over the corpus
+ // where theta are the global parameters (prior, probCtx*, probPhrase*)
+ // and lambda are lagrange multipliers for the posterior sparsity constraints
+ //
+ // this is formulated as:
+ // f = log Z(lambda) = sum_i log ( sum_i p_theta(t_i|p_i,c_i) exp [-lambda_{t_i,p_i,c_i}] )
+ // where i indexes the training examples - specifying the (p, c) pair (which may occur with count > 1)
+ //
+ // with derivative:
+ // f'_{tpc} = frac { - count(t,p,c) p_theta(t|p,c) exp (-lambda_{t,p,c}) }
+ // { sum_t' p_theta(t'|p,c) exp (-lambda_{t',p,c}) }
+
+ //cout << "penalised_log_likelihood with lambda ";
+ //copy(lambda, lambda+n, ostream_iterator<double>(cout, " "));
+ //cout << "\n";
+
+ double f = 0;
+ if (grad)
+ {
+ for (int i = 0; i < n; ++i)
+ grad[i] = 0.0;
+ }
+
+ for (int p = 0; p < phrases.size(); ++p)
+ {
+ const Phrase &phrase = phrases.type(p);
+ PhraseToContextCounts::const_iterator pcit = concordancePhraseToContexts.find(p);
+ for (ContextCounts::const_iterator ccit = pcit->second.begin();
+ ccit != pcit->second.end(); ++ccit)
+ {
+ const Context &context = contexts.type(ccit->first);
+ double z = 0;
+ Dist scores = penalised_conditionals(phrase, context, lambda, &z);
+
+ f += ccit->second * log(z);
+ //cout << "\tphrase: " << phrase << " context: " << context << " count: " << ccit->second << " z " << z << endl;
+ //cout << "\t\tscores: " << scores << "\n";
+
+ if (grad)
+ {
+ for (int t = 0; t < numTags; ++t)
+ {
+ int i = lambda_index(phrase, context, t); // FIXME: redundant lookups
+ assert(grad[i] == 0.0);
+ grad[i] = - ccit->second * scores[t];
+ }
+ }
+ }
+ }
+
+ //cout << "penalised_log_likelihood returning " << f;
+ //if (grad)
+ //{
+ //cout << "\ngradient: ";
+ //copy(grad, grad+n, ostream_iterator<double>(cout, " "));
+ //}
+ //cout << "\n";
+
+ return f;
+}
+
+typedef struct
+{
+ // one of p or c should be set to -1, in which case it will be marginalised out
+ // i.e. sum_p' lambda_{p'ct} <= threshold
+ // or sum_c' lambda_{pc't} <= threshold
+ int p, c, t, threshold;
+} constraint_data;
+
+double
+constraint_and_gradient(int n, const double *lambda, double *grad, void *data)
+{
+ constraint_data *d = (constraint_data *) data;
+ assert(d->t >= 0);
+ assert(d->threshold >= 0);
+
+ //cout << "constraint_and_gradient: t " << d->t << " p " << d->p << " c " << d->c << " tau " << d->threshold << endl;
+ //cout << "\tlambda ";
+ //copy(lambda, lambda+n, ostream_iterator<double>(cout, " "));
+ //cout << "\n";
+
+ // FIXME: it's crazy to use a dense gradient here => will only have a handful of non-zero entries
+ if (grad)
+ {
+ for (int i = 0; i < n; ++i)
+ grad[i] = 0.0;
+ }
+
+ //cout << "constraint_and_gradient: " << d->p << "; " << d->c << "; " << d->t << "; " << d->threshold << endl;
+
+ if (d->p >= 0)
+ {
+ assert(d->c < 0);
+ // sum_c lambda_pct <= delta [a.k.a. threshold]
+ // => sum_c lambda_pct - delta <= 0
+ // derivative_pct = { 1, if p and t match; 0, otherwise }
+
+ double val = -d->threshold;
+
+ const Phrase &phrase = phrases.type(d->p);
+ PhraseToContextCounts::const_iterator pcit = concordancePhraseToContexts.find(d->p);
+ assert(pcit != concordancePhraseToContexts.end());
+ for (ContextCounts::const_iterator ccit = pcit->second.begin();
+ ccit != pcit->second.end(); ++ccit)
+ {
+ const Context &context = contexts.type(ccit->first);
+ int i = lambda_index(phrase, context, d->t);
+ val += lambda[i];
+ if (grad) grad[i] = 1;
+ }
+ //cout << "\treturning " << val << endl;
+
+ return val;
+ }
+ else
+ {
+ assert(d->c >= 0);
+ assert(d->p < 0);
+ // sum_p lambda_pct <= gamma [a.k.a. threshold]
+ // => sum_p lambda_pct - gamma <= 0
+ // derivative_pct = { 1, if c and t match; 0, otherwise }
+
+ double val = -d->threshold;
+
+ const Context &context = contexts.type(d->c);
+ ContextToPhraseCounts::iterator cpit = concordanceContextToPhrases.find(d->c);
+ assert(cpit != concordanceContextToPhrases.end());
+ for (PhraseCounts::iterator pcit = cpit->second.begin();
+ pcit != cpit->second.end(); ++pcit)
+ {
+ const Phrase &phrase = phrases.type(pcit->first);
+ int i = lambda_index(phrase, context, d->t);
+ val += lambda[i];
+ if (grad) grad[i] = 1;
+ }
+ //cout << "\treturning " << val << endl;
+
+ return val;
+ }
+}
+
+void
+optimise_lambda(double delta, double gamma, vector<double> &lambdav)
+{
+ int num_lambdas = lambdav.size();
+ if (lambda_indices.empty() || lambdav.empty())
+ {
+ lambda_indices.clear();
+ lambdav.clear();
+
+ int i = 0;
+ for (int p = 0; p < phrases.size(); ++p)
+ {
+ const Phrase &phrase = phrases.type(p);
+ PhraseToContextCounts::iterator pcit = concordancePhraseToContexts.find(p);
+ for (ContextCounts::iterator ccit = pcit->second.begin();
+ ccit != pcit->second.end(); ++ccit)
+ {
+ const Context &context = contexts.type(ccit->first);
+ lambda_indices[phrase][context] = i;
+ i += numTags;
+ }
+ }
+ num_lambdas = i;
+ lambdav.resize(num_lambdas);
+ }
+ //cout << "optimise_lambda: #langrange multipliers " << num_lambdas << endl;
+
+ // FIXME: better to work with an implicit representation to save memory usage
+ int num_constraints = (((delta > 0) ? phrases.size() : 0) + ((gamma > 0) ? contexts.size() : 0)) * numTags;
+ //cout << "optimise_lambda: #constraints " << num_constraints << endl;
+ constraint_data *data = new constraint_data[num_constraints];
+ int i = 0;
+ if (delta > 0)
+ {
+ for (int p = 0; p < phrases.size(); ++p)
+ {
+ for (int t = 0; t < numTags; ++t, ++i)
+ {
+ constraint_data &d = data[i];
+ d.p = p;
+ d.c = -1;
+ d.t = t;
+ d.threshold = delta;
+ }
+ }
+ }
+
+ if (gamma > 0)
+ {
+ for (int c = 0; c < contexts.size(); ++c)
+ {
+ for (int t = 0; t < numTags; ++t, ++i)
+ {
+ constraint_data &d = data[i];
+ d.p = -1;
+ d.c = c;
+ d.t = t;
+ d.threshold = gamma;
+ }
+ }
+ }
+ assert(i == num_constraints);
+
+ double lambda[num_lambdas];
+ double lb[num_lambdas], ub[num_lambdas];
+ for (i = 0; i < num_lambdas; ++i)
+ {
+ lambda[i] = lambdav[i]; // starting value
+ lb[i] = 0; // lower bound
+ if (delta <= 0) // upper bound
+ ub[i] = gamma;
+ else if (gamma <= 0)
+ ub[i] = delta;
+ else
+ assert(false);
+ }
+
+ //print_primal_dual(lambda, delta, gamma);
+
+ double minf;
+ int error_code = nlopt_minimize_constrained(NLOPT_LN_COBYLA, num_lambdas, penalised_log_likelihood, NULL,
+ num_constraints, constraint_and_gradient, data, sizeof(constraint_data),
+ lb, ub, lambda, &minf, -HUGE_VAL, 0.0, 0.0, 1e-4, NULL, 0, 0.0);
+ //cout << "optimise error code " << error_code << endl;
+
+ //print_primal_dual(lambda, delta, gamma);
+
+ delete [] data;
+
+ if (error_code < 0)
+ cout << "WARNING: optimisation failed with error code: " << error_code << endl;
+ //else
+ //{
+ //cout << "success; minf " << minf << endl;
+ //print_primal_dual(lambda, delta, gamma);
+ //}
+
+ lambdav = vector<double>(&lambda[0], &lambda[0] + num_lambdas);
+}
+
+// FIXME: inefficient - cache the scores
+double
+expected_violation_phrases(const double *lambda)
+{
+ // sum_pt max_c E_q[phi_pct]
+ double violation = 0;
+
+ for (int p = 0; p < phrases.size(); ++p)
+ {
+ const Phrase &phrase = phrases.type(p);
+ PhraseToContextCounts::const_iterator pcit = concordancePhraseToContexts.find(p);
+
+ for (int t = 0; t < numTags; ++t)
+ {
+ double best = 0;
+ for (ContextCounts::const_iterator ccit = pcit->second.begin();
+ ccit != pcit->second.end(); ++ccit)
+ {
+ const Context &context = contexts.type(ccit->first);
+ Dist scores = penalised_conditionals(phrase, context, lambda, 0);
+ best = max(best, scores[t]);
+ }
+ violation += best;
+ }
+ }
+
+ return violation;
+}
+
+// FIXME: inefficient - cache the scores
+double
+expected_violation_contexts(const double *lambda)
+{
+ // sum_ct max_p E_q[phi_pct]
+ double violation = 0;
+
+ for (int c = 0; c < contexts.size(); ++c)
+ {
+ const Context &context = contexts.type(c);
+ ContextToPhraseCounts::iterator cpit = concordanceContextToPhrases.find(c);
+
+ for (int t = 0; t < numTags; ++t)
+ {
+ double best = 0;
+ for (PhraseCounts::iterator pit = cpit->second.begin();
+ pit != cpit->second.end(); ++pit)
+ {
+ const Phrase &phrase = phrases.type(pit->first);
+ Dist scores = penalised_conditionals(phrase, context, lambda, 0);
+ best = max(best, scores[t]);
+ }
+ violation += best;
+ }
+ }
+
+ return violation;
+}
+
+// FIXME: possibly inefficient
+double
+primal_likelihood() // FIXME: primal evaluation needs to use lambda and calculate l1linf terms
+{
+ double llh = 0;
+ for (int p = 0; p < phrases.size(); ++p)
+ {
+ const Phrase &phrase = phrases.type(p);
+ PhraseToContextCounts::const_iterator pcit = concordancePhraseToContexts.find(p);
+ for (ContextCounts::const_iterator ccit = pcit->second.begin();
+ ccit != pcit->second.end(); ++ccit)
+ {
+ const Context &context = contexts.type(ccit->first);
+ double z = 0;
+ Dist scores = conditional_probs(phrase, context, &z);
+ llh += ccit->second * log(z);
+ }
+ }
+ return llh;
+}
+
+// FIXME: inefficient - cache the scores
+double
+primal_kl_divergence(const double *lambda)
+{
+ // return KL(q || p) = sum_y q(y) { log q(y) - log p(y | x) }
+ // = sum_y q(y) { log p(y | x) - lambda . phi(x, y) - log Z - log p(y | x) }
+ // = sum_y q(y) { - lambda . phi(x, y) } - log Z
+ // and q(y) factors with each edge, ditto for Z
+
+ double feature_sum = 0, log_z = 0;
+ for (int p = 0; p < phrases.size(); ++p)
+ {
+ const Phrase &phrase = phrases.type(p);
+ PhraseToContextCounts::const_iterator pcit = concordancePhraseToContexts.find(p);
+ for (ContextCounts::const_iterator ccit = pcit->second.begin();
+ ccit != pcit->second.end(); ++ccit)
+ {
+ const Context &context = contexts.type(ccit->first);
+
+ double local_z = 0;
+ double local_f = 0;
+ Dist d = conditional_probs(phrase, context, 0);
+ for (int t = 0; t < numTags; ++t)
+ {
+ int i = lambda_index(phrase, context, t);
+ double s = d[t] * exp(-lambda[i]);
+ local_f += lambda[i] * s;
+ local_z += s;
+ }
+
+ log_z += ccit->second * log(local_z);
+ feature_sum += ccit->second * (local_f / local_z);
+ }
+ }
+
+ return -feature_sum - log_z;
+}
+
+// FIXME: inefficient - cache the scores
+double
+dual(const double *lambda)
+{
+ // return log(Z) = - log { sum_y p(y | x) exp( - lambda . phi(x, y) }
+ // n.b. have flipped the sign as we're minimising
+
+ double z = 0;
+ for (int p = 0; p < phrases.size(); ++p)
+ {
+ const Phrase &phrase = phrases.type(p);
+ PhraseToContextCounts::const_iterator pcit = concordancePhraseToContexts.find(p);
+ for (ContextCounts::const_iterator ccit = pcit->second.begin();
+ ccit != pcit->second.end(); ++ccit)
+ {
+ const Context &context = contexts.type(ccit->first);
+ double lz = 0;
+ Dist scores = penalised_conditionals(phrase, context, lambda, &z);
+ z += lz * ccit->second;
+ }
+ }
+ return log(z);
+}
+
+void
+print_primal_dual(const double *lambda, double delta, double gamma)
+{
+ double likelihood = primal_likelihood();
+ double kl = primal_kl_divergence(lambda);
+ double sum_pt = expected_violation_phrases(lambda);
+ double sum_ct = expected_violation_contexts(lambda);
+ //double d = dual(lambda);
+
+ cout << "\tllh=" << likelihood
+ << " kl=" << kl
+ << " violations phrases=" << sum_pt
+ << " contexts=" << sum_ct
+ //<< " primal=" << (kl + delta * sum_pt + gamma * sum_ct)
+ //<< " dual=" << d
+ << " objective=" << (likelihood - kl + delta * sum_pt + gamma * sum_ct)
+ << endl;
+}