From 6704c23f34940dde3951155fd77246bb6229ba95 Mon Sep 17 00:00:00 2001
From: Chris Dyer <cdyer@cs.cmu.edu>
Date: Tue, 21 Feb 2012 17:51:44 -0500
Subject: use lbfgs

---
 training/Makefile.am  |  2 +-
 training/lbl_model.cc | 33 ++++++++++++++++++++++++++++-----
 2 files changed, 29 insertions(+), 6 deletions(-)

diff --git a/training/Makefile.am b/training/Makefile.am
index 330341ac..991ac210 100644
--- a/training/Makefile.am
+++ b/training/Makefile.am
@@ -50,7 +50,7 @@ test_ngram_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteva
 model1_SOURCES = model1.cc ttables.cc
 model1_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz
 
-lbl_model_SOURCES = lbl_model.cc ttables.cc
+lbl_model_SOURCES = lbl_model.cc optimize.cc
 lbl_model_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz
 
 grammar_convert_SOURCES = grammar_convert.cc
diff --git a/training/lbl_model.cc b/training/lbl_model.cc
index 4759eedc..eb3e194d 100644
--- a/training/lbl_model.cc
+++ b/training/lbl_model.cc
@@ -16,6 +16,7 @@
 #include <boost/program_options/variables_map.hpp>
 #include <Eigen/Dense>
 
+#include "optimize.h"
 #include "array2d.h"
 #include "m.h"
 #include "lattice.h"
@@ -26,7 +27,7 @@
 namespace po = boost::program_options;
 using namespace std;
 
-#define kDIMENSIONS 25
+#define kDIMENSIONS 8
 typedef Eigen::Matrix<float, kDIMENSIONS, 1> RVector;
 typedef Eigen::Matrix<float, 1, kDIMENSIONS> RTVector;
 typedef Eigen::Matrix<float, kDIMENSIONS, kDIMENSIONS> TMatrix;
@@ -69,6 +70,21 @@ void Normalize(RVector* v) {
   *v /= norm;
 }
 
+void Flatten(const TMatrix& m, vector<double>* v) {
+  unsigned c = 0;
+  v->resize(kDIMENSIONS * kDIMENSIONS);
+  for (unsigned i = 0; i < kDIMENSIONS; ++i)
+    for (unsigned j = 0; j < kDIMENSIONS; ++j)
+      (*v)[c++] = m(i,j);
+}
+
+void Unflatten(const vector<double>& v, TMatrix* m) {
+  unsigned c = 0;
+  for (unsigned i = 0; i < kDIMENSIONS; ++i)
+    for (unsigned j = 0; j < kDIMENSIONS; ++j)
+      (*m)(i, j) = v[c++];
+}
+
 int main(int argc, char** argv) {
   po::variables_map conf;
   if (!InitCommandLine(argc, argv, &conf)) return 1;
@@ -76,7 +92,7 @@ int main(int argc, char** argv) {
   const int ITERATIONS = conf["iterations"].as<unsigned>();
   const float eta = conf["eta"].as<float>();
   const double diagonal_tension = conf["diagonal_tension"].as<double>();
-  bool SGD = true;
+  bool SGD = false;
   if (diagonal_tension < 0.0) {
     cerr << "Invalid value for diagonal_tension: must be >= 0\n";
     return 1;
@@ -121,6 +137,7 @@ int main(int argc, char** argv) {
   cerr << "Number of target word types: " << vocab_e.size() << endl;
   const float num_examples = lc;
 
+  LBFGSOptimizer lbfgs(kDIMENSIONS * kDIMENSIONS, 100);
   r_trg.resize(TD::NumWords() + 1);
   r_src.resize(TD::NumWords() + 1);
   if (conf.count("random_seed")) {
@@ -130,7 +147,7 @@ int main(int argc, char** argv) {
     cerr << "Random seed: " << seed << endl;
     srand(seed);
   }
-  TMatrix t = TMatrix::Random() / 100.0;
+  TMatrix t = TMatrix::Random() / 1024.0;
   for (unsigned i = 1; i < r_trg.size(); ++i) {
     r_trg[i] = RVector::Random();
     r_src[i] = RVector::Random();
@@ -145,6 +162,8 @@ int main(int argc, char** argv) {
   TMatrix g;
   vector<TMatrix> exp_src;
   vector<double> z_src;
+  vector<double> flat_g, flat_t;
+  Flatten(t, &flat_t);
   for (int iter = 0; iter < ITERATIONS; ++iter) {
     cerr << "ITERATION " << (iter + 1) << endl;
     ReadFile rf(fname);
@@ -236,8 +255,6 @@ int main(int argc, char** argv) {
         if (SGD) {
           t -= g * eta / num_examples;
           g *= 0;
-        } else {
-          assert(!"not implemented");
         }
       }
       
@@ -250,6 +267,12 @@ int main(int argc, char** argv) {
     cerr << "  log_2 likelihood: " << base2_likelihood << endl;
     cerr << "   cross entropy: " << (-base2_likelihood / denom) << endl;
     cerr << "      perplexity: " << pow(2.0, -base2_likelihood / denom) << endl;
+    if (!SGD) {
+      Flatten(g, &flat_g);
+      lbfgs.Optimize(-likelihood, flat_g, &flat_t);
+      Unflatten(flat_t, &t);
+      if (lbfgs.HasConverged()) break;
+    }
     cerr << t << endl;
   }
   cerr << "TRANSLATION MATRIX:" << endl << t << endl;
-- 
cgit v1.2.3