From 2faca3e7b3b8e4eba6c036c635a5b23883e72337 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 24 Feb 2012 00:47:48 -0500 Subject: load embeddings from file --- training/lbl_model.cc | 69 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 63 insertions(+), 6 deletions(-) diff --git a/training/lbl_model.cc b/training/lbl_model.cc index a114bba7..2af848b5 100644 --- a/training/lbl_model.cc +++ b/training/lbl_model.cc @@ -28,7 +28,7 @@ namespace po = boost::program_options; using namespace std; -#define kDIMENSIONS 110 +#define kDIMENSIONS 100 typedef Eigen::Matrix RVector; typedef Eigen::Matrix RTVector; typedef Eigen::Matrix TMatrix; @@ -40,7 +40,9 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("input,i",po::value(),"Input file") ("iterations,I",po::value()->default_value(1000),"Number of iterations of training") ("regularization_strength,C",po::value()->default_value(0.1),"L2 regularization strength (0 for no regularization)") - ("eta,e", po::value()->default_value(0.1f), "Eta for SGD") + ("eta", po::value()->default_value(0.1f), "Eta for SGD") + ("source_embeddings,f", po::value(), "File containing source embeddings (if unset, random vectors will be used)") + ("target_embeddings,e", po::value(), "File containing target embeddings (if unset, random vectors will be used)") ("random_seed,s", po::value(), "Random seed") ("diagonal_tension,T", po::value()->default_value(4.0), "How sharp or flat around the diagonal is the alignment distribution (0 = uniform, >0 sharpens)") ("testset,x", po::value(), "After training completes, compute the log likelihood of this set of sentence pairs under the learned model"); @@ -106,6 +108,59 @@ double ApplyRegularization(const double C, return reg; } +void LoadEmbeddings(const string& filename, vector* pv) { + vector& v = *pv; + cerr << "Reading embeddings from " << filename << " ...\n"; + ReadFile rf(filename); + istream& in = *rf.stream(); + string line; + unsigned lc = 0; + while(getline(in, line)) { + ++lc; + size_t cur = line.find(' '); + if (cur == string::npos || cur == 0) { + cerr << "Parse error reading line " << lc << ":\n" << line << endl; + abort(); + } + WordID w = TD::Convert(line.substr(0, cur)); + if (w >= v.size()) continue; + RVector& curv = v[w]; + line[cur] = 0; + size_t start = cur + 1; + cur = start + 1; + size_t c = 0; + while(cur < line.size()) { + if (line[cur] == ' ') { + line[cur] = 0; + curv[c++] = strtod(&line[start], NULL); + start = cur + 1; + cur = start; + if (c == kDIMENSIONS) break; + } + ++cur; + } + if (c < kDIMENSIONS && cur != start) { + if (cur < line.size()) line[cur] = 0; + curv[c++] = strtod(&line[start], NULL); + } + if (c != kDIMENSIONS) { + static bool first = true; + if (first) { + cerr << " read " << c << " dimensions from embedding file, but built with " << kDIMENSIONS << " (filling in with random values)\n"; + first = false; + } + for (; c < kDIMENSIONS; ++c) curv[c] = rand(); + } + if (c == kDIMENSIONS && cur != line.size()) { + static bool first = true; + if (first) { + cerr << " embedding file contains more dimensions than configured with, truncating.\n"; + first = false; + } + } + } +} + int main(int argc, char** argv) { po::variables_map conf; if (!InitCommandLine(argc, argv, &conf)) return 1; @@ -175,11 +230,11 @@ int main(int argc, char** argv) { for (unsigned i = 1; i < r_trg.size(); ++i) { r_trg[i] = RVector::Random(); r_src[i] = RVector::Random(); - r_trg[i][i % kDIMENSIONS] = 0.5; - r_src[i][(i-1) % kDIMENSIONS] = 0.5; - Normalize(&r_trg[i]); - Normalize(&r_src[i]); } + if (conf.count("source_embeddings")) + LoadEmbeddings(conf["source_embeddings"].as(), &r_src); + if (conf.count("target_embeddings")) + LoadEmbeddings(conf["target_embeddings"].as(), &r_trg); vector > trg_pos(TD::NumWords() + 1); // do optimization @@ -242,6 +297,8 @@ int main(int argc, char** argv) { // model expectations for a single target generation with // uniform alignment prior + // TODO: when using a non-uniform alignment, m_exp will be + // a function of j (below) double m_z = 0; TMatrix m_exp = TMatrix::Zero(); for (unsigned i = 0; i < src.size(); ++i) { -- cgit v1.2.3