summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/recurrent.js35
1 files changed, 27 insertions, 8 deletions
diff --git a/src/recurrent.js b/src/recurrent.js
index 4c20a39..4b2846b 100644
--- a/src/recurrent.js
+++ b/src/recurrent.js
@@ -116,7 +116,20 @@ var R = {}; // the Recurrent library
Graph.prototype = {
backward: function() {
for(var i=this.backprop.length-1;i>=0;i--) {
- this.backprop[i](); // tick!
+ var debug = false;
+ if (this.backprop[i]["mark"] == "encoder") {
+ debug=true;
+ break;
+ }
+ this.backprop[i](debug); // tick!
+ }
+ },
+ backward1: function() {
+ for(var i=this.backprop.length-1;i>=0;i--) {
+ var debug = false;
+ if (this.backprop[i]["mark"]) debug=true;
+ if (this.backprop[i]["mark"] != "encoder") continue;
+ this.backprop[i](debug); // tick!
}
},
rowPluck: function(m, ix) {
@@ -127,7 +140,7 @@ var R = {}; // the Recurrent library
for(var i=0,n=d;i<n;i++){ out.w[i] = m.w[d * ix + i]; } // copy over the data
if(this.needs_backprop) {
- var backward = function() {
+ var backward = function(debug=false) {
for(var i=0,n=d;i<n;i++){ m.dw[d * ix + i] += out.dw[i]; }
}
this.backprop.push(backward);
@@ -143,7 +156,7 @@ var R = {}; // the Recurrent library
}
if(this.needs_backprop) {
- var backward = function() {
+ var backward = function(debug=false) {
for(var i=0;i<n;i++) {
// grad for z = tanh(x) is (1 - z^2)
var mwi = out.w[i];
@@ -163,7 +176,7 @@ var R = {}; // the Recurrent library
}
if(this.needs_backprop) {
- var backward = function() {
+ var backward = function(debug=false) {
for(var i=0;i<n;i++) {
// grad for z = tanh(x) is (1 - z^2)
var mwi = out.w[i];
@@ -181,7 +194,7 @@ var R = {}; // the Recurrent library
out.w[i] = Math.max(0, m.w[i]); // relu
}
if(this.needs_backprop) {
- var backward = function() {
+ var backward = function(debug=false) {
for(var i=0;i<n;i++) {
m.dw[i] += m.w[i] > 0 ? out.dw[i] : 0.0;
}
@@ -208,7 +221,7 @@ var R = {}; // the Recurrent library
}
if(this.needs_backprop) {
- var backward = function() {
+ var backward = function(debug=false) {
for(var i=0;i<m1.n;i++) { // loop over rows of m1
for(var j=0;j<m2.d;j++) { // loop over cols of m2
for(var k=0;k<m1.d;k++) { // dot product loop
@@ -231,11 +244,15 @@ var R = {}; // the Recurrent library
out.w[i] = m1.w[i] + m2.w[i];
}
if(this.needs_backprop) {
- var backward = function() {
+ var backward = function(debug=false) {
for(var i=0,n=m1.w.length;i<n;i++) {
m1.dw[i] += out.dw[i];
m2.dw[i] += out.dw[i];
}
+ /*if (debug) {
+ if (m1.dw[0]!=0) alert(m1.dw[0]);
+ if (m2.dw[0]!=0) alert(m2.dw[0]);
+ }*/
}
this.backprop.push(backward);
}
@@ -249,7 +266,7 @@ var R = {}; // the Recurrent library
out.w[i] = m1.w[i] * m2.w[i];
}
if(this.needs_backprop) {
- var backward = function() {
+ var backward = function(debug=false) {
for(var i=0,n=m1.w.length;i<n;i++) {
m1.dw[i] += m2.w[i] * out.dw[i];
m2.dw[i] += m1.w[i] * out.dw[i];
@@ -292,6 +309,7 @@ var R = {}; // the Recurrent library
var num_tot = 0;
for(var k in model) {
if(model.hasOwnProperty(k)) {
+ if (k == "name") continue;
var m = model[k]; // mat ref
if(!(k in this.step_cache)) { this.step_cache[k] = new Mat(m.n, m.d); }
var s = this.step_cache[k];
@@ -313,6 +331,7 @@ var R = {}; // the Recurrent library
num_tot++;
// update (and regularize)
+ var q = m.w[i];
m.w[i] += - step_size * mdwi / Math.sqrt(s.w[i] + this.smooth_eps) - regc * m.w[i];
m.dw[i] = 0; // reset gradients for next iteration
}