From 7ef9733c608889b99335cb9e7db86262a6e7e528 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Fri, 4 Aug 2017 16:18:12 +0200 Subject: nmt wip --- nmt/nmt-1.js | 144 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 nmt/nmt-1.js (limited to 'nmt/nmt-1.js') diff --git a/nmt/nmt-1.js b/nmt/nmt-1.js new file mode 100644 index 0000000..3d25713 --- /dev/null +++ b/nmt/nmt-1.js @@ -0,0 +1,144 @@ +/*var encode_time_step = function (src, model, lh, hidden_sizes, G, x=false) +{ + if (x) + G = new R.Graph(false); + /*else + var G = new R.Graph();*/ + var inp = one_hot($vocab_sz_src, $vocab_src[src]); + var out = R.forwardRNN(G, model, hidden_sizes, inp, lh); + + return [G, out]; +} + +var encode_update = function (model, G, dw) +{ + G.backward(); + var solver = new R.Solver(); + solver.step(model, 0.01, 0.0001, 5.0); +} + + +var decode_time_step = function (src, tgt, model, lh, solver, hidden_sizes, context, G) +{ + var inp = cat(one_hot($vocab_sz_tgt, $vocab_tgt[src]), context); + var out = R.forwardRNN(G, model, hidden_sizes, inp, lh); + var logprobs = out.o; + var probs = R.softmax(logprobs); + var target = $vocab_tgt[tgt]; + cost = -Math.log(probs.w[target]); + logprobs.dw = probs.w; + logprobs.dw[target] -= 1; + + return [model, cost, out, probs, G]; +} + + var hidden_sizes = [20]; +/*var train = function (hidden_sizes) +{*/ + var encoder = R.initRNN($vocab_sz_src, hidden_sizes, $vocab_sz_src+100); + encoder["name"] = "encoder"; + var decoder = R.initRNN($vocab_sz_src+100+$vocab_sz_tgt, hidden_sizes, $vocab_sz_tgt); + decoder["name"] = "decoder"; + + decoder["encoder_Wxh0"] = encoder["Wxh0"]; + decoder["encoder_Whh0"] = encoder["Whh0"]; + decoder["encoder_bhh0"] = encoder["bhh0"]; + decoder["encoder_Whd"] = encoder["Whd"]; + decoder["encoder_bd"] = encoder["bd"]; + + var solver = new R.Solver(); + var G = new R.Graph(); + + + lh = {}; + costs = []; + var l = 0; + while (true) + { + var cost = 0.0; + for (var i=0; i<$data.length; i++) { + var last_encode_lh,enc_g; + for (var j=0; j<$data[i]["src"].length; j++) { + [G, lh] = encode_time_step($data[i]["src"][j], encoder, lh, hidden_sizes, G); + } + last_encode_lh = lh; + + var context = last_encode_lh.o; + lh = {} + for (var k=0; k<$data[i]["tgt"].length; k++) { + var [decoder, c, lh, probs, G] = decode_time_step($data[i]["tgt"][k], $data[i]["tgt"][k+1], decoder, lh, solver, hidden_sizes, context, G); + cost += c; + + G.backward(); + solver.step(encoder, 0.01, 0.0001, 5.0); + solver.step(decoder, 0.01, 0.0001, 5.0); + } + + //encode_update(encoder, enc_g, lh); + + } + l++; + costs.push(cost); + if (stopping_criterion(costs[costs.length-2], cost, l)) + break; + } + + //return [encoder, decoder, costs]; + +/*}*/ + + +//var x = ["", "das", "ist", ""]; + + + +var k = 3; +var kbest = []; + +for (var q=0; q < k; q++) { + +var x = $data[0]["src"]; +var _, cntxt; +var o = {}; +for (var i=0; i" || z==100) + break; + } + kbest.push( {"transl":str, "score":p } ); +} + + + +var main = function () +{ + /*var hidden_sizes = [20]; + var [encoder, decoder, costs] = train($data, hidden_sizes);*/ + + return false; +} + +$(document).ready(function() +{ + main(); +});*/ + -- cgit v1.2.3