From 7ef9733c608889b99335cb9e7db86262a6e7e528 Mon Sep 17 00:00:00 2001 From: Patrick Simianer
Date: Fri, 4 Aug 2017 16:18:12 +0200
Subject: nmt wip
---
nmt/nmt-1.js | 144 ++++++++++++++++++++++++++++++++++++++
nmt/nmt.html | 23 +++++++
nmt/nmt.js | 218 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
nmt/notes.txt | 19 +++++
4 files changed, 404 insertions(+)
create mode 100644 nmt/nmt-1.js
create mode 100644 nmt/nmt.html
create mode 100644 nmt/nmt.js
create mode 100644 nmt/notes.txt
(limited to 'nmt')
diff --git a/nmt/nmt-1.js b/nmt/nmt-1.js
new file mode 100644
index 0000000..3d25713
--- /dev/null
+++ b/nmt/nmt-1.js
@@ -0,0 +1,144 @@
+/*var encode_time_step = function (src, model, lh, hidden_sizes, G, x=false)
+{
+ if (x)
+ G = new R.Graph(false);
+ /*else
+ var G = new R.Graph();*/
+ var inp = one_hot($vocab_sz_src, $vocab_src[src]);
+ var out = R.forwardRNN(G, model, hidden_sizes, inp, lh);
+
+ return [G, out];
+}
+
+var encode_update = function (model, G, dw)
+{
+ G.backward();
+ var solver = new R.Solver();
+ solver.step(model, 0.01, 0.0001, 5.0);
+}
+
+
+var decode_time_step = function (src, tgt, model, lh, solver, hidden_sizes, context, G)
+{
+ var inp = cat(one_hot($vocab_sz_tgt, $vocab_tgt[src]), context);
+ var out = R.forwardRNN(G, model, hidden_sizes, inp, lh);
+ var logprobs = out.o;
+ var probs = R.softmax(logprobs);
+ var target = $vocab_tgt[tgt];
+ cost = -Math.log(probs.w[target]);
+ logprobs.dw = probs.w;
+ logprobs.dw[target] -= 1;
+
+ return [model, cost, out, probs, G];
+}
+
+ var hidden_sizes = [20];
+/*var train = function (hidden_sizes)
+{*/
+ var encoder = R.initRNN($vocab_sz_src, hidden_sizes, $vocab_sz_src+100);
+ encoder["name"] = "encoder";
+ var decoder = R.initRNN($vocab_sz_src+100+$vocab_sz_tgt, hidden_sizes, $vocab_sz_tgt);
+ decoder["name"] = "decoder";
+
+ decoder["encoder_Wxh0"] = encoder["Wxh0"];
+ decoder["encoder_Whh0"] = encoder["Whh0"];
+ decoder["encoder_bhh0"] = encoder["bhh0"];
+ decoder["encoder_Whd"] = encoder["Whd"];
+ decoder["encoder_bd"] = encoder["bd"];
+
+ var solver = new R.Solver();
+ var G = new R.Graph();
+
+
+ lh = {};
+ costs = [];
+ var l = 0;
+ while (true)
+ {
+ var cost = 0.0;
+ for (var i=0; i<$data.length; i++) {
+ var last_encode_lh,enc_g;
+ for (var j=0; j<$data[i]["src"].length; j++) {
+ [G, lh] = encode_time_step($data[i]["src"][j], encoder, lh, hidden_sizes, G);
+ }
+ last_encode_lh = lh;
+
+ var context = last_encode_lh.o;
+ lh = {}
+ for (var k=0; k<$data[i]["tgt"].length; k++) {
+ var [decoder, c, lh, probs, G] = decode_time_step($data[i]["tgt"][k], $data[i]["tgt"][k+1], decoder, lh, solver, hidden_sizes, context, G);
+ cost += c;
+
+ G.backward();
+ solver.step(encoder, 0.01, 0.0001, 5.0);
+ solver.step(decoder, 0.01, 0.0001, 5.0);
+ }
+
+ //encode_update(encoder, enc_g, lh);
+
+ }
+ l++;
+ costs.push(cost);
+ if (stopping_criterion(costs[costs.length-2], cost, l))
+ break;
+ }
+
+ //return [encoder, decoder, costs];
+
+/*}*/
+
+
+//var x = [" Training costs Samples Predict (context: "welcome to my")
+
+
+
+
+
+