diff options
Diffstat (limited to 'training/lbl_model.cc')
-rw-r--r-- | training/lbl_model.cc | 33 |
1 files changed, 28 insertions, 5 deletions
diff --git a/training/lbl_model.cc b/training/lbl_model.cc index 4759eedc..eb3e194d 100644 --- a/training/lbl_model.cc +++ b/training/lbl_model.cc @@ -16,6 +16,7 @@ #include <boost/program_options/variables_map.hpp> #include <Eigen/Dense> +#include "optimize.h" #include "array2d.h" #include "m.h" #include "lattice.h" @@ -26,7 +27,7 @@ namespace po = boost::program_options; using namespace std; -#define kDIMENSIONS 25 +#define kDIMENSIONS 8 typedef Eigen::Matrix<float, kDIMENSIONS, 1> RVector; typedef Eigen::Matrix<float, 1, kDIMENSIONS> RTVector; typedef Eigen::Matrix<float, kDIMENSIONS, kDIMENSIONS> TMatrix; @@ -69,6 +70,21 @@ void Normalize(RVector* v) { *v /= norm; } +void Flatten(const TMatrix& m, vector<double>* v) { + unsigned c = 0; + v->resize(kDIMENSIONS * kDIMENSIONS); + for (unsigned i = 0; i < kDIMENSIONS; ++i) + for (unsigned j = 0; j < kDIMENSIONS; ++j) + (*v)[c++] = m(i,j); +} + +void Unflatten(const vector<double>& v, TMatrix* m) { + unsigned c = 0; + for (unsigned i = 0; i < kDIMENSIONS; ++i) + for (unsigned j = 0; j < kDIMENSIONS; ++j) + (*m)(i, j) = v[c++]; +} + int main(int argc, char** argv) { po::variables_map conf; if (!InitCommandLine(argc, argv, &conf)) return 1; @@ -76,7 +92,7 @@ int main(int argc, char** argv) { const int ITERATIONS = conf["iterations"].as<unsigned>(); const float eta = conf["eta"].as<float>(); const double diagonal_tension = conf["diagonal_tension"].as<double>(); - bool SGD = true; + bool SGD = false; if (diagonal_tension < 0.0) { cerr << "Invalid value for diagonal_tension: must be >= 0\n"; return 1; @@ -121,6 +137,7 @@ int main(int argc, char** argv) { cerr << "Number of target word types: " << vocab_e.size() << endl; const float num_examples = lc; + LBFGSOptimizer lbfgs(kDIMENSIONS * kDIMENSIONS, 100); r_trg.resize(TD::NumWords() + 1); r_src.resize(TD::NumWords() + 1); if (conf.count("random_seed")) { @@ -130,7 +147,7 @@ int main(int argc, char** argv) { cerr << "Random seed: " << seed << endl; srand(seed); } - TMatrix t = TMatrix::Random() / 100.0; + TMatrix t = TMatrix::Random() / 1024.0; for (unsigned i = 1; i < r_trg.size(); ++i) { r_trg[i] = RVector::Random(); r_src[i] = RVector::Random(); @@ -145,6 +162,8 @@ int main(int argc, char** argv) { TMatrix g; vector<TMatrix> exp_src; vector<double> z_src; + vector<double> flat_g, flat_t; + Flatten(t, &flat_t); for (int iter = 0; iter < ITERATIONS; ++iter) { cerr << "ITERATION " << (iter + 1) << endl; ReadFile rf(fname); @@ -236,8 +255,6 @@ int main(int argc, char** argv) { if (SGD) { t -= g * eta / num_examples; g *= 0; - } else { - assert(!"not implemented"); } } @@ -250,6 +267,12 @@ int main(int argc, char** argv) { cerr << " log_2 likelihood: " << base2_likelihood << endl; cerr << " cross entropy: " << (-base2_likelihood / denom) << endl; cerr << " perplexity: " << pow(2.0, -base2_likelihood / denom) << endl; + if (!SGD) { + Flatten(g, &flat_g); + lbfgs.Optimize(-likelihood, flat_g, &flat_t); + Unflatten(flat_t, &t); + if (lbfgs.HasConverged()) break; + } cerr << t << endl; } cerr << "TRANSLATION MATRIX:" << endl << t << endl; |