summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--training/lbl_model.cc69
1 files changed, 63 insertions, 6 deletions
diff --git a/training/lbl_model.cc b/training/lbl_model.cc
index a114bba7..2af848b5 100644
--- a/training/lbl_model.cc
+++ b/training/lbl_model.cc
@@ -28,7 +28,7 @@
namespace po = boost::program_options;
using namespace std;
-#define kDIMENSIONS 110
+#define kDIMENSIONS 100
typedef Eigen::Matrix<float, kDIMENSIONS, 1> RVector;
typedef Eigen::Matrix<float, 1, kDIMENSIONS> RTVector;
typedef Eigen::Matrix<float, kDIMENSIONS, kDIMENSIONS> TMatrix;
@@ -40,7 +40,9 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("input,i",po::value<string>(),"Input file")
("iterations,I",po::value<unsigned>()->default_value(1000),"Number of iterations of training")
("regularization_strength,C",po::value<float>()->default_value(0.1),"L2 regularization strength (0 for no regularization)")
- ("eta,e", po::value<float>()->default_value(0.1f), "Eta for SGD")
+ ("eta", po::value<float>()->default_value(0.1f), "Eta for SGD")
+ ("source_embeddings,f", po::value<string>(), "File containing source embeddings (if unset, random vectors will be used)")
+ ("target_embeddings,e", po::value<string>(), "File containing target embeddings (if unset, random vectors will be used)")
("random_seed,s", po::value<unsigned>(), "Random seed")
("diagonal_tension,T", po::value<double>()->default_value(4.0), "How sharp or flat around the diagonal is the alignment distribution (0 = uniform, >0 sharpens)")
("testset,x", po::value<string>(), "After training completes, compute the log likelihood of this set of sentence pairs under the learned model");
@@ -106,6 +108,59 @@ double ApplyRegularization(const double C,
return reg;
}
+void LoadEmbeddings(const string& filename, vector<RVector>* pv) {
+ vector<RVector>& v = *pv;
+ cerr << "Reading embeddings from " << filename << " ...\n";
+ ReadFile rf(filename);
+ istream& in = *rf.stream();
+ string line;
+ unsigned lc = 0;
+ while(getline(in, line)) {
+ ++lc;
+ size_t cur = line.find(' ');
+ if (cur == string::npos || cur == 0) {
+ cerr << "Parse error reading line " << lc << ":\n" << line << endl;
+ abort();
+ }
+ WordID w = TD::Convert(line.substr(0, cur));
+ if (w >= v.size()) continue;
+ RVector& curv = v[w];
+ line[cur] = 0;
+ size_t start = cur + 1;
+ cur = start + 1;
+ size_t c = 0;
+ while(cur < line.size()) {
+ if (line[cur] == ' ') {
+ line[cur] = 0;
+ curv[c++] = strtod(&line[start], NULL);
+ start = cur + 1;
+ cur = start;
+ if (c == kDIMENSIONS) break;
+ }
+ ++cur;
+ }
+ if (c < kDIMENSIONS && cur != start) {
+ if (cur < line.size()) line[cur] = 0;
+ curv[c++] = strtod(&line[start], NULL);
+ }
+ if (c != kDIMENSIONS) {
+ static bool first = true;
+ if (first) {
+ cerr << " read " << c << " dimensions from embedding file, but built with " << kDIMENSIONS << " (filling in with random values)\n";
+ first = false;
+ }
+ for (; c < kDIMENSIONS; ++c) curv[c] = rand();
+ }
+ if (c == kDIMENSIONS && cur != line.size()) {
+ static bool first = true;
+ if (first) {
+ cerr << " embedding file contains more dimensions than configured with, truncating.\n";
+ first = false;
+ }
+ }
+ }
+}
+
int main(int argc, char** argv) {
po::variables_map conf;
if (!InitCommandLine(argc, argv, &conf)) return 1;
@@ -175,11 +230,11 @@ int main(int argc, char** argv) {
for (unsigned i = 1; i < r_trg.size(); ++i) {
r_trg[i] = RVector::Random();
r_src[i] = RVector::Random();
- r_trg[i][i % kDIMENSIONS] = 0.5;
- r_src[i][(i-1) % kDIMENSIONS] = 0.5;
- Normalize(&r_trg[i]);
- Normalize(&r_src[i]);
}
+ if (conf.count("source_embeddings"))
+ LoadEmbeddings(conf["source_embeddings"].as<string>(), &r_src);
+ if (conf.count("target_embeddings"))
+ LoadEmbeddings(conf["target_embeddings"].as<string>(), &r_trg);
vector<set<unsigned> > trg_pos(TD::NumWords() + 1);
// do optimization
@@ -242,6 +297,8 @@ int main(int argc, char** argv) {
// model expectations for a single target generation with
// uniform alignment prior
+ // TODO: when using a non-uniform alignment, m_exp will be
+ // a function of j (below)
double m_z = 0;
TMatrix m_exp = TMatrix::Zero();
for (unsigned i = 0; i < src.size(); ++i) {