From 9e45f895aaec5c7a2f362aa532ca5ca4325e102b Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 21 Feb 2012 11:53:01 -0500 Subject: basic lbl model, nothing to see here --- training/lbl_model.cc | 147 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 126 insertions(+), 21 deletions(-) diff --git a/training/lbl_model.cc b/training/lbl_model.cc index ccd29255..4759eedc 100644 --- a/training/lbl_model.cc +++ b/training/lbl_model.cc @@ -5,13 +5,18 @@ int main() { std::cerr << "Please rebuild with --with-eigen PATH\n"; return 1; } #else +#include +#include #include #include +#include // memset +#include #include #include #include +#include "array2d.h" #include "m.h" #include "lattice.h" #include "stringlib.h" @@ -21,7 +26,7 @@ namespace po = boost::program_options; using namespace std; -#define kDIMENSIONS 10 +#define kDIMENSIONS 25 typedef Eigen::Matrix RVector; typedef Eigen::Matrix RTVector; typedef Eigen::Matrix TMatrix; @@ -32,6 +37,8 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { opts.add_options() ("input,i",po::value(),"Input file") ("iterations,I",po::value()->default_value(1000),"Number of iterations of training") + ("eta,e", po::value()->default_value(0.1f), "Eta for SGD") + ("random_seed", po::value(), "Random seed") ("diagonal_tension,T", po::value()->default_value(4.0), "How sharp or flat around the diagonal is the alignment distribution (0 = uniform, >0 sharpens)") ("testset,x", po::value(), "After training completes, compute the log likelihood of this set of sentence pairs under the learned model"); po::options_description clo("Command line options"); @@ -57,12 +64,19 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { return true; } +void Normalize(RVector* v) { + float norm = v->norm(); + *v /= norm; +} + int main(int argc, char** argv) { po::variables_map conf; if (!InitCommandLine(argc, argv, &conf)) return 1; const string fname = conf["input"].as(); const int ITERATIONS = conf["iterations"].as(); + const float eta = conf["eta"].as(); const double diagonal_tension = conf["diagonal_tension"].as(); + bool SGD = true; if (diagonal_tension < 0.0) { cerr << "Invalid value for diagonal_tension: must be >= 0\n"; return 1; @@ -70,14 +84,15 @@ int main(int argc, char** argv) { string testset; if (conf.count("testset")) testset = conf["testset"].as(); - int lc = 0; + unsigned lc = 0; vector unnormed_a_i; string line; string ssrc, strg; bool flag = false; Lattice src, trg; - set vocab_e; + vector vocab_e; { // read through corpus, initialize int map, check lines are good + set svocab_e; cerr << "INITIAL READ OF " << fname << endl; ReadFile rf(fname); istream& in = *rf.stream(); @@ -97,13 +112,39 @@ int main(int argc, char** argv) { unnormed_a_i.resize(src.size()); for (unsigned i = 0; i < trg.size(); ++i) { assert(trg[i].size() == 1); - vocab_e.insert(trg[i][0].label); + svocab_e.insert(trg[i][0].label); } } + copy(svocab_e.begin(), svocab_e.end(), back_inserter(vocab_e)); } if (flag) cerr << endl; + cerr << "Number of target word types: " << vocab_e.size() << endl; + const float num_examples = lc; + + r_trg.resize(TD::NumWords() + 1); + r_src.resize(TD::NumWords() + 1); + if (conf.count("random_seed")) { + srand(conf["random_seed"].as()); + } else { + unsigned seed = time(NULL); + cerr << "Random seed: " << seed << endl; + srand(seed); + } + TMatrix t = TMatrix::Random() / 100.0; + for (unsigned i = 1; i < r_trg.size(); ++i) { + r_trg[i] = RVector::Random(); + r_src[i] = RVector::Random(); + r_trg[i][i % kDIMENSIONS] = 0.5; + r_src[i][(i-1) % kDIMENSIONS] = 0.5; + Normalize(&r_trg[i]); + Normalize(&r_src[i]); + } + vector > trg_pos(TD::NumWords() + 1); // do optimization + TMatrix g; + vector exp_src; + vector z_src; for (int iter = 0; iter < ITERATIONS; ++iter) { cerr << "ITERATION " << (iter + 1) << endl; ReadFile rf(fname); @@ -112,9 +153,8 @@ int main(int argc, char** argv) { double denom = 0.0; lc = 0; flag = false; - while(true) { - getline(in, line); - if (!in) break; + g *= 0; + while(getline(in, line)) { ++lc; if (lc % 1000 == 0) { cerr << '.'; flag = true; } if (lc %50000 == 0) { cerr << " [" << lc << "]\n" << flush; flag = false; } @@ -122,23 +162,86 @@ int main(int argc, char** argv) { LatticeTools::ConvertTextToLattice(ssrc, &src); LatticeTools::ConvertTextToLattice(strg, &trg); denom += trg.size(); - vector probs(src.size() + 1); - for (int j = 0; j < trg.size(); ++j) { - const WordID& f_j = trg[j][0].label; - double sum = 0; - const double j_over_ts = double(j) / trg.size(); - double az = 0; - for (int ta = 0; ta < src.size(); ++ta) { - unnormed_a_i[ta] = exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension); - az += unnormed_a_i[ta]; + + exp_src.clear(); exp_src.resize(src.size(), TMatrix::Zero()); + z_src.clear(); z_src.resize(src.size(), 0.0); + Array2D exp_refs(src.size(), trg.size(), TMatrix::Zero()); + Array2D z_refs(src.size(), trg.size(), 0.0); + for (unsigned j = 0; j < trg.size(); ++j) + trg_pos[trg[j][0].label].insert(j); + + for (unsigned i = 0; i < src.size(); ++i) { + const RVector& r_s = r_src[src[i][0].label]; + const RTVector pred = r_s.transpose() * t; + TMatrix& exp_m = exp_src[i]; + double& z = z_src[i]; + for (unsigned k = 0; k < vocab_e.size(); ++k) { + const WordID v_k = vocab_e[k]; + const RVector& r_t = r_trg[v_k]; + const double dot_prod = pred * r_t; + const double u = exp(dot_prod); + z += u; + const TMatrix v = r_s * r_t.transpose() * u; + exp_m += v; + set& ref_locs = trg_pos[v_k]; + if (!ref_locs.empty()) { + for (set::iterator it = ref_locs.begin(); it != ref_locs.end(); ++it) { + TMatrix& exp_ref_ij = exp_refs(i, *it); + double& z_ref_ij = z_refs(i, *it); + z_ref_ij += u; + exp_ref_ij += v; + } + } + } + } + for (unsigned j = 0; j < trg.size(); ++j) + trg_pos[trg[j][0].label].clear(); + + // model expectations for a single target generation with + // uniform alignment prior + double m_z = 0; + TMatrix m_exp = TMatrix::Zero(); + for (unsigned i = 0; i < src.size(); ++i) { + m_exp += exp_src[i]; + m_z += z_src[i]; + } + m_exp /= m_z; + + Array2D al(src.size(), trg.size(), false); + for (unsigned j = 0; j < trg.size(); ++j) { + double ref_z = 0; + TMatrix ref_exp = TMatrix::Zero(); + int max_i = 0; + double max_s = -9999999; + for (unsigned i = 0; i < src.size(); ++i) { + ref_exp += exp_refs(i, j); + ref_z += z_refs(i, j); + if (log(z_refs(i, j)) > max_s) { + max_s = log(z_refs(i, j)); + max_i = i; + } + // TODO handle alignment prob + } + if (ref_z <= 0) { + cerr << "TRG=" << TD::Convert(trg[j][0].label) << endl; + cerr << " LINE=" << line << endl; + cerr << " REF_EXP=\n" << ref_exp << endl; + cerr << " M_EXP=\n" << m_exp << endl; + abort(); } - for (int i = 1; i <= src.size(); ++i) { - const double prob_a_i = unnormed_a_i[i-1] / az; - // TODO - probs[i] = 1; // tt.prob(src[i-1][0].label, f_j) * prob_a_i; - sum += probs[i]; + al(max_i, j) = true; + ref_exp /= ref_z; + g += m_exp - ref_exp; + likelihood += log(ref_z) - log(m_z); + if (SGD) { + t -= g * eta / num_examples; + g *= 0; + } else { + assert(!"not implemented"); } } + + if (iter == (ITERATIONS - 1) || lc == 28) { cerr << al << endl; } } if (flag) { cerr << endl; } @@ -147,7 +250,9 @@ int main(int argc, char** argv) { cerr << " log_2 likelihood: " << base2_likelihood << endl; cerr << " cross entropy: " << (-base2_likelihood / denom) << endl; cerr << " perplexity: " << pow(2.0, -base2_likelihood / denom) << endl; + cerr << t << endl; } + cerr << "TRANSLATION MATRIX:" << endl << t << endl; return 0; } -- cgit v1.2.3