diff options
Diffstat (limited to 'rnnlm')
| -rw-r--r-- | rnnlm/rnnlm.html | 14 | ||||
| -rw-r--r-- | rnnlm/rnnlm.js | 60 | 
2 files changed, 61 insertions, 13 deletions
| diff --git a/rnnlm/rnnlm.html b/rnnlm/rnnlm.html index 8e137d7..b9642f4 100644 --- a/rnnlm/rnnlm.html +++ b/rnnlm/rnnlm.html @@ -15,8 +15,20 @@  <ul id="samples">  </ul> +<p>Predict (context: "--")</p> +<p id="predict0"></p> +<hr /> +<p>Predict (context: "welcome")</p> +<p id="predict1"></p> +<hr /> +<p>Predict (context: "welcome to")</p> +<p id="predict2"></p> +<hr />  <p>Predict (context: "welcome to my")</p> -<p id="predict"></p> +<p id="predict3"></p> +<hr /> +<p>Predict (context: "welcome to my house")</p> +<p id="predict4"></p>    </body>  </html> diff --git a/rnnlm/rnnlm.js b/rnnlm/rnnlm.js index a8d2dc7..70a66b6 100644 --- a/rnnlm/rnnlm.js +++ b/rnnlm/rnnlm.js @@ -32,11 +32,10 @@ var one_hot = function (n, i)    return m;  } -var time_step = function (src, tgt, model, lh, solver, hidden_sizes) +var time_step = function (src, tgt, model, lh, solver, hidden_sizes, G)  { -  var G = new R.Graph();    var inp = one_hot($vocab_sz, $vocab[src]); -  var out = R.forwardRNN(G, model, hidden_sizes, inp, lh); +  var out = R.forwardLSTM(G, model, hidden_sizes, inp, lh);    var logprobs = out.o;    var probs = R.softmax(logprobs);    var target = $vocab[tgt]; @@ -58,17 +57,19 @@ var stopping_criterion = function (c, d, iter, margin=0.01, max_iter=100)  var train = function ($data, hidden_sizes)  { -  var model = R.initRNN($vocab_sz, hidden_sizes, $vocab_sz); +  var model = R.initLSTM($vocab_sz, hidden_sizes, $vocab_sz);    var solver = new R.Solver();    lh = {};    costs = [];    var k = 0;    while (true)    { +    $data = shuffle($data);      var cost = 0.0;      for (var i=0; i<$data.length; i++) { +      var G = new R.Graph();        for (var j=0; j<$data[i].length-1; j++) { -        [model, c, lh, probs] = time_step($data[i][j], $data[i][j+1], model, lh, solver, hidden_sizes); +        [model, c, lh, probs] = time_step($data[i][j], $data[i][j+1], model, lh, solver, hidden_sizes, G);          cost += c;        }      } @@ -89,7 +90,7 @@ var generate = function (model, hidden_sizes)    while (true) {      var G = new R.Graph(false);      var inp = one_hot($vocab_sz, $vocab[src]); -    var lh = R.forwardRNN(G, model, hidden_sizes, inp, prev); +    var lh = R.forwardLSTM(G, model, hidden_sizes, inp, prev);      prev = lh;      var logprobs = lh.o;      var probs = R.softmax(logprobs); @@ -103,15 +104,14 @@ var generate = function (model, hidden_sizes)    return str;  } -var predict = function (model, hidden_sizes) +var predict = function (model, hidden_sizes, context)  { -  context = ["<bos>", "welcome", "to", "my"];    var prev = {}, lh;    var logprobs, probs;    for (var i=0; i<context.length; i++) {      var G = new R.Graph(false);      var inp = one_hot($vocab_sz, $vocab[context[i]]); -    lh = R.forwardRNN(G, model, hidden_sizes, inp, prev); +    lh = R.forwardLSTM(G, model, hidden_sizes, inp, prev);      prev = lh;      logprobs = lh.o;      probs = R.softmax(logprobs); @@ -122,9 +122,28 @@ var predict = function (model, hidden_sizes)    return [$ivocab[maxi], probs.w[maxi], probs.w];  } +function shuffle(array) { +  var currentIndex = array.length, temporaryValue, randomIndex; + +  // While there remain elements to shuffle... +  while (0 !== currentIndex) { + +    // Pick a remaining element... +    randomIndex = Math.floor(Math.random() * currentIndex); +    currentIndex -= 1; + +    // And swap it with the current element. +    temporaryValue = array[currentIndex]; +    array[currentIndex] = array[randomIndex]; +    array[randomIndex] = temporaryValue; +  } + +  return array; +} +  var main = function ()  { -  var hidden_sizes = [20]; +  var hidden_sizes = [10];    var [model, costs] = train($data, hidden_sizes);    for (var i=0; i<costs.length; i++) @@ -138,8 +157,25 @@ var main = function ()      if (i==13) break;    } -  var [t,tp,dist] = predict(model, hidden_sizes); -  $("#predict").html(t+" ("+tp+")"); +  var context = ["<bos>"]; +  var [t,tp,dist] = predict(model, hidden_sizes, context); +  $("#predict0").html(t+" ("+tp+")"); + +  context = ["<bos>", "welcome"]; +  var [t,tp,dist] = predict(model, hidden_sizes, context); +  $("#predict1").html(t+" ("+tp+")"); + +  context = ["<bos>", "welcome", "to"]; +  var [t,tp,dist] = predict(model, hidden_sizes, context); +  $("#predict2").html(t+" ("+tp+")"); + +  context = ["<bos>", "welcome", "to", "my"]; +  var [t,tp,dist] = predict(model, hidden_sizes, context); +  $("#predict3").html(t+" ("+tp+")"); + +  context = ["<bos>", "welcome", "to", "my", "house"]; +  var [t,tp,dist] = predict(model, hidden_sizes, context); +  $("#predict4").html(t+" ("+tp+")");    return false;  } | 
