From 19b320070754dd34146a3921db825c7ed7a3094e Mon Sep 17 00:00:00 2001 From: Patrick Simianer
Date: Tue, 5 Jul 2016 11:26:29 +0200
Subject: more test contxts, use LSTM, shuffle, fix
---
rnnlm/rnnlm.js | 60 ++++++++++++++++++++++++++++++++++++++++++++++------------
1 file changed, 48 insertions(+), 12 deletions(-)
(limited to 'rnnlm/rnnlm.js')
diff --git a/rnnlm/rnnlm.js b/rnnlm/rnnlm.js
index a8d2dc7..70a66b6 100644
--- a/rnnlm/rnnlm.js
+++ b/rnnlm/rnnlm.js
@@ -32,11 +32,10 @@ var one_hot = function (n, i)
return m;
}
-var time_step = function (src, tgt, model, lh, solver, hidden_sizes)
+var time_step = function (src, tgt, model, lh, solver, hidden_sizes, G)
{
- var G = new R.Graph();
var inp = one_hot($vocab_sz, $vocab[src]);
- var out = R.forwardRNN(G, model, hidden_sizes, inp, lh);
+ var out = R.forwardLSTM(G, model, hidden_sizes, inp, lh);
var logprobs = out.o;
var probs = R.softmax(logprobs);
var target = $vocab[tgt];
@@ -58,17 +57,19 @@ var stopping_criterion = function (c, d, iter, margin=0.01, max_iter=100)
var train = function ($data, hidden_sizes)
{
- var model = R.initRNN($vocab_sz, hidden_sizes, $vocab_sz);
+ var model = R.initLSTM($vocab_sz, hidden_sizes, $vocab_sz);
var solver = new R.Solver();
lh = {};
costs = [];
var k = 0;
while (true)
{
+ $data = shuffle($data);
var cost = 0.0;
for (var i=0; i<$data.length; i++) {
+ var G = new R.Graph();
for (var j=0; j<$data[i].length-1; j++) {
- [model, c, lh, probs] = time_step($data[i][j], $data[i][j+1], model, lh, solver, hidden_sizes);
+ [model, c, lh, probs] = time_step($data[i][j], $data[i][j+1], model, lh, solver, hidden_sizes, G);
cost += c;
}
}
@@ -89,7 +90,7 @@ var generate = function (model, hidden_sizes)
while (true) {
var G = new R.Graph(false);
var inp = one_hot($vocab_sz, $vocab[src]);
- var lh = R.forwardRNN(G, model, hidden_sizes, inp, prev);
+ var lh = R.forwardLSTM(G, model, hidden_sizes, inp, prev);
prev = lh;
var logprobs = lh.o;
var probs = R.softmax(logprobs);
@@ -103,15 +104,14 @@ var generate = function (model, hidden_sizes)
return str;
}
-var predict = function (model, hidden_sizes)
+var predict = function (model, hidden_sizes, context)
{
- context = ["