diff options
author | Patrick Simianer <p@simianer.de> | 2017-08-04 16:18:12 +0200 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2017-08-04 16:18:12 +0200 |
commit | 7ef9733c608889b99335cb9e7db86262a6e7e528 (patch) | |
tree | f3be0dd62b176756bbb8809504685ac07043bc4e | |
parent | 266b88a783d5af777fb54543851e5f50b89b3170 (diff) |
-rw-r--r-- | nmt/nmt-1.js | 144 | ||||
-rw-r--r-- | nmt/nmt.html | 23 | ||||
-rw-r--r-- | nmt/nmt.js | 218 | ||||
-rw-r--r-- | nmt/notes.txt | 19 |
4 files changed, 404 insertions, 0 deletions
diff --git a/nmt/nmt-1.js b/nmt/nmt-1.js new file mode 100644 index 0000000..3d25713 --- /dev/null +++ b/nmt/nmt-1.js @@ -0,0 +1,144 @@ +/*var encode_time_step = function (src, model, lh, hidden_sizes, G, x=false) +{ + if (x) + G = new R.Graph(false); + /*else + var G = new R.Graph();*/ + var inp = one_hot($vocab_sz_src, $vocab_src[src]); + var out = R.forwardRNN(G, model, hidden_sizes, inp, lh); + + return [G, out]; +} + +var encode_update = function (model, G, dw) +{ + G.backward(); + var solver = new R.Solver(); + solver.step(model, 0.01, 0.0001, 5.0); +} + + +var decode_time_step = function (src, tgt, model, lh, solver, hidden_sizes, context, G) +{ + var inp = cat(one_hot($vocab_sz_tgt, $vocab_tgt[src]), context); + var out = R.forwardRNN(G, model, hidden_sizes, inp, lh); + var logprobs = out.o; + var probs = R.softmax(logprobs); + var target = $vocab_tgt[tgt]; + cost = -Math.log(probs.w[target]); + logprobs.dw = probs.w; + logprobs.dw[target] -= 1; + + return [model, cost, out, probs, G]; +} + + var hidden_sizes = [20]; +/*var train = function (hidden_sizes) +{*/ + var encoder = R.initRNN($vocab_sz_src, hidden_sizes, $vocab_sz_src+100); + encoder["name"] = "encoder"; + var decoder = R.initRNN($vocab_sz_src+100+$vocab_sz_tgt, hidden_sizes, $vocab_sz_tgt); + decoder["name"] = "decoder"; + + decoder["encoder_Wxh0"] = encoder["Wxh0"]; + decoder["encoder_Whh0"] = encoder["Whh0"]; + decoder["encoder_bhh0"] = encoder["bhh0"]; + decoder["encoder_Whd"] = encoder["Whd"]; + decoder["encoder_bd"] = encoder["bd"]; + + var solver = new R.Solver(); + var G = new R.Graph(); + + + lh = {}; + costs = []; + var l = 0; + while (true) + { + var cost = 0.0; + for (var i=0; i<$data.length; i++) { + var last_encode_lh,enc_g; + for (var j=0; j<$data[i]["src"].length; j++) { + [G, lh] = encode_time_step($data[i]["src"][j], encoder, lh, hidden_sizes, G); + } + last_encode_lh = lh; + + var context = last_encode_lh.o; + lh = {} + for (var k=0; k<$data[i]["tgt"].length; k++) { + var [decoder, c, lh, probs, G] = decode_time_step($data[i]["tgt"][k], $data[i]["tgt"][k+1], decoder, lh, solver, hidden_sizes, context, G); + cost += c; + + G.backward(); + solver.step(encoder, 0.01, 0.0001, 5.0); + solver.step(decoder, 0.01, 0.0001, 5.0); + } + + //encode_update(encoder, enc_g, lh); + + } + l++; + costs.push(cost); + if (stopping_criterion(costs[costs.length-2], cost, l)) + break; + } + + //return [encoder, decoder, costs]; + +/*}*/ + + +//var x = ["<bos>", "das", "ist", "<eos>"]; + + + +var k = 3; +var kbest = []; + +for (var q=0; q < k; q++) { + +var x = $data[0]["src"]; +var _, cntxt; +var o = {}; +for (var i=0; i<x.length; i++) + [_, o] = encode_time_step(x[i], encoder, o, hidden_sizes, null, true); +cntxt = o.o; + +var w = "<bos>", lh; +var prev = {}; +var str = ""; +var z = 0; +var p = 0; + while (true) { + var G = new R.Graph(false); + var inp = cat(one_hot($vocab_sz_tgt, $vocab_tgt[w]), cntxt); + var lh = R.forwardRNN(G, decoder, hidden_sizes, inp, prev); + prev = lh; + var logprobs = lh.o; + var probs = R.softmax(logprobs); + var x = R.samplei(probs.w); + p += probs.w[x]; + src = $ivocab_tgt[x]; + str += " "+src; + z++; + if (src == "<eos>" || z==100) + break; + } + kbest.push( {"transl":str, "score":p } ); +} + + + +var main = function () +{ + /*var hidden_sizes = [20]; + var [encoder, decoder, costs] = train($data, hidden_sizes);*/ + + return false; +} + +$(document).ready(function() +{ + main(); +});*/ + diff --git a/nmt/nmt.html b/nmt/nmt.html new file mode 100644 index 0000000..5cf60a5 --- /dev/null +++ b/nmt/nmt.html @@ -0,0 +1,23 @@ +<html> + <head> + <meta charset="utf-8"> + <script type="text/javascript" src="../external/jquery-1.8.3.min.js"></script> + <script type="text/javascript" src="../src/recurrent.js"></script> + <script type="text/javascript" src="nmt.js"> </script> + </head> + <body> + +<p>Training costs</p> +<ol id="costs"> +</ol> + +<p>Samples</p> +<ul id="samples"> +</ul> + +<p>Predict (context: "welcome to my")</p> +<p id="predict"></p> + + </body> +</html> + diff --git a/nmt/nmt.js b/nmt/nmt.js new file mode 100644 index 0000000..6710c36 --- /dev/null +++ b/nmt/nmt.js @@ -0,0 +1,218 @@ +/*var $data = [{"src":["<bos>", "das", "ist", "ein", "kleines", "haus", "<eos>"], + "tgt":["<bos>", "this", "is", "a", "small", "house", "<eos>"]} +];*/ + +var $data = [{"src":["<bos>", "das", "<eos>"], + "tgt":["<bos>", "this", "<eos>"]}];//, +/* {"src":["<bos>", "mein", "<eos>"], + "tgt":["<bos>", "my", "<eos>"]}];*/ + +var make_vocab = function (data) +{ + var k = 0; + var vocab = {}, ivocab = []; + for (var i=0; i<data.length; i++) { + for (var j=0; j<data[i].length; j++) { + var w = data[i][j]; + if (vocab[w]==undefined) { + vocab[w] = k; + ivocab[k] = w; + k++; + } + } + } + + return [vocab,ivocab,ivocab.length]; +} + +var $vocab_src, $ivocab_src,$vocab_sz_src; +var $vocab_tgt, $ivocab_tgt,$vocab_sz_tgt; +[$vocab_src,$ivocab_src,$vocab_sz_src] = make_vocab($data.map(function(i){return i["src"]})); +$vocab_sz_src++; +[$vocab_tgt,$ivocab_tgt,$vocab_sz_tgt] = make_vocab($data.map(function(i){return i["tgt"]})); + +var one_hot = function (n, i) +{ + var m = new R.Mat(n,1); + m.set(i, 0, 1); + + return m; +} + +var stopping_criterion = function (c, d, iter, margin=0.01, max_iter=100) +{ + if (Math.abs(c-d) < margin || iter>=max_iter) + return true; + return false; +} + +var cat = function (a, b) +{ + R.assert(a.d==1 && b.d==1); + var m = new R.Mat(a.n+b.n, 1); + var i; + for (i=0; i<a.n; i++) { + m.w[i] = a.w[i]; + m.dw[i] = a.dw[i]; + } + for (var j=0; j<b.n; j++) { + m.w[i+j] = b.w[j]; + m.dw[i+j] = b.dw[j]; + } + + return m; +} + +function shuffle(array) { + var currentIndex = array.length, temporaryValue, randomIndex; + + // While there remain elements to shuffle... + while (0 !== currentIndex) { + + // Pick a remaining element... + randomIndex = Math.floor(Math.random() * currentIndex); + currentIndex -= 1; + + // And swap it with the current element. + temporaryValue = array[currentIndex]; + array[currentIndex] = array[randomIndex]; + array[randomIndex] = temporaryValue; + } + + return array; +} + +var hidden_sizes = [5]; +var encoder = R.initRNN($vocab_sz_src, hidden_sizes, $vocab_sz_src+100); +encoder["name"] = "encoder"; +var decoder = R.initRNN($vocab_sz_src+100+$vocab_sz_tgt, hidden_sizes, $vocab_sz_tgt); +decoder["name"] = "decoder"; + +var solver = new R.Solver(); +var h = {}; +var inp = null; +costs = []; +var epoch = 0; + +decoder["encoder_Wxh0"] = encoder["Wxh0"]; +decoder["encoder_Whh0"] = encoder["Whh0"]; +decoder["encoder_bhh0"] = encoder["bhh0"]; +decoder["encoder_Whd"] = encoder["Whd"]; +decoder["encoder_bd"] = encoder["bd"]; + +var x = encoder.Whh0.w[0] +var y = decoder.Whh0.w[0] + +var enc_hs = []; +var dec_hs = []; +var dec_inps = []; + +while (true) +{ + $data = shuffle($data); + var c = 0.0; + for (var i=0; i<$data.length; i++) { + var last_enc_h; + var G = new R.Graph(); + for (var j=0; j<$data[i]["src"].length; j++) { + var src = $data[i]["src"][j]; + var inp = one_hot($vocab_sz_src, $vocab_src[src]); + var h = R.forwardRNN(G, encoder, hidden_sizes, inp, h); + last_enc_h = h; + enc_hs.push(h); + } + + for (var z=0; z<G.backprop.length; z++) { + G.backprop[z]["mark"] = "encoder"; + } + + var context = last_enc_h.o; + var h = {} + for (var k=0; k<$data[i]["tgt"].length; k++) { + var src = $data[i]["tgt"][k], + tgt = $data[i]["tgt"][k+1]; + + inp = cat(one_hot($vocab_sz_tgt, $vocab_tgt[src]), context); + dec_inps.push(inp); + h = R.forwardRNN(G, decoder, hidden_sizes, inp, h); + dec_hs.push(h); + + var logprobs = h.o; + var probs = R.softmax(logprobs); + var target = $vocab_tgt[tgt]; + cost = -Math.log(probs.w[target]); + if (!cost || cost==Infinity) cost = 0; // hmmm + logprobs.dw = probs.w; + logprobs.dw[target] -= 1; + c += cost; + + // update weights + G.backward(); + // copy grads? decoder.Wxh0 -> encoder.Whd ? + for (var z=0; z<last_enc_h.o.dw.length; z++) { + last_enc_h.o.dw[z] = dec_inps[0].dw[z]; + } + G.backward1(); + //exit(); + solver.step(decoder, 0.01, 0.000001, 5.0); + } + } + + epoch++; + costs.push(c); + if (stopping_criterion(costs[costs.length-2], cost, epoch)) + break; +} + + + +var k = 10; +var samples = []; + +for (var q=0; q < k; q++) { + +//var x = ["<bos>", "das", "ist", "ein", "kleines", "haus", "<eos>"]; +var x = ["<bos>", "das", "<eos>"]; +var _, cntxt; +var o = {}; +var G = new R.Graph(false); +for (var i=0; i<x.length; i++) { + var src = x[i]; + var inp = one_hot($vocab_sz_src, $vocab_src[src]); + var h = R.forwardRNN(G, encoder, hidden_sizes, inp, h); +} + +cntxt = h.o; + +var w = "<bos>", lh; +var prev = {}; +var str = ""; +var z = 0; +var p = 1; + while (true) { + var inp = cat(one_hot($vocab_sz_tgt, $vocab_tgt[w]), cntxt); + var lh = R.forwardRNN(G, decoder, hidden_sizes, inp, prev); + prev = lh; + var logprobs = lh.o; + var probs = R.softmax(logprobs); + var x = R.samplei(probs.w); + p *= probs.w[x]; + src = $ivocab_tgt[x]; + str += " "+src; + z++; + if (src == "<eos>" || z==100) + break; + } + samples.push( {"transl":str, "score":p } ); +} + +best = ""; +best_score = -999999; +for (var q = 0; q<samples.length; q++) { + if (best_score < samples[q].score) { + best = samples[q].transl + } +} + +alert(best); + diff --git a/nmt/notes.txt b/nmt/notes.txt new file mode 100644 index 0000000..0bc0b9c --- /dev/null +++ b/nmt/notes.txt @@ -0,0 +1,19 @@ +forward: + single tick + advance hidden state + add functions to G + +Graph + does backprop + sets .dw fields of matrices + +Solver + updates models + step cache doesn't matter + +input +forward +calc loss,set deriv in last layer +backward +solver + |