From ab3534c45f463e541f3baf05006a50b64e3bbe31 Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Mon, 28 Jun 2010 19:34:58 +0000 Subject: First bits of code for PR training git-svn-id: https://ws10smt.googlecode.com/svn/trunk@44 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/posterior-regularisation/em.cc | 830 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 830 insertions(+) create mode 100644 gi/posterior-regularisation/em.cc (limited to 'gi/posterior-regularisation/em.cc') diff --git a/gi/posterior-regularisation/em.cc b/gi/posterior-regularisation/em.cc new file mode 100644 index 00000000..f6c9fd68 --- /dev/null +++ b/gi/posterior-regularisation/em.cc @@ -0,0 +1,830 @@ +// Input of the form: +// " the phantom of the opera " tickets for tonight ? ||| C=1 ||| seats for ? ||| C=1 ||| i see ? ||| C=1 +// phrase TAB [context]+ +// where context = phrase ||| C=... which are separated by ||| + +// Model parameterised as follows: +// - each phrase, p, is allocated a latent state, t +// - this is used to generate the contexts, c +// - each context is generated using 4 independent multinomials, one for each position LL, L, R, RR + +// Training with EM: +// - e-step is estimating P(t|p,c) for all x,c +// - m-step is estimating model parameters P(p,c,t) = P(t) P(p|t) P(c|t) + +// Sexing it up: +// - constrain the posteriors P(t|c) and P(t|p) to have few high-magnitude entries +// - improve the generation of phrase internals, e.g., generate edge words from +// different distribution to central words + +#include "alphabet.hh" +#include "log_add.hh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace std::tr1; + +const int numTags = 5; +const int numIterations = 100; +const bool posterior_regularisation = true; +const double PHRASE_VIOLATION_WEIGHT = 10; +const double CONTEXT_VIOLATION_WEIGHT = 0; +const bool includePhraseProb = false; + +// Data structures: +Alphabet lexicon; +typedef vector Phrase; +typedef tuple Context; +Alphabet phrases; +Alphabet contexts; + +typedef map ContextCounts; +typedef map PhraseCounts; +typedef map PhraseToContextCounts; +typedef map ContextToPhraseCounts; + +PhraseToContextCounts concordancePhraseToContexts; +ContextToPhraseCounts concordanceContextToPhrases; + +typedef vector Dist; +typedef vector ConditionalDist; +Dist prior; // class -> P(class) +vector probCtx; // word -> class -> P(word | class), for each position of context word +ConditionalDist probPhrase; // class -> P(word | class) +Dist probPhraseLength; // class -> P(length | class) expressed as geometric distribution parameter + +mt19937 randomGenerator((size_t) time(NULL)); +uniform_real uniDist(0.0, 1e-1); +variate_generator< mt19937, uniform_real > rng(randomGenerator, uniDist); + +void addRandomNoise(Dist &d); +void normalise(Dist &d); +void addTo(Dist &d, const Dist &e); +int argmax(const Dist &d); + +map > lambda_indices; + +Dist conditional_probs(const Phrase &phrase, const Context &context, double *normalisation = 0); +template +Dist +penalised_conditionals(const Phrase &phrase, const Context &context, + const T &lambda, double *normalisation); +//Dist penalised_conditionals(const Phrase &phrase, const Context &context, const double *lambda, double *normalisation = 0); +double penalised_log_likelihood(int n, const double *lambda, double *gradient, void *data); +void optimise_lambda(double delta, double gamma, vector &lambda); +double expected_violation_phrases(const double *lambda); +double expected_violation_contexts(const double *lambda); +double primal_kl_divergence(const double *lambda); +double dual(const double *lambda); +void print_primal_dual(const double *lambda, double delta, double gamma); + +ostream &operator<<(ostream &, const Phrase &); +ostream &operator<<(ostream &, const Context &); +ostream &operator<<(ostream &, const Dist &); +ostream &operator<<(ostream &, const ConditionalDist &); + +int +main(int argc, char *argv[]) +{ + randomGenerator.seed(time(NULL)); + + int edges = 0; + istream &input = cin; + while (input.good()) + { + // read the phrase + string phraseString; + Phrase phrase; + getline(input, phraseString, '\t'); + istringstream pinput(phraseString); + string token; + while (pinput >> token) + phrase.push_back(lexicon.insert(token)); + int phraseId = phrases.insert(phrase); + + // read the rest, storing each context + string remainder; + getline(input, remainder, '\n'); + istringstream rinput(remainder); + Context context(-1, -1, -1, -1); + int index = 0; + while (rinput >> token) + { + if (token != "|||" && token != "") + { + if (index < 4) + { + // eugh! damn templates + switch (index) + { + case 0: get<0>(context) = lexicon.insert(token); break; + case 1: get<1>(context) = lexicon.insert(token); break; + case 2: get<2>(context) = lexicon.insert(token); break; + case 3: get<3>(context) = lexicon.insert(token); break; + default: assert(false); + } + index += 1; + } + else if (token.find("C=") == 0) + { + int contextId = contexts.insert(context); + int count = atoi(token.substr(strlen("C=")).c_str()); + concordancePhraseToContexts[phraseId][contextId] += count; + concordanceContextToPhrases[contextId][phraseId] += count; + index = 0; + context = Context(-1, -1, -1, -1); + edges += 1; + } + } + } + + // trigger EOF + input >> ws; + } + + cout << "Read in " << phrases.size() << " phrases" + << " and " << contexts.size() << " contexts" + << " and " << edges << " edges" + << " and " << lexicon.size() << " word types\n"; + + // FIXME: filter out low count phrases and low count contexts (based on individual words?) + // now populate model parameters with uniform + random noise + prior.resize(numTags, 1.0); + addRandomNoise(prior); + normalise(prior); + + probCtx.resize(4, ConditionalDist(numTags, Dist(lexicon.size(), 1.0))); + if (includePhraseProb) + probPhrase.resize(numTags, Dist(lexicon.size(), 1.0)); + for (int t = 0; t < numTags; ++t) + { + for (int j = 0; j < 4; ++j) + { + addRandomNoise(probCtx[j][t]); + normalise(probCtx[j][t]); + } + if (includePhraseProb) + { + addRandomNoise(probPhrase[t]); + normalise(probPhrase[t]); + } + } + if (includePhraseProb) + { + probPhraseLength.resize(numTags, 0.5); // geometric distribution p=0.5 + addRandomNoise(probPhraseLength); + } + + cout << "\tprior: " << prior << "\n"; + //cout << "\tcontext: " << probCtx << "\n"; + //cout << "\tphrase: " << probPhrase << "\n"; + //cout << "\tphraseLen: " << probPhraseLength << endl; + + vector lambda; + + // now do EM training + for (int iteration = 0; iteration < numIterations; ++iteration) + { + cout << "EM iteration " << iteration << endl; + + if (posterior_regularisation) + optimise_lambda(PHRASE_VIOLATION_WEIGHT, CONTEXT_VIOLATION_WEIGHT, lambda); + //cout << "\tlambda " << lambda << endl; + + Dist countsPrior(numTags, 0.0); + vector countsCtx(4, ConditionalDist(numTags, Dist(lexicon.size(), 1e-10))); + ConditionalDist countsPhrase(numTags, Dist(lexicon.size(), 1e-10)); + Dist countsPhraseLength(numTags, 0.0); + Dist nPhrases(numTags, 0.0); + + double llh = 0; + for (PhraseToContextCounts::iterator pcit = concordancePhraseToContexts.begin(); + pcit != concordancePhraseToContexts.end(); ++pcit) + { + const Phrase &phrase = phrases.type(pcit->first); + + // e-step: estimate latent class probs; compile (class,word) stats for m-step + for (ContextCounts::iterator ccit = pcit->second.begin(); + ccit != pcit->second.end(); ++ccit) + { + const Context &context = contexts.type(ccit->first); + + double z = 0; + Dist tagCounts; + if (!posterior_regularisation) + tagCounts = conditional_probs(phrase, context, &z); + else + tagCounts = penalised_conditionals(phrase, context, lambda, &z); + + llh += log(z) * ccit->second; + addTo(countsPrior, tagCounts); // FIXME: times ccit->secon + + for (int t = 0; t < numTags; ++t) + { + for (int j = 0; j < 4; ++j) + countsCtx[j][t][get<0>(context)] += tagCounts[t] * ccit->second; + + if (includePhraseProb) + { + for (Phrase::const_iterator pit = phrase.begin(); pit != phrase.end(); ++pit) + countsPhrase[t][*pit] += tagCounts[t] * ccit->second; + countsPhraseLength[t] += phrase.size() * tagCounts[t] * ccit->second; + nPhrases[t] += tagCounts[t] * ccit->second; + } + } + } + } + + cout << "M-step\n"; + + // m-step: normalise prior and (class,word) stats and assign to model parameters + normalise(countsPrior); + prior = countsPrior; + for (int t = 0; t < numTags; ++t) + { + //cout << "\t\tt " << t << " prior " << countsPrior[t] << "\n"; + for (int j = 0; j < 4; ++j) + normalise(countsCtx[j][t]); + if (includePhraseProb) + { + normalise(countsPhrase[t]); + countsPhraseLength[t] = nPhrases[t] / countsPhraseLength[t]; + } + } + probCtx = countsCtx; + if (includePhraseProb) + { + probPhrase = countsPhrase; + probPhraseLength = countsPhraseLength; + } + + double *larray = new double[lambda.size()]; + copy(lambda.begin(), lambda.end(), larray); + print_primal_dual(larray, PHRASE_VIOLATION_WEIGHT, CONTEXT_VIOLATION_WEIGHT); + delete [] larray; + + //cout << "\tllh " << llh << endl; + //cout << "\tprior: " << prior << "\n"; + //cout << "\tcontext: " << probCtx << "\n"; + //cout << "\tphrase: " << probPhrase << "\n"; + //cout << "\tphraseLen: " << probPhraseLength << "\n"; + } + + // output class membership + for (PhraseToContextCounts::iterator pcit = concordancePhraseToContexts.begin(); + pcit != concordancePhraseToContexts.end(); ++pcit) + { + const Phrase &phrase = phrases.type(pcit->first); + for (ContextCounts::iterator ccit = pcit->second.begin(); + ccit != pcit->second.end(); ++ccit) + { + const Context &context = contexts.type(ccit->first); + Dist tagCounts = conditional_probs(phrase, context, 0); + cout << phrase << " ||| " << context << " ||| " << argmax(tagCounts) << "\n"; + } + } + + return 0; +} + +void addRandomNoise(Dist &d) +{ + for (Dist::iterator dit = d.begin(); dit != d.end(); ++dit) + *dit += rng(); +} + +void normalise(Dist &d) +{ + double z = 0; + for (Dist::iterator dit = d.begin(); dit != d.end(); ++dit) + z += *dit; + for (Dist::iterator dit = d.begin(); dit != d.end(); ++dit) + *dit /= z; +} + +void addTo(Dist &d, const Dist &e) +{ + assert(d.size() == e.size()); + for (int i = 0; i < (int) d.size(); ++i) + d[i] += e[i]; +} + +int argmax(const Dist &d) +{ + double best = d[0]; + int index = 0; + for (int i = 1; i < (int) d.size(); ++i) + { + if (d[i] > best) + { + best = d[i]; + index = i; + } + } + return index; +} + +ostream &operator<<(ostream &out, const Phrase &phrase) +{ + for (Phrase::const_iterator pit = phrase.begin(); pit != phrase.end(); ++pit) + lexicon.display(((pit == phrase.begin()) ? out : out << " "), *pit); + return out; +} + +ostream &operator<<(ostream &out, const Context &context) +{ + lexicon.display(out, get<0>(context)); + lexicon.display(out << " ", get<1>(context)); + lexicon.display(out << " ", get<2>(context)); + lexicon.display(out << " ", get<3>(context)); + return out; +} + +ostream &operator<<(ostream &out, const Dist &dist) +{ + for (Dist::const_iterator dit = dist.begin(); dit != dist.end(); ++dit) + out << ((dit == dist.begin()) ? "" : " ") << *dit; + return out; +} + +ostream &operator<<(ostream &out, const ConditionalDist &dist) +{ + for (ConditionalDist::const_iterator dit = dist.begin(); dit != dist.end(); ++dit) + out << ((dit == dist.begin()) ? "" : "; ") << *dit; + return out; +} + +// FIXME: slow - just use the phrase index, context index to do the mapping +// (n.b. it's a sparse setup, not just equal to 3d array index) +int +lambda_index(const Phrase &phrase, const Context &context, int tag) +{ + return lambda_indices[phrase][context] + tag; +} + +template +Dist +penalised_conditionals(const Phrase &phrase, const Context &context, + const T &lambda, double *normalisation) +{ + Dist d = conditional_probs(phrase, context, 0); + + double z = 0; + for (int t = 0; t < numTags; ++t) + { + d[t] *= exp(-lambda[lambda_index(phrase, context, t)]); + z += d[t]; + } + + if (normalisation) + *normalisation = z; + + for (int t = 0; t < numTags; ++t) + d[t] /= z; + + return d; +} + +Dist +conditional_probs(const Phrase &phrase, const Context &context, double *normalisation) +{ + Dist tagCounts(numTags, 0.0); + double z = 0; + for (int t = 0; t < numTags; ++t) + { + double prob = prior[t]; + prob *= (probCtx[0][t][get<0>(context)] * probCtx[1][t][get<1>(context)] * + probCtx[2][t][get<2>(context)] * probCtx[3][t][get<3>(context)]); + + if (includePhraseProb) + { + prob *= pow(1 - probPhraseLength[t], phrase.size() - 1) * probPhraseLength[t]; + for (Phrase::const_iterator pit = phrase.begin(); pit != phrase.end(); ++pit) + prob *= probPhrase[t][*pit]; + } + + tagCounts[t] = prob; + z += prob; + } + if (normalisation) + *normalisation = z; + + for (int t = 0; t < numTags; ++t) + tagCounts[t] /= z; + + return tagCounts; +} + +double +penalised_log_likelihood(int n, const double *lambda, double *grad, void *) +{ + // return log Z(lambda, theta) over the corpus + // where theta are the global parameters (prior, probCtx*, probPhrase*) + // and lambda are lagrange multipliers for the posterior sparsity constraints + // + // this is formulated as: + // f = log Z(lambda) = sum_i log ( sum_i p_theta(t_i|p_i,c_i) exp [-lambda_{t_i,p_i,c_i}] ) + // where i indexes the training examples - specifying the (p, c) pair (which may occur with count > 1) + // + // with derivative: + // f'_{tpc} = frac { - count(t,p,c) p_theta(t|p,c) exp (-lambda_{t,p,c}) } + // { sum_t' p_theta(t'|p,c) exp (-lambda_{t',p,c}) } + + //cout << "penalised_log_likelihood with lambda "; + //copy(lambda, lambda+n, ostream_iterator(cout, " ")); + //cout << "\n"; + + double f = 0; + if (grad) + { + for (int i = 0; i < n; ++i) + grad[i] = 0.0; + } + + for (int p = 0; p < phrases.size(); ++p) + { + const Phrase &phrase = phrases.type(p); + PhraseToContextCounts::const_iterator pcit = concordancePhraseToContexts.find(p); + for (ContextCounts::const_iterator ccit = pcit->second.begin(); + ccit != pcit->second.end(); ++ccit) + { + const Context &context = contexts.type(ccit->first); + double z = 0; + Dist scores = penalised_conditionals(phrase, context, lambda, &z); + + f += ccit->second * log(z); + //cout << "\tphrase: " << phrase << " context: " << context << " count: " << ccit->second << " z " << z << endl; + //cout << "\t\tscores: " << scores << "\n"; + + if (grad) + { + for (int t = 0; t < numTags; ++t) + { + int i = lambda_index(phrase, context, t); // FIXME: redundant lookups + assert(grad[i] == 0.0); + grad[i] = - ccit->second * scores[t]; + } + } + } + } + + //cout << "penalised_log_likelihood returning " << f; + //if (grad) + //{ + //cout << "\ngradient: "; + //copy(grad, grad+n, ostream_iterator(cout, " ")); + //} + //cout << "\n"; + + return f; +} + +typedef struct +{ + // one of p or c should be set to -1, in which case it will be marginalised out + // i.e. sum_p' lambda_{p'ct} <= threshold + // or sum_c' lambda_{pc't} <= threshold + int p, c, t, threshold; +} constraint_data; + +double +constraint_and_gradient(int n, const double *lambda, double *grad, void *data) +{ + constraint_data *d = (constraint_data *) data; + assert(d->t >= 0); + assert(d->threshold >= 0); + + //cout << "constraint_and_gradient: t " << d->t << " p " << d->p << " c " << d->c << " tau " << d->threshold << endl; + //cout << "\tlambda "; + //copy(lambda, lambda+n, ostream_iterator(cout, " ")); + //cout << "\n"; + + // FIXME: it's crazy to use a dense gradient here => will only have a handful of non-zero entries + if (grad) + { + for (int i = 0; i < n; ++i) + grad[i] = 0.0; + } + + //cout << "constraint_and_gradient: " << d->p << "; " << d->c << "; " << d->t << "; " << d->threshold << endl; + + if (d->p >= 0) + { + assert(d->c < 0); + // sum_c lambda_pct <= delta [a.k.a. threshold] + // => sum_c lambda_pct - delta <= 0 + // derivative_pct = { 1, if p and t match; 0, otherwise } + + double val = -d->threshold; + + const Phrase &phrase = phrases.type(d->p); + PhraseToContextCounts::const_iterator pcit = concordancePhraseToContexts.find(d->p); + assert(pcit != concordancePhraseToContexts.end()); + for (ContextCounts::const_iterator ccit = pcit->second.begin(); + ccit != pcit->second.end(); ++ccit) + { + const Context &context = contexts.type(ccit->first); + int i = lambda_index(phrase, context, d->t); + val += lambda[i]; + if (grad) grad[i] = 1; + } + //cout << "\treturning " << val << endl; + + return val; + } + else + { + assert(d->c >= 0); + assert(d->p < 0); + // sum_p lambda_pct <= gamma [a.k.a. threshold] + // => sum_p lambda_pct - gamma <= 0 + // derivative_pct = { 1, if c and t match; 0, otherwise } + + double val = -d->threshold; + + const Context &context = contexts.type(d->c); + ContextToPhraseCounts::iterator cpit = concordanceContextToPhrases.find(d->c); + assert(cpit != concordanceContextToPhrases.end()); + for (PhraseCounts::iterator pcit = cpit->second.begin(); + pcit != cpit->second.end(); ++pcit) + { + const Phrase &phrase = phrases.type(pcit->first); + int i = lambda_index(phrase, context, d->t); + val += lambda[i]; + if (grad) grad[i] = 1; + } + //cout << "\treturning " << val << endl; + + return val; + } +} + +void +optimise_lambda(double delta, double gamma, vector &lambdav) +{ + int num_lambdas = lambdav.size(); + if (lambda_indices.empty() || lambdav.empty()) + { + lambda_indices.clear(); + lambdav.clear(); + + int i = 0; + for (int p = 0; p < phrases.size(); ++p) + { + const Phrase &phrase = phrases.type(p); + PhraseToContextCounts::iterator pcit = concordancePhraseToContexts.find(p); + for (ContextCounts::iterator ccit = pcit->second.begin(); + ccit != pcit->second.end(); ++ccit) + { + const Context &context = contexts.type(ccit->first); + lambda_indices[phrase][context] = i; + i += numTags; + } + } + num_lambdas = i; + lambdav.resize(num_lambdas); + } + //cout << "optimise_lambda: #langrange multipliers " << num_lambdas << endl; + + // FIXME: better to work with an implicit representation to save memory usage + int num_constraints = (((delta > 0) ? phrases.size() : 0) + ((gamma > 0) ? contexts.size() : 0)) * numTags; + //cout << "optimise_lambda: #constraints " << num_constraints << endl; + constraint_data *data = new constraint_data[num_constraints]; + int i = 0; + if (delta > 0) + { + for (int p = 0; p < phrases.size(); ++p) + { + for (int t = 0; t < numTags; ++t, ++i) + { + constraint_data &d = data[i]; + d.p = p; + d.c = -1; + d.t = t; + d.threshold = delta; + } + } + } + + if (gamma > 0) + { + for (int c = 0; c < contexts.size(); ++c) + { + for (int t = 0; t < numTags; ++t, ++i) + { + constraint_data &d = data[i]; + d.p = -1; + d.c = c; + d.t = t; + d.threshold = gamma; + } + } + } + assert(i == num_constraints); + + double lambda[num_lambdas]; + double lb[num_lambdas], ub[num_lambdas]; + for (i = 0; i < num_lambdas; ++i) + { + lambda[i] = lambdav[i]; // starting value + lb[i] = 0; // lower bound + if (delta <= 0) // upper bound + ub[i] = gamma; + else if (gamma <= 0) + ub[i] = delta; + else + assert(false); + } + + //print_primal_dual(lambda, delta, gamma); + + double minf; + int error_code = nlopt_minimize_constrained(NLOPT_LN_COBYLA, num_lambdas, penalised_log_likelihood, NULL, + num_constraints, constraint_and_gradient, data, sizeof(constraint_data), + lb, ub, lambda, &minf, -HUGE_VAL, 0.0, 0.0, 1e-4, NULL, 0, 0.0); + //cout << "optimise error code " << error_code << endl; + + //print_primal_dual(lambda, delta, gamma); + + delete [] data; + + if (error_code < 0) + cout << "WARNING: optimisation failed with error code: " << error_code << endl; + //else + //{ + //cout << "success; minf " << minf << endl; + //print_primal_dual(lambda, delta, gamma); + //} + + lambdav = vector(&lambda[0], &lambda[0] + num_lambdas); +} + +// FIXME: inefficient - cache the scores +double +expected_violation_phrases(const double *lambda) +{ + // sum_pt max_c E_q[phi_pct] + double violation = 0; + + for (int p = 0; p < phrases.size(); ++p) + { + const Phrase &phrase = phrases.type(p); + PhraseToContextCounts::const_iterator pcit = concordancePhraseToContexts.find(p); + + for (int t = 0; t < numTags; ++t) + { + double best = 0; + for (ContextCounts::const_iterator ccit = pcit->second.begin(); + ccit != pcit->second.end(); ++ccit) + { + const Context &context = contexts.type(ccit->first); + Dist scores = penalised_conditionals(phrase, context, lambda, 0); + best = max(best, scores[t]); + } + violation += best; + } + } + + return violation; +} + +// FIXME: inefficient - cache the scores +double +expected_violation_contexts(const double *lambda) +{ + // sum_ct max_p E_q[phi_pct] + double violation = 0; + + for (int c = 0; c < contexts.size(); ++c) + { + const Context &context = contexts.type(c); + ContextToPhraseCounts::iterator cpit = concordanceContextToPhrases.find(c); + + for (int t = 0; t < numTags; ++t) + { + double best = 0; + for (PhraseCounts::iterator pit = cpit->second.begin(); + pit != cpit->second.end(); ++pit) + { + const Phrase &phrase = phrases.type(pit->first); + Dist scores = penalised_conditionals(phrase, context, lambda, 0); + best = max(best, scores[t]); + } + violation += best; + } + } + + return violation; +} + +// FIXME: possibly inefficient +double +primal_likelihood() // FIXME: primal evaluation needs to use lambda and calculate l1linf terms +{ + double llh = 0; + for (int p = 0; p < phrases.size(); ++p) + { + const Phrase &phrase = phrases.type(p); + PhraseToContextCounts::const_iterator pcit = concordancePhraseToContexts.find(p); + for (ContextCounts::const_iterator ccit = pcit->second.begin(); + ccit != pcit->second.end(); ++ccit) + { + const Context &context = contexts.type(ccit->first); + double z = 0; + Dist scores = conditional_probs(phrase, context, &z); + llh += ccit->second * log(z); + } + } + return llh; +} + +// FIXME: inefficient - cache the scores +double +primal_kl_divergence(const double *lambda) +{ + // return KL(q || p) = sum_y q(y) { log q(y) - log p(y | x) } + // = sum_y q(y) { log p(y | x) - lambda . phi(x, y) - log Z - log p(y | x) } + // = sum_y q(y) { - lambda . phi(x, y) } - log Z + // and q(y) factors with each edge, ditto for Z + + double feature_sum = 0, log_z = 0; + for (int p = 0; p < phrases.size(); ++p) + { + const Phrase &phrase = phrases.type(p); + PhraseToContextCounts::const_iterator pcit = concordancePhraseToContexts.find(p); + for (ContextCounts::const_iterator ccit = pcit->second.begin(); + ccit != pcit->second.end(); ++ccit) + { + const Context &context = contexts.type(ccit->first); + + double local_z = 0; + double local_f = 0; + Dist d = conditional_probs(phrase, context, 0); + for (int t = 0; t < numTags; ++t) + { + int i = lambda_index(phrase, context, t); + double s = d[t] * exp(-lambda[i]); + local_f += lambda[i] * s; + local_z += s; + } + + log_z += ccit->second * log(local_z); + feature_sum += ccit->second * (local_f / local_z); + } + } + + return -feature_sum - log_z; +} + +// FIXME: inefficient - cache the scores +double +dual(const double *lambda) +{ + // return log(Z) = - log { sum_y p(y | x) exp( - lambda . phi(x, y) } + // n.b. have flipped the sign as we're minimising + + double z = 0; + for (int p = 0; p < phrases.size(); ++p) + { + const Phrase &phrase = phrases.type(p); + PhraseToContextCounts::const_iterator pcit = concordancePhraseToContexts.find(p); + for (ContextCounts::const_iterator ccit = pcit->second.begin(); + ccit != pcit->second.end(); ++ccit) + { + const Context &context = contexts.type(ccit->first); + double lz = 0; + Dist scores = penalised_conditionals(phrase, context, lambda, &z); + z += lz * ccit->second; + } + } + return log(z); +} + +void +print_primal_dual(const double *lambda, double delta, double gamma) +{ + double likelihood = primal_likelihood(); + double kl = primal_kl_divergence(lambda); + double sum_pt = expected_violation_phrases(lambda); + double sum_ct = expected_violation_contexts(lambda); + //double d = dual(lambda); + + cout << "\tllh=" << likelihood + << " kl=" << kl + << " violations phrases=" << sum_pt + << " contexts=" << sum_ct + //<< " primal=" << (kl + delta * sum_pt + gamma * sum_ct) + //<< " dual=" << d + << " objective=" << (likelihood - kl + delta * sum_pt + gamma * sum_ct) + << endl; +} -- cgit v1.2.3