From 500b99183517af88788098281237b70ad63453b9 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Sat, 21 May 2016 21:25:39 +0200 Subject: rnnlm --- rnnlm/rnnlm.html | 23 +++++++++ rnnlm/rnnlm.js | 151 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+) create mode 100644 rnnlm/rnnlm.html create mode 100644 rnnlm/rnnlm.js diff --git a/rnnlm/rnnlm.html b/rnnlm/rnnlm.html new file mode 100644 index 0000000..8e137d7 --- /dev/null +++ b/rnnlm/rnnlm.html @@ -0,0 +1,23 @@ + + + + + + + + + +

Training costs

+
    +
+ +

Samples

+ + +

Predict (context: "welcome to my")

+

+ + + + diff --git a/rnnlm/rnnlm.js b/rnnlm/rnnlm.js new file mode 100644 index 0000000..a8d2dc7 --- /dev/null +++ b/rnnlm/rnnlm.js @@ -0,0 +1,151 @@ +var $data = [["", "this", "is", "my", "house", ""], + ["", "welcome", "to", "my", "house", ""], + ["", "welcome", "to", "my", "tiny", "house", ""], + ["", "welcome", "to", "my", "little", "house", ""]]; + +var make_vocab = function ($data) +{ + var k = 0; + var vocab = {}, ivocab = []; + for (var i=0; i<$data.length; i++) { + for (var j=0; j<$data[i].length; j++) { + var w = $data[i][j]; + if (vocab[w]==undefined) { + vocab[w] = k; + ivocab[k] = w; + k++; + } + } + } + + return [vocab,ivocab,ivocab.length]; +} + +var $vocab, $ivocab,$vocab_sz; +[$vocab,$ivocab,$vocab_sz] = make_vocab($data); + +var one_hot = function (n, i) +{ + var m = new R.Mat(n,1); + m.set(i, 0, 1); + + return m; +} + +var time_step = function (src, tgt, model, lh, solver, hidden_sizes) +{ + var G = new R.Graph(); + var inp = one_hot($vocab_sz, $vocab[src]); + var out = R.forwardRNN(G, model, hidden_sizes, inp, lh); + var logprobs = out.o; + var probs = R.softmax(logprobs); + var target = $vocab[tgt]; + cost = -Math.log(probs.w[target]); + logprobs.dw = probs.w; + logprobs.dw[target] -= 1; + G.backward(); + solver.step(model, 0.01, 0.0001, 5.0); + + return [model, cost, out, probs]; +} + +var stopping_criterion = function (c, d, iter, margin=0.01, max_iter=100) +{ + if (Math.abs(c-d) < margin || iter>=max_iter) + return true; + return false; +} + +var train = function ($data, hidden_sizes) +{ + var model = R.initRNN($vocab_sz, hidden_sizes, $vocab_sz); + var solver = new R.Solver(); + lh = {}; + costs = []; + var k = 0; + while (true) + { + var cost = 0.0; + for (var i=0; i<$data.length; i++) { + for (var j=0; j<$data[i].length-1; j++) { + [model, c, lh, probs] = time_step($data[i][j], $data[i][j+1], model, lh, solver, hidden_sizes); + cost += c; + } + } + k++; + costs.push(cost); + if (stopping_criterion(costs[costs.length-2], cost, k)) + break; + } + + return [model, costs]; +} + +var generate = function (model, hidden_sizes) +{ + var prev = {}; + var str = ""; + var src = ""; + while (true) { + var G = new R.Graph(false); + var inp = one_hot($vocab_sz, $vocab[src]); + var lh = R.forwardRNN(G, model, hidden_sizes, inp, prev); + prev = lh; + var logprobs = lh.o; + var probs = R.softmax(logprobs); + var x = R.samplei(probs.w); + src = $ivocab[x]; + str += " "+src; + if (src == "") + break; + } + + return str; +} + +var predict = function (model, hidden_sizes) +{ + context = ["", "welcome", "to", "my"]; + var prev = {}, lh; + var logprobs, probs; + for (var i=0; i'+costs[i]+''); + + var i=0; + while (1) { + var s = generate(model, hidden_sizes); + $("#samples").append("
  • "+s+"
  • "); + i++; + if (i==13) break; + } + + var [t,tp,dist] = predict(model, hidden_sizes); + $("#predict").html(t+" ("+tp+")"); + + return false; +} + +$(document).ready(function() +{ + main(); +}); + -- cgit v1.2.3