diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/Makefile.am | 2 | ||||
| -rw-r--r-- | training/lbl_model.cc | 33 | 
2 files changed, 29 insertions, 6 deletions
| diff --git a/training/Makefile.am b/training/Makefile.am index 330341ac..991ac210 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -50,7 +50,7 @@ test_ngram_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteva  model1_SOURCES = model1.cc ttables.cc  model1_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz -lbl_model_SOURCES = lbl_model.cc ttables.cc +lbl_model_SOURCES = lbl_model.cc optimize.cc  lbl_model_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz  grammar_convert_SOURCES = grammar_convert.cc diff --git a/training/lbl_model.cc b/training/lbl_model.cc index 4759eedc..eb3e194d 100644 --- a/training/lbl_model.cc +++ b/training/lbl_model.cc @@ -16,6 +16,7 @@  #include <boost/program_options/variables_map.hpp>  #include <Eigen/Dense> +#include "optimize.h"  #include "array2d.h"  #include "m.h"  #include "lattice.h" @@ -26,7 +27,7 @@  namespace po = boost::program_options;  using namespace std; -#define kDIMENSIONS 25 +#define kDIMENSIONS 8  typedef Eigen::Matrix<float, kDIMENSIONS, 1> RVector;  typedef Eigen::Matrix<float, 1, kDIMENSIONS> RTVector;  typedef Eigen::Matrix<float, kDIMENSIONS, kDIMENSIONS> TMatrix; @@ -69,6 +70,21 @@ void Normalize(RVector* v) {    *v /= norm;  } +void Flatten(const TMatrix& m, vector<double>* v) { +  unsigned c = 0; +  v->resize(kDIMENSIONS * kDIMENSIONS); +  for (unsigned i = 0; i < kDIMENSIONS; ++i) +    for (unsigned j = 0; j < kDIMENSIONS; ++j) +      (*v)[c++] = m(i,j); +} + +void Unflatten(const vector<double>& v, TMatrix* m) { +  unsigned c = 0; +  for (unsigned i = 0; i < kDIMENSIONS; ++i) +    for (unsigned j = 0; j < kDIMENSIONS; ++j) +      (*m)(i, j) = v[c++]; +} +  int main(int argc, char** argv) {    po::variables_map conf;    if (!InitCommandLine(argc, argv, &conf)) return 1; @@ -76,7 +92,7 @@ int main(int argc, char** argv) {    const int ITERATIONS = conf["iterations"].as<unsigned>();    const float eta = conf["eta"].as<float>();    const double diagonal_tension = conf["diagonal_tension"].as<double>(); -  bool SGD = true; +  bool SGD = false;    if (diagonal_tension < 0.0) {      cerr << "Invalid value for diagonal_tension: must be >= 0\n";      return 1; @@ -121,6 +137,7 @@ int main(int argc, char** argv) {    cerr << "Number of target word types: " << vocab_e.size() << endl;    const float num_examples = lc; +  LBFGSOptimizer lbfgs(kDIMENSIONS * kDIMENSIONS, 100);    r_trg.resize(TD::NumWords() + 1);    r_src.resize(TD::NumWords() + 1);    if (conf.count("random_seed")) { @@ -130,7 +147,7 @@ int main(int argc, char** argv) {      cerr << "Random seed: " << seed << endl;      srand(seed);    } -  TMatrix t = TMatrix::Random() / 100.0; +  TMatrix t = TMatrix::Random() / 1024.0;    for (unsigned i = 1; i < r_trg.size(); ++i) {      r_trg[i] = RVector::Random();      r_src[i] = RVector::Random(); @@ -145,6 +162,8 @@ int main(int argc, char** argv) {    TMatrix g;    vector<TMatrix> exp_src;    vector<double> z_src; +  vector<double> flat_g, flat_t; +  Flatten(t, &flat_t);    for (int iter = 0; iter < ITERATIONS; ++iter) {      cerr << "ITERATION " << (iter + 1) << endl;      ReadFile rf(fname); @@ -236,8 +255,6 @@ int main(int argc, char** argv) {          if (SGD) {            t -= g * eta / num_examples;            g *= 0; -        } else { -          assert(!"not implemented");          }        } @@ -250,6 +267,12 @@ int main(int argc, char** argv) {      cerr << "  log_2 likelihood: " << base2_likelihood << endl;      cerr << "   cross entropy: " << (-base2_likelihood / denom) << endl;      cerr << "      perplexity: " << pow(2.0, -base2_likelihood / denom) << endl; +    if (!SGD) { +      Flatten(g, &flat_g); +      lbfgs.Optimize(-likelihood, flat_g, &flat_t); +      Unflatten(flat_t, &t); +      if (lbfgs.HasConverged()) break; +    }      cerr << t << endl;    }    cerr << "TRANSLATION MATRIX:" << endl << t << endl; | 
