summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/lbl_model.cc147
1 files changed, 126 insertions, 21 deletions
diff --git a/training/lbl_model.cc b/training/lbl_model.cc
index ccd29255..4759eedc 100644
--- a/training/lbl_model.cc
+++ b/training/lbl_model.cc
@@ -5,13 +5,18 @@
int main() { std::cerr << "Please rebuild with --with-eigen PATH\n"; return 1; }
#else
+#include <cstdlib>
+#include <algorithm>
#include <cmath>
#include <set>
+#include <cstring> // memset
+#include <ctime>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
#include <Eigen/Dense>
+#include "array2d.h"
#include "m.h"
#include "lattice.h"
#include "stringlib.h"
@@ -21,7 +26,7 @@
namespace po = boost::program_options;
using namespace std;
-#define kDIMENSIONS 10
+#define kDIMENSIONS 25
typedef Eigen::Matrix<float, kDIMENSIONS, 1> RVector;
typedef Eigen::Matrix<float, 1, kDIMENSIONS> RTVector;
typedef Eigen::Matrix<float, kDIMENSIONS, kDIMENSIONS> TMatrix;
@@ -32,6 +37,8 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
opts.add_options()
("input,i",po::value<string>(),"Input file")
("iterations,I",po::value<unsigned>()->default_value(1000),"Number of iterations of training")
+ ("eta,e", po::value<float>()->default_value(0.1f), "Eta for SGD")
+ ("random_seed", po::value<unsigned>(), "Random seed")
("diagonal_tension,T", po::value<double>()->default_value(4.0), "How sharp or flat around the diagonal is the alignment distribution (0 = uniform, >0 sharpens)")
("testset,x", po::value<string>(), "After training completes, compute the log likelihood of this set of sentence pairs under the learned model");
po::options_description clo("Command line options");
@@ -57,12 +64,19 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
return true;
}
+void Normalize(RVector* v) {
+ float norm = v->norm();
+ *v /= norm;
+}
+
int main(int argc, char** argv) {
po::variables_map conf;
if (!InitCommandLine(argc, argv, &conf)) return 1;
const string fname = conf["input"].as<string>();
const int ITERATIONS = conf["iterations"].as<unsigned>();
+ const float eta = conf["eta"].as<float>();
const double diagonal_tension = conf["diagonal_tension"].as<double>();
+ bool SGD = true;
if (diagonal_tension < 0.0) {
cerr << "Invalid value for diagonal_tension: must be >= 0\n";
return 1;
@@ -70,14 +84,15 @@ int main(int argc, char** argv) {
string testset;
if (conf.count("testset")) testset = conf["testset"].as<string>();
- int lc = 0;
+ unsigned lc = 0;
vector<double> unnormed_a_i;
string line;
string ssrc, strg;
bool flag = false;
Lattice src, trg;
- set<WordID> vocab_e;
+ vector<WordID> vocab_e;
{ // read through corpus, initialize int map, check lines are good
+ set<WordID> svocab_e;
cerr << "INITIAL READ OF " << fname << endl;
ReadFile rf(fname);
istream& in = *rf.stream();
@@ -97,13 +112,39 @@ int main(int argc, char** argv) {
unnormed_a_i.resize(src.size());
for (unsigned i = 0; i < trg.size(); ++i) {
assert(trg[i].size() == 1);
- vocab_e.insert(trg[i][0].label);
+ svocab_e.insert(trg[i][0].label);
}
}
+ copy(svocab_e.begin(), svocab_e.end(), back_inserter(vocab_e));
}
if (flag) cerr << endl;
+ cerr << "Number of target word types: " << vocab_e.size() << endl;
+ const float num_examples = lc;
+
+ r_trg.resize(TD::NumWords() + 1);
+ r_src.resize(TD::NumWords() + 1);
+ if (conf.count("random_seed")) {
+ srand(conf["random_seed"].as<unsigned>());
+ } else {
+ unsigned seed = time(NULL);
+ cerr << "Random seed: " << seed << endl;
+ srand(seed);
+ }
+ TMatrix t = TMatrix::Random() / 100.0;
+ for (unsigned i = 1; i < r_trg.size(); ++i) {
+ r_trg[i] = RVector::Random();
+ r_src[i] = RVector::Random();
+ r_trg[i][i % kDIMENSIONS] = 0.5;
+ r_src[i][(i-1) % kDIMENSIONS] = 0.5;
+ Normalize(&r_trg[i]);
+ Normalize(&r_src[i]);
+ }
+ vector<set<unsigned> > trg_pos(TD::NumWords() + 1);
// do optimization
+ TMatrix g;
+ vector<TMatrix> exp_src;
+ vector<double> z_src;
for (int iter = 0; iter < ITERATIONS; ++iter) {
cerr << "ITERATION " << (iter + 1) << endl;
ReadFile rf(fname);
@@ -112,9 +153,8 @@ int main(int argc, char** argv) {
double denom = 0.0;
lc = 0;
flag = false;
- while(true) {
- getline(in, line);
- if (!in) break;
+ g *= 0;
+ while(getline(in, line)) {
++lc;
if (lc % 1000 == 0) { cerr << '.'; flag = true; }
if (lc %50000 == 0) { cerr << " [" << lc << "]\n" << flush; flag = false; }
@@ -122,23 +162,86 @@ int main(int argc, char** argv) {
LatticeTools::ConvertTextToLattice(ssrc, &src);
LatticeTools::ConvertTextToLattice(strg, &trg);
denom += trg.size();
- vector<double> probs(src.size() + 1);
- for (int j = 0; j < trg.size(); ++j) {
- const WordID& f_j = trg[j][0].label;
- double sum = 0;
- const double j_over_ts = double(j) / trg.size();
- double az = 0;
- for (int ta = 0; ta < src.size(); ++ta) {
- unnormed_a_i[ta] = exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension);
- az += unnormed_a_i[ta];
+
+ exp_src.clear(); exp_src.resize(src.size(), TMatrix::Zero());
+ z_src.clear(); z_src.resize(src.size(), 0.0);
+ Array2D<TMatrix> exp_refs(src.size(), trg.size(), TMatrix::Zero());
+ Array2D<double> z_refs(src.size(), trg.size(), 0.0);
+ for (unsigned j = 0; j < trg.size(); ++j)
+ trg_pos[trg[j][0].label].insert(j);
+
+ for (unsigned i = 0; i < src.size(); ++i) {
+ const RVector& r_s = r_src[src[i][0].label];
+ const RTVector pred = r_s.transpose() * t;
+ TMatrix& exp_m = exp_src[i];
+ double& z = z_src[i];
+ for (unsigned k = 0; k < vocab_e.size(); ++k) {
+ const WordID v_k = vocab_e[k];
+ const RVector& r_t = r_trg[v_k];
+ const double dot_prod = pred * r_t;
+ const double u = exp(dot_prod);
+ z += u;
+ const TMatrix v = r_s * r_t.transpose() * u;
+ exp_m += v;
+ set<unsigned>& ref_locs = trg_pos[v_k];
+ if (!ref_locs.empty()) {
+ for (set<unsigned>::iterator it = ref_locs.begin(); it != ref_locs.end(); ++it) {
+ TMatrix& exp_ref_ij = exp_refs(i, *it);
+ double& z_ref_ij = z_refs(i, *it);
+ z_ref_ij += u;
+ exp_ref_ij += v;
+ }
+ }
+ }
+ }
+ for (unsigned j = 0; j < trg.size(); ++j)
+ trg_pos[trg[j][0].label].clear();
+
+ // model expectations for a single target generation with
+ // uniform alignment prior
+ double m_z = 0;
+ TMatrix m_exp = TMatrix::Zero();
+ for (unsigned i = 0; i < src.size(); ++i) {
+ m_exp += exp_src[i];
+ m_z += z_src[i];
+ }
+ m_exp /= m_z;
+
+ Array2D<bool> al(src.size(), trg.size(), false);
+ for (unsigned j = 0; j < trg.size(); ++j) {
+ double ref_z = 0;
+ TMatrix ref_exp = TMatrix::Zero();
+ int max_i = 0;
+ double max_s = -9999999;
+ for (unsigned i = 0; i < src.size(); ++i) {
+ ref_exp += exp_refs(i, j);
+ ref_z += z_refs(i, j);
+ if (log(z_refs(i, j)) > max_s) {
+ max_s = log(z_refs(i, j));
+ max_i = i;
+ }
+ // TODO handle alignment prob
+ }
+ if (ref_z <= 0) {
+ cerr << "TRG=" << TD::Convert(trg[j][0].label) << endl;
+ cerr << " LINE=" << line << endl;
+ cerr << " REF_EXP=\n" << ref_exp << endl;
+ cerr << " M_EXP=\n" << m_exp << endl;
+ abort();
}
- for (int i = 1; i <= src.size(); ++i) {
- const double prob_a_i = unnormed_a_i[i-1] / az;
- // TODO
- probs[i] = 1; // tt.prob(src[i-1][0].label, f_j) * prob_a_i;
- sum += probs[i];
+ al(max_i, j) = true;
+ ref_exp /= ref_z;
+ g += m_exp - ref_exp;
+ likelihood += log(ref_z) - log(m_z);
+ if (SGD) {
+ t -= g * eta / num_examples;
+ g *= 0;
+ } else {
+ assert(!"not implemented");
}
}
+
+ if (iter == (ITERATIONS - 1) || lc == 28) { cerr << al << endl; }
}
if (flag) { cerr << endl; }
@@ -147,7 +250,9 @@ int main(int argc, char** argv) {
cerr << " log_2 likelihood: " << base2_likelihood << endl;
cerr << " cross entropy: " << (-base2_likelihood / denom) << endl;
cerr << " perplexity: " << pow(2.0, -base2_likelihood / denom) << endl;
+ cerr << t << endl;
}
+ cerr << "TRANSLATION MATRIX:" << endl << t << endl;
return 0;
}