summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--rnnlm/rnnlm.html14
-rw-r--r--rnnlm/rnnlm.js60
2 files changed, 61 insertions, 13 deletions
diff --git a/rnnlm/rnnlm.html b/rnnlm/rnnlm.html
index 8e137d7..b9642f4 100644
--- a/rnnlm/rnnlm.html
+++ b/rnnlm/rnnlm.html
@@ -15,8 +15,20 @@
<ul id="samples">
</ul>
+<p>Predict (context: "--")</p>
+<p id="predict0"></p>
+<hr />
+<p>Predict (context: "welcome")</p>
+<p id="predict1"></p>
+<hr />
+<p>Predict (context: "welcome to")</p>
+<p id="predict2"></p>
+<hr />
<p>Predict (context: "welcome to my")</p>
-<p id="predict"></p>
+<p id="predict3"></p>
+<hr />
+<p>Predict (context: "welcome to my house")</p>
+<p id="predict4"></p>
</body>
</html>
diff --git a/rnnlm/rnnlm.js b/rnnlm/rnnlm.js
index a8d2dc7..70a66b6 100644
--- a/rnnlm/rnnlm.js
+++ b/rnnlm/rnnlm.js
@@ -32,11 +32,10 @@ var one_hot = function (n, i)
return m;
}
-var time_step = function (src, tgt, model, lh, solver, hidden_sizes)
+var time_step = function (src, tgt, model, lh, solver, hidden_sizes, G)
{
- var G = new R.Graph();
var inp = one_hot($vocab_sz, $vocab[src]);
- var out = R.forwardRNN(G, model, hidden_sizes, inp, lh);
+ var out = R.forwardLSTM(G, model, hidden_sizes, inp, lh);
var logprobs = out.o;
var probs = R.softmax(logprobs);
var target = $vocab[tgt];
@@ -58,17 +57,19 @@ var stopping_criterion = function (c, d, iter, margin=0.01, max_iter=100)
var train = function ($data, hidden_sizes)
{
- var model = R.initRNN($vocab_sz, hidden_sizes, $vocab_sz);
+ var model = R.initLSTM($vocab_sz, hidden_sizes, $vocab_sz);
var solver = new R.Solver();
lh = {};
costs = [];
var k = 0;
while (true)
{
+ $data = shuffle($data);
var cost = 0.0;
for (var i=0; i<$data.length; i++) {
+ var G = new R.Graph();
for (var j=0; j<$data[i].length-1; j++) {
- [model, c, lh, probs] = time_step($data[i][j], $data[i][j+1], model, lh, solver, hidden_sizes);
+ [model, c, lh, probs] = time_step($data[i][j], $data[i][j+1], model, lh, solver, hidden_sizes, G);
cost += c;
}
}
@@ -89,7 +90,7 @@ var generate = function (model, hidden_sizes)
while (true) {
var G = new R.Graph(false);
var inp = one_hot($vocab_sz, $vocab[src]);
- var lh = R.forwardRNN(G, model, hidden_sizes, inp, prev);
+ var lh = R.forwardLSTM(G, model, hidden_sizes, inp, prev);
prev = lh;
var logprobs = lh.o;
var probs = R.softmax(logprobs);
@@ -103,15 +104,14 @@ var generate = function (model, hidden_sizes)
return str;
}
-var predict = function (model, hidden_sizes)
+var predict = function (model, hidden_sizes, context)
{
- context = ["<bos>", "welcome", "to", "my"];
var prev = {}, lh;
var logprobs, probs;
for (var i=0; i<context.length; i++) {
var G = new R.Graph(false);
var inp = one_hot($vocab_sz, $vocab[context[i]]);
- lh = R.forwardRNN(G, model, hidden_sizes, inp, prev);
+ lh = R.forwardLSTM(G, model, hidden_sizes, inp, prev);
prev = lh;
logprobs = lh.o;
probs = R.softmax(logprobs);
@@ -122,9 +122,28 @@ var predict = function (model, hidden_sizes)
return [$ivocab[maxi], probs.w[maxi], probs.w];
}
+function shuffle(array) {
+ var currentIndex = array.length, temporaryValue, randomIndex;
+
+ // While there remain elements to shuffle...
+ while (0 !== currentIndex) {
+
+ // Pick a remaining element...
+ randomIndex = Math.floor(Math.random() * currentIndex);
+ currentIndex -= 1;
+
+ // And swap it with the current element.
+ temporaryValue = array[currentIndex];
+ array[currentIndex] = array[randomIndex];
+ array[randomIndex] = temporaryValue;
+ }
+
+ return array;
+}
+
var main = function ()
{
- var hidden_sizes = [20];
+ var hidden_sizes = [10];
var [model, costs] = train($data, hidden_sizes);
for (var i=0; i<costs.length; i++)
@@ -138,8 +157,25 @@ var main = function ()
if (i==13) break;
}
- var [t,tp,dist] = predict(model, hidden_sizes);
- $("#predict").html(t+" ("+tp+")");
+ var context = ["<bos>"];
+ var [t,tp,dist] = predict(model, hidden_sizes, context);
+ $("#predict0").html(t+" ("+tp+")");
+
+ context = ["<bos>", "welcome"];
+ var [t,tp,dist] = predict(model, hidden_sizes, context);
+ $("#predict1").html(t+" ("+tp+")");
+
+ context = ["<bos>", "welcome", "to"];
+ var [t,tp,dist] = predict(model, hidden_sizes, context);
+ $("#predict2").html(t+" ("+tp+")");
+
+ context = ["<bos>", "welcome", "to", "my"];
+ var [t,tp,dist] = predict(model, hidden_sizes, context);
+ $("#predict3").html(t+" ("+tp+")");
+
+ context = ["<bos>", "welcome", "to", "my", "house"];
+ var [t,tp,dist] = predict(model, hidden_sizes, context);
+ $("#predict4").html(t+" ("+tp+")");
return false;
}