summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-02-21 17:51:44 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2012-02-21 17:51:44 -0500
commit6704c23f34940dde3951155fd77246bb6229ba95 (patch)
treec08249bcfdc92d2c0c83bcfc83d13abea6d08c0e /training
parent48efe9ed9e6e8c5373dacd83493e1aee484ee070 (diff)
use lbfgs
Diffstat (limited to 'training')
-rw-r--r--training/Makefile.am2
-rw-r--r--training/lbl_model.cc33
2 files changed, 29 insertions, 6 deletions
diff --git a/training/Makefile.am b/training/Makefile.am
index 330341ac..991ac210 100644
--- a/training/Makefile.am
+++ b/training/Makefile.am
@@ -50,7 +50,7 @@ test_ngram_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteva
model1_SOURCES = model1.cc ttables.cc
model1_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz
-lbl_model_SOURCES = lbl_model.cc ttables.cc
+lbl_model_SOURCES = lbl_model.cc optimize.cc
lbl_model_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz
grammar_convert_SOURCES = grammar_convert.cc
diff --git a/training/lbl_model.cc b/training/lbl_model.cc
index 4759eedc..eb3e194d 100644
--- a/training/lbl_model.cc
+++ b/training/lbl_model.cc
@@ -16,6 +16,7 @@
#include <boost/program_options/variables_map.hpp>
#include <Eigen/Dense>
+#include "optimize.h"
#include "array2d.h"
#include "m.h"
#include "lattice.h"
@@ -26,7 +27,7 @@
namespace po = boost::program_options;
using namespace std;
-#define kDIMENSIONS 25
+#define kDIMENSIONS 8
typedef Eigen::Matrix<float, kDIMENSIONS, 1> RVector;
typedef Eigen::Matrix<float, 1, kDIMENSIONS> RTVector;
typedef Eigen::Matrix<float, kDIMENSIONS, kDIMENSIONS> TMatrix;
@@ -69,6 +70,21 @@ void Normalize(RVector* v) {
*v /= norm;
}
+void Flatten(const TMatrix& m, vector<double>* v) {
+ unsigned c = 0;
+ v->resize(kDIMENSIONS * kDIMENSIONS);
+ for (unsigned i = 0; i < kDIMENSIONS; ++i)
+ for (unsigned j = 0; j < kDIMENSIONS; ++j)
+ (*v)[c++] = m(i,j);
+}
+
+void Unflatten(const vector<double>& v, TMatrix* m) {
+ unsigned c = 0;
+ for (unsigned i = 0; i < kDIMENSIONS; ++i)
+ for (unsigned j = 0; j < kDIMENSIONS; ++j)
+ (*m)(i, j) = v[c++];
+}
+
int main(int argc, char** argv) {
po::variables_map conf;
if (!InitCommandLine(argc, argv, &conf)) return 1;
@@ -76,7 +92,7 @@ int main(int argc, char** argv) {
const int ITERATIONS = conf["iterations"].as<unsigned>();
const float eta = conf["eta"].as<float>();
const double diagonal_tension = conf["diagonal_tension"].as<double>();
- bool SGD = true;
+ bool SGD = false;
if (diagonal_tension < 0.0) {
cerr << "Invalid value for diagonal_tension: must be >= 0\n";
return 1;
@@ -121,6 +137,7 @@ int main(int argc, char** argv) {
cerr << "Number of target word types: " << vocab_e.size() << endl;
const float num_examples = lc;
+ LBFGSOptimizer lbfgs(kDIMENSIONS * kDIMENSIONS, 100);
r_trg.resize(TD::NumWords() + 1);
r_src.resize(TD::NumWords() + 1);
if (conf.count("random_seed")) {
@@ -130,7 +147,7 @@ int main(int argc, char** argv) {
cerr << "Random seed: " << seed << endl;
srand(seed);
}
- TMatrix t = TMatrix::Random() / 100.0;
+ TMatrix t = TMatrix::Random() / 1024.0;
for (unsigned i = 1; i < r_trg.size(); ++i) {
r_trg[i] = RVector::Random();
r_src[i] = RVector::Random();
@@ -145,6 +162,8 @@ int main(int argc, char** argv) {
TMatrix g;
vector<TMatrix> exp_src;
vector<double> z_src;
+ vector<double> flat_g, flat_t;
+ Flatten(t, &flat_t);
for (int iter = 0; iter < ITERATIONS; ++iter) {
cerr << "ITERATION " << (iter + 1) << endl;
ReadFile rf(fname);
@@ -236,8 +255,6 @@ int main(int argc, char** argv) {
if (SGD) {
t -= g * eta / num_examples;
g *= 0;
- } else {
- assert(!"not implemented");
}
}
@@ -250,6 +267,12 @@ int main(int argc, char** argv) {
cerr << " log_2 likelihood: " << base2_likelihood << endl;
cerr << " cross entropy: " << (-base2_likelihood / denom) << endl;
cerr << " perplexity: " << pow(2.0, -base2_likelihood / denom) << endl;
+ if (!SGD) {
+ Flatten(g, &flat_g);
+ lbfgs.Optimize(-likelihood, flat_g, &flat_t);
+ Unflatten(flat_t, &t);
+ if (lbfgs.HasConverged()) break;
+ }
cerr << t << endl;
}
cerr << "TRANSLATION MATRIX:" << endl << t << endl;