diff options
Diffstat (limited to 'nmt')
| -rw-r--r-- | nmt/nmt-1.js | 144 | ||||
| -rw-r--r-- | nmt/nmt.html | 23 | ||||
| -rw-r--r-- | nmt/nmt.js | 218 | ||||
| -rw-r--r-- | nmt/notes.txt | 19 | 
4 files changed, 404 insertions, 0 deletions
| diff --git a/nmt/nmt-1.js b/nmt/nmt-1.js new file mode 100644 index 0000000..3d25713 --- /dev/null +++ b/nmt/nmt-1.js @@ -0,0 +1,144 @@ +/*var encode_time_step = function (src, model, lh, hidden_sizes, G, x=false) +{ +  if (x) +    G = new R.Graph(false); +  /*else +    var G = new R.Graph();*/ +  var inp = one_hot($vocab_sz_src, $vocab_src[src]); +  var out = R.forwardRNN(G, model, hidden_sizes, inp, lh); + +  return [G, out]; +} + +var encode_update = function (model, G, dw) +{ +  G.backward(); +  var solver = new R.Solver(); +  solver.step(model, 0.01, 0.0001, 5.0); +} + + +var decode_time_step = function (src, tgt, model, lh, solver, hidden_sizes, context, G) +{ +  var inp = cat(one_hot($vocab_sz_tgt, $vocab_tgt[src]), context); +  var out = R.forwardRNN(G, model, hidden_sizes, inp, lh); +  var logprobs = out.o; +  var probs = R.softmax(logprobs); +  var target = $vocab_tgt[tgt]; +  cost = -Math.log(probs.w[target]); +  logprobs.dw = probs.w; +  logprobs.dw[target] -= 1; + +  return [model, cost, out, probs, G]; +} + +  var hidden_sizes = [20]; +/*var train = function (hidden_sizes) +{*/ +  var encoder = R.initRNN($vocab_sz_src, hidden_sizes, $vocab_sz_src+100); +  encoder["name"] = "encoder"; +  var decoder = R.initRNN($vocab_sz_src+100+$vocab_sz_tgt, hidden_sizes, $vocab_sz_tgt); +  decoder["name"] = "decoder"; + +  decoder["encoder_Wxh0"] = encoder["Wxh0"]; +  decoder["encoder_Whh0"] = encoder["Whh0"]; +  decoder["encoder_bhh0"] = encoder["bhh0"]; +  decoder["encoder_Whd"]  = encoder["Whd"]; +  decoder["encoder_bd"]   = encoder["bd"]; + +  var solver = new R.Solver(); +  var G = new R.Graph(); + + +  lh = {}; +  costs = []; +  var l = 0; +  while (true) +  { +    var cost = 0.0; +    for (var i=0; i<$data.length; i++) { +      var last_encode_lh,enc_g; +      for (var j=0; j<$data[i]["src"].length; j++) { +        [G, lh] = encode_time_step($data[i]["src"][j], encoder, lh, hidden_sizes, G); +      } +      last_encode_lh = lh; + +      var context = last_encode_lh.o; +      lh = {} +      for (var k=0; k<$data[i]["tgt"].length; k++) { +        var [decoder, c, lh, probs, G] = decode_time_step($data[i]["tgt"][k], $data[i]["tgt"][k+1], decoder, lh, solver, hidden_sizes, context, G); +        cost += c; + +        G.backward(); +        solver.step(encoder, 0.01, 0.0001, 5.0); +        solver.step(decoder, 0.01, 0.0001, 5.0); +      } + +      //encode_update(encoder, enc_g, lh); + +    } +    l++; +    costs.push(cost); +    if (stopping_criterion(costs[costs.length-2], cost, l)) +      break; +  } + +  //return [encoder, decoder, costs]; + +/*}*/ + + +//var x = ["<bos>", "das", "ist", "<eos>"]; + + + +var k = 3; +var kbest = []; + +for (var q=0; q < k; q++) { + +var x = $data[0]["src"];  +var _, cntxt; +var o = {}; +for (var i=0; i<x.length; i++) +  [_, o] = encode_time_step(x[i], encoder, o, hidden_sizes, null, true); +cntxt = o.o; + +var w = "<bos>", lh; +var prev = {}; +var str = ""; +var z = 0; +var p = 0; +  while (true) { +    var G = new R.Graph(false); +    var inp = cat(one_hot($vocab_sz_tgt, $vocab_tgt[w]), cntxt); +    var lh = R.forwardRNN(G, decoder, hidden_sizes, inp, prev); +    prev = lh; +    var logprobs = lh.o; +    var probs = R.softmax(logprobs); +    var x = R.samplei(probs.w); +    p += probs.w[x]; +    src = $ivocab_tgt[x]; +    str += " "+src; +    z++; +    if (src == "<eos>" || z==100) +      break; +  } +  kbest.push( {"transl":str, "score":p } ); +} + + + +var main = function () +{ +  /*var hidden_sizes = [20]; +  var [encoder, decoder, costs] = train($data, hidden_sizes);*/ + +  return false; +} + +$(document).ready(function() +{ +  main(); +});*/ + diff --git a/nmt/nmt.html b/nmt/nmt.html new file mode 100644 index 0000000..5cf60a5 --- /dev/null +++ b/nmt/nmt.html @@ -0,0 +1,23 @@ +<html> +  <head> +     <meta charset="utf-8">  +     <script type="text/javascript" src="../external/jquery-1.8.3.min.js"></script> +     <script type="text/javascript" src="../src/recurrent.js"></script> +     <script type="text/javascript" src="nmt.js"> </script> +  </head> +  <body> + +<p>Training costs</p> +<ol id="costs"> +</ol> + +<p>Samples</p> +<ul id="samples"> +</ul> + +<p>Predict (context: "welcome to my")</p> +<p id="predict"></p> + +  </body> +</html> + diff --git a/nmt/nmt.js b/nmt/nmt.js new file mode 100644 index 0000000..6710c36 --- /dev/null +++ b/nmt/nmt.js @@ -0,0 +1,218 @@ +/*var $data = [{"src":["<bos>", "das", "ist", "ein", "kleines", "haus", "<eos>"], +              "tgt":["<bos>", "this", "is", "a", "small", "house", "<eos>"]} +];*/ + +var $data = [{"src":["<bos>", "das", "<eos>"], +              "tgt":["<bos>", "this", "<eos>"]}];//, +/*             {"src":["<bos>", "mein", "<eos>"], +              "tgt":["<bos>", "my", "<eos>"]}];*/ + +var make_vocab = function (data) +{ +  var k = 0; +  var vocab = {}, ivocab = []; +  for (var i=0; i<data.length; i++) { +    for (var j=0; j<data[i].length; j++) { +      var w = data[i][j]; +      if (vocab[w]==undefined) { +        vocab[w] = k; +        ivocab[k] = w; +        k++; +      } +    } +  } + +  return [vocab,ivocab,ivocab.length]; +} + +var $vocab_src, $ivocab_src,$vocab_sz_src; +var $vocab_tgt, $ivocab_tgt,$vocab_sz_tgt; +[$vocab_src,$ivocab_src,$vocab_sz_src] = make_vocab($data.map(function(i){return i["src"]})); +$vocab_sz_src++; +[$vocab_tgt,$ivocab_tgt,$vocab_sz_tgt] = make_vocab($data.map(function(i){return i["tgt"]})); + +var one_hot = function (n, i) +{ +  var m = new R.Mat(n,1); +  m.set(i, 0, 1); + +  return m; +} + +var stopping_criterion = function (c, d, iter, margin=0.01, max_iter=100) +{ +  if (Math.abs(c-d) < margin || iter>=max_iter) +    return true; +  return false; +} + +var cat = function (a, b) +{ +  R.assert(a.d==1 && b.d==1); +  var m = new R.Mat(a.n+b.n, 1); +  var i; +  for (i=0; i<a.n; i++) { +    m.w[i] = a.w[i]; +    m.dw[i] = a.dw[i]; +  } +  for (var j=0; j<b.n; j++) { +    m.w[i+j] = b.w[j]; +    m.dw[i+j] = b.dw[j]; +  } + +  return m; +} + +function shuffle(array) { +  var currentIndex = array.length, temporaryValue, randomIndex; + +  // While there remain elements to shuffle... +  while (0 !== currentIndex) { + +    // Pick a remaining element... +    randomIndex = Math.floor(Math.random() * currentIndex); +    currentIndex -= 1; + +    // And swap it with the current element. +    temporaryValue = array[currentIndex]; +    array[currentIndex] = array[randomIndex]; +    array[randomIndex] = temporaryValue; +  } + +  return array; +} + +var hidden_sizes = [5]; +var encoder = R.initRNN($vocab_sz_src, hidden_sizes, $vocab_sz_src+100); +encoder["name"] = "encoder"; +var decoder = R.initRNN($vocab_sz_src+100+$vocab_sz_tgt, hidden_sizes, $vocab_sz_tgt); +decoder["name"] = "decoder"; + +var solver = new R.Solver(); +var h = {}; +var inp = null; +costs = []; +var epoch = 0; + +decoder["encoder_Wxh0"] = encoder["Wxh0"]; +decoder["encoder_Whh0"] = encoder["Whh0"]; +decoder["encoder_bhh0"] = encoder["bhh0"]; +decoder["encoder_Whd"]  = encoder["Whd"]; +decoder["encoder_bd"]   = encoder["bd"]; + +var x = encoder.Whh0.w[0] +var y = decoder.Whh0.w[0] + +var enc_hs = []; +var dec_hs = []; +var dec_inps = []; + +while (true) +{ +  $data = shuffle($data); +  var c = 0.0; +  for (var i=0; i<$data.length; i++) { +    var last_enc_h; +    var G = new R.Graph(); +    for (var j=0; j<$data[i]["src"].length; j++) { +      var src = $data[i]["src"][j]; +      var inp = one_hot($vocab_sz_src, $vocab_src[src]); +      var h = R.forwardRNN(G, encoder, hidden_sizes, inp, h); +      last_enc_h = h; +      enc_hs.push(h); +    } + +    for (var z=0; z<G.backprop.length; z++) { +      G.backprop[z]["mark"] = "encoder"; +    } + +    var context = last_enc_h.o; +    var h = {} +    for (var k=0; k<$data[i]["tgt"].length; k++) { +      var src = $data[i]["tgt"][k], +          tgt = $data[i]["tgt"][k+1]; + +      inp = cat(one_hot($vocab_sz_tgt, $vocab_tgt[src]), context); +      dec_inps.push(inp); +      h = R.forwardRNN(G, decoder, hidden_sizes, inp, h); +      dec_hs.push(h); +       +      var logprobs = h.o; +      var probs = R.softmax(logprobs); +      var target = $vocab_tgt[tgt]; +      cost = -Math.log(probs.w[target]); +      if (!cost || cost==Infinity) cost = 0; // hmmm +      logprobs.dw = probs.w; +      logprobs.dw[target] -= 1; +      c += cost; + +      // update weights +      G.backward(); +      // copy grads? decoder.Wxh0 -> encoder.Whd ? +      for (var z=0; z<last_enc_h.o.dw.length; z++) { +        last_enc_h.o.dw[z] = dec_inps[0].dw[z]; +      } +      G.backward1(); +      //exit(); +      solver.step(decoder, 0.01, 0.000001, 5.0); +    } +  } + +  epoch++; +  costs.push(c); +  if (stopping_criterion(costs[costs.length-2], cost, epoch)) +    break; +} + + + +var k = 10; +var samples = []; + +for (var q=0; q < k; q++) { + +//var x = ["<bos>", "das", "ist", "ein", "kleines", "haus", "<eos>"];  +var x = ["<bos>", "das", "<eos>"];  +var _, cntxt; +var o = {}; +var G = new R.Graph(false); +for (var i=0; i<x.length; i++) { +    var src = x[i]; +    var inp = one_hot($vocab_sz_src, $vocab_src[src]); +    var h = R.forwardRNN(G, encoder, hidden_sizes, inp, h); +} + +cntxt = h.o; + +var w = "<bos>", lh; +var prev = {}; +var str = ""; +var z = 0; +var p = 1; +  while (true) { +    var inp = cat(one_hot($vocab_sz_tgt, $vocab_tgt[w]), cntxt); +    var lh = R.forwardRNN(G, decoder, hidden_sizes, inp, prev); +    prev = lh; +    var logprobs = lh.o; +    var probs = R.softmax(logprobs); +    var x = R.samplei(probs.w); +    p *= probs.w[x]; +    src = $ivocab_tgt[x]; +    str += " "+src; +    z++; +    if (src == "<eos>" || z==100) +      break; +  } +  samples.push( {"transl":str, "score":p } ); +} + +best = ""; +best_score = -999999; +for (var q = 0; q<samples.length; q++) { +  if (best_score < samples[q].score) { +    best = samples[q].transl +  } +} + +alert(best); + diff --git a/nmt/notes.txt b/nmt/notes.txt new file mode 100644 index 0000000..0bc0b9c --- /dev/null +++ b/nmt/notes.txt @@ -0,0 +1,19 @@ +forward: +  single tick +  advance hidden state +  add functions to G + +Graph + does backprop + sets .dw fields of matrices + +Solver + updates models + step cache doesn't matter + +input +forward +calc loss,set deriv in last layer +backward +solver + | 
