diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/recurrent.js | 35 |
1 files changed, 27 insertions, 8 deletions
diff --git a/src/recurrent.js b/src/recurrent.js index 4c20a39..4b2846b 100644 --- a/src/recurrent.js +++ b/src/recurrent.js @@ -116,7 +116,20 @@ var R = {}; // the Recurrent library Graph.prototype = { backward: function() { for(var i=this.backprop.length-1;i>=0;i--) { - this.backprop[i](); // tick! + var debug = false; + if (this.backprop[i]["mark"] == "encoder") { + debug=true; + break; + } + this.backprop[i](debug); // tick! + } + }, + backward1: function() { + for(var i=this.backprop.length-1;i>=0;i--) { + var debug = false; + if (this.backprop[i]["mark"]) debug=true; + if (this.backprop[i]["mark"] != "encoder") continue; + this.backprop[i](debug); // tick! } }, rowPluck: function(m, ix) { @@ -127,7 +140,7 @@ var R = {}; // the Recurrent library for(var i=0,n=d;i<n;i++){ out.w[i] = m.w[d * ix + i]; } // copy over the data if(this.needs_backprop) { - var backward = function() { + var backward = function(debug=false) { for(var i=0,n=d;i<n;i++){ m.dw[d * ix + i] += out.dw[i]; } } this.backprop.push(backward); @@ -143,7 +156,7 @@ var R = {}; // the Recurrent library } if(this.needs_backprop) { - var backward = function() { + var backward = function(debug=false) { for(var i=0;i<n;i++) { // grad for z = tanh(x) is (1 - z^2) var mwi = out.w[i]; @@ -163,7 +176,7 @@ var R = {}; // the Recurrent library } if(this.needs_backprop) { - var backward = function() { + var backward = function(debug=false) { for(var i=0;i<n;i++) { // grad for z = tanh(x) is (1 - z^2) var mwi = out.w[i]; @@ -181,7 +194,7 @@ var R = {}; // the Recurrent library out.w[i] = Math.max(0, m.w[i]); // relu } if(this.needs_backprop) { - var backward = function() { + var backward = function(debug=false) { for(var i=0;i<n;i++) { m.dw[i] += m.w[i] > 0 ? out.dw[i] : 0.0; } @@ -208,7 +221,7 @@ var R = {}; // the Recurrent library } if(this.needs_backprop) { - var backward = function() { + var backward = function(debug=false) { for(var i=0;i<m1.n;i++) { // loop over rows of m1 for(var j=0;j<m2.d;j++) { // loop over cols of m2 for(var k=0;k<m1.d;k++) { // dot product loop @@ -231,11 +244,15 @@ var R = {}; // the Recurrent library out.w[i] = m1.w[i] + m2.w[i]; } if(this.needs_backprop) { - var backward = function() { + var backward = function(debug=false) { for(var i=0,n=m1.w.length;i<n;i++) { m1.dw[i] += out.dw[i]; m2.dw[i] += out.dw[i]; } + /*if (debug) { + if (m1.dw[0]!=0) alert(m1.dw[0]); + if (m2.dw[0]!=0) alert(m2.dw[0]); + }*/ } this.backprop.push(backward); } @@ -249,7 +266,7 @@ var R = {}; // the Recurrent library out.w[i] = m1.w[i] * m2.w[i]; } if(this.needs_backprop) { - var backward = function() { + var backward = function(debug=false) { for(var i=0,n=m1.w.length;i<n;i++) { m1.dw[i] += m2.w[i] * out.dw[i]; m2.dw[i] += m1.w[i] * out.dw[i]; @@ -292,6 +309,7 @@ var R = {}; // the Recurrent library var num_tot = 0; for(var k in model) { if(model.hasOwnProperty(k)) { + if (k == "name") continue; var m = model[k]; // mat ref if(!(k in this.step_cache)) { this.step_cache[k] = new Mat(m.n, m.d); } var s = this.step_cache[k]; @@ -313,6 +331,7 @@ var R = {}; // the Recurrent library num_tot++; // update (and regularize) + var q = m.w[i]; m.w[i] += - step_size * mdwi / Math.sqrt(s.w[i] + this.smooth_eps) - regc * m.w[i]; m.dw[i] = 0; // reset gradients for next iteration } |