From 19b320070754dd34146a3921db825c7ed7a3094e Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Tue, 5 Jul 2016 11:26:29 +0200 Subject: more test contxts, use LSTM, shuffle, fix --- rnnlm/rnnlm.html | 14 ++++++++++++- rnnlm/rnnlm.js | 60 ++++++++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 61 insertions(+), 13 deletions(-) diff --git a/rnnlm/rnnlm.html b/rnnlm/rnnlm.html index 8e137d7..b9642f4 100644 --- a/rnnlm/rnnlm.html +++ b/rnnlm/rnnlm.html @@ -15,8 +15,20 @@ +

Predict (context: "--")

+

+
+

Predict (context: "welcome")

+

+
+

Predict (context: "welcome to")

+

+

Predict (context: "welcome to my")

-

+

+
+

Predict (context: "welcome to my house")

+

diff --git a/rnnlm/rnnlm.js b/rnnlm/rnnlm.js index a8d2dc7..70a66b6 100644 --- a/rnnlm/rnnlm.js +++ b/rnnlm/rnnlm.js @@ -32,11 +32,10 @@ var one_hot = function (n, i) return m; } -var time_step = function (src, tgt, model, lh, solver, hidden_sizes) +var time_step = function (src, tgt, model, lh, solver, hidden_sizes, G) { - var G = new R.Graph(); var inp = one_hot($vocab_sz, $vocab[src]); - var out = R.forwardRNN(G, model, hidden_sizes, inp, lh); + var out = R.forwardLSTM(G, model, hidden_sizes, inp, lh); var logprobs = out.o; var probs = R.softmax(logprobs); var target = $vocab[tgt]; @@ -58,17 +57,19 @@ var stopping_criterion = function (c, d, iter, margin=0.01, max_iter=100) var train = function ($data, hidden_sizes) { - var model = R.initRNN($vocab_sz, hidden_sizes, $vocab_sz); + var model = R.initLSTM($vocab_sz, hidden_sizes, $vocab_sz); var solver = new R.Solver(); lh = {}; costs = []; var k = 0; while (true) { + $data = shuffle($data); var cost = 0.0; for (var i=0; i<$data.length; i++) { + var G = new R.Graph(); for (var j=0; j<$data[i].length-1; j++) { - [model, c, lh, probs] = time_step($data[i][j], $data[i][j+1], model, lh, solver, hidden_sizes); + [model, c, lh, probs] = time_step($data[i][j], $data[i][j+1], model, lh, solver, hidden_sizes, G); cost += c; } } @@ -89,7 +90,7 @@ var generate = function (model, hidden_sizes) while (true) { var G = new R.Graph(false); var inp = one_hot($vocab_sz, $vocab[src]); - var lh = R.forwardRNN(G, model, hidden_sizes, inp, prev); + var lh = R.forwardLSTM(G, model, hidden_sizes, inp, prev); prev = lh; var logprobs = lh.o; var probs = R.softmax(logprobs); @@ -103,15 +104,14 @@ var generate = function (model, hidden_sizes) return str; } -var predict = function (model, hidden_sizes) +var predict = function (model, hidden_sizes, context) { - context = ["", "welcome", "to", "my"]; var prev = {}, lh; var logprobs, probs; for (var i=0; i"]; + var [t,tp,dist] = predict(model, hidden_sizes, context); + $("#predict0").html(t+" ("+tp+")"); + + context = ["", "welcome"]; + var [t,tp,dist] = predict(model, hidden_sizes, context); + $("#predict1").html(t+" ("+tp+")"); + + context = ["", "welcome", "to"]; + var [t,tp,dist] = predict(model, hidden_sizes, context); + $("#predict2").html(t+" ("+tp+")"); + + context = ["", "welcome", "to", "my"]; + var [t,tp,dist] = predict(model, hidden_sizes, context); + $("#predict3").html(t+" ("+tp+")"); + + context = ["", "welcome", "to", "my", "house"]; + var [t,tp,dist] = predict(model, hidden_sizes, context); + $("#predict4").html(t+" ("+tp+")"); return false; } -- cgit v1.2.3