diff options
Diffstat (limited to 'rnnlm')
-rw-r--r-- | rnnlm/rnnlm.html | 14 | ||||
-rw-r--r-- | rnnlm/rnnlm.js | 60 |
2 files changed, 61 insertions, 13 deletions
diff --git a/rnnlm/rnnlm.html b/rnnlm/rnnlm.html index 8e137d7..b9642f4 100644 --- a/rnnlm/rnnlm.html +++ b/rnnlm/rnnlm.html @@ -15,8 +15,20 @@ <ul id="samples"> </ul> +<p>Predict (context: "--")</p> +<p id="predict0"></p> +<hr /> +<p>Predict (context: "welcome")</p> +<p id="predict1"></p> +<hr /> +<p>Predict (context: "welcome to")</p> +<p id="predict2"></p> +<hr /> <p>Predict (context: "welcome to my")</p> -<p id="predict"></p> +<p id="predict3"></p> +<hr /> +<p>Predict (context: "welcome to my house")</p> +<p id="predict4"></p> </body> </html> diff --git a/rnnlm/rnnlm.js b/rnnlm/rnnlm.js index a8d2dc7..70a66b6 100644 --- a/rnnlm/rnnlm.js +++ b/rnnlm/rnnlm.js @@ -32,11 +32,10 @@ var one_hot = function (n, i) return m; } -var time_step = function (src, tgt, model, lh, solver, hidden_sizes) +var time_step = function (src, tgt, model, lh, solver, hidden_sizes, G) { - var G = new R.Graph(); var inp = one_hot($vocab_sz, $vocab[src]); - var out = R.forwardRNN(G, model, hidden_sizes, inp, lh); + var out = R.forwardLSTM(G, model, hidden_sizes, inp, lh); var logprobs = out.o; var probs = R.softmax(logprobs); var target = $vocab[tgt]; @@ -58,17 +57,19 @@ var stopping_criterion = function (c, d, iter, margin=0.01, max_iter=100) var train = function ($data, hidden_sizes) { - var model = R.initRNN($vocab_sz, hidden_sizes, $vocab_sz); + var model = R.initLSTM($vocab_sz, hidden_sizes, $vocab_sz); var solver = new R.Solver(); lh = {}; costs = []; var k = 0; while (true) { + $data = shuffle($data); var cost = 0.0; for (var i=0; i<$data.length; i++) { + var G = new R.Graph(); for (var j=0; j<$data[i].length-1; j++) { - [model, c, lh, probs] = time_step($data[i][j], $data[i][j+1], model, lh, solver, hidden_sizes); + [model, c, lh, probs] = time_step($data[i][j], $data[i][j+1], model, lh, solver, hidden_sizes, G); cost += c; } } @@ -89,7 +90,7 @@ var generate = function (model, hidden_sizes) while (true) { var G = new R.Graph(false); var inp = one_hot($vocab_sz, $vocab[src]); - var lh = R.forwardRNN(G, model, hidden_sizes, inp, prev); + var lh = R.forwardLSTM(G, model, hidden_sizes, inp, prev); prev = lh; var logprobs = lh.o; var probs = R.softmax(logprobs); @@ -103,15 +104,14 @@ var generate = function (model, hidden_sizes) return str; } -var predict = function (model, hidden_sizes) +var predict = function (model, hidden_sizes, context) { - context = ["<bos>", "welcome", "to", "my"]; var prev = {}, lh; var logprobs, probs; for (var i=0; i<context.length; i++) { var G = new R.Graph(false); var inp = one_hot($vocab_sz, $vocab[context[i]]); - lh = R.forwardRNN(G, model, hidden_sizes, inp, prev); + lh = R.forwardLSTM(G, model, hidden_sizes, inp, prev); prev = lh; logprobs = lh.o; probs = R.softmax(logprobs); @@ -122,9 +122,28 @@ var predict = function (model, hidden_sizes) return [$ivocab[maxi], probs.w[maxi], probs.w]; } +function shuffle(array) { + var currentIndex = array.length, temporaryValue, randomIndex; + + // While there remain elements to shuffle... + while (0 !== currentIndex) { + + // Pick a remaining element... + randomIndex = Math.floor(Math.random() * currentIndex); + currentIndex -= 1; + + // And swap it with the current element. + temporaryValue = array[currentIndex]; + array[currentIndex] = array[randomIndex]; + array[randomIndex] = temporaryValue; + } + + return array; +} + var main = function () { - var hidden_sizes = [20]; + var hidden_sizes = [10]; var [model, costs] = train($data, hidden_sizes); for (var i=0; i<costs.length; i++) @@ -138,8 +157,25 @@ var main = function () if (i==13) break; } - var [t,tp,dist] = predict(model, hidden_sizes); - $("#predict").html(t+" ("+tp+")"); + var context = ["<bos>"]; + var [t,tp,dist] = predict(model, hidden_sizes, context); + $("#predict0").html(t+" ("+tp+")"); + + context = ["<bos>", "welcome"]; + var [t,tp,dist] = predict(model, hidden_sizes, context); + $("#predict1").html(t+" ("+tp+")"); + + context = ["<bos>", "welcome", "to"]; + var [t,tp,dist] = predict(model, hidden_sizes, context); + $("#predict2").html(t+" ("+tp+")"); + + context = ["<bos>", "welcome", "to", "my"]; + var [t,tp,dist] = predict(model, hidden_sizes, context); + $("#predict3").html(t+" ("+tp+")"); + + context = ["<bos>", "welcome", "to", "my", "house"]; + var [t,tp,dist] = predict(model, hidden_sizes, context); + $("#predict4").html(t+" ("+tp+")"); return false; } |