summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Simianer <pks@pks.rocks>2018-10-13 09:50:32 +0200
committerPatrick Simianer <pks@pks.rocks>2018-10-13 09:50:32 +0200
commitcd792b182dc2f641b9eb79c54db76da98b69c59a (patch)
tree93e84a54e3d92959db977bfe6a64cbc21d745763
parentb78965221808a3f688c1da12a1179db1bbbf1ad5 (diff)
scatter-nd-add
-rw-r--r--tensorflow/scatter-nd-add.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/tensorflow/scatter-nd-add.py b/tensorflow/scatter-nd-add.py
new file mode 100644
index 0000000..7194d8b
--- /dev/null
+++ b/tensorflow/scatter-nd-add.py
@@ -0,0 +1,57 @@
+import numpy as np
+import tensorflow as tf
+
+sess = tf.Session()
+
+#idx = tf.constant([[0,2],[1,2]])
+
+# 4 x 2 | 40K x 256
+m = tf.Variable([[1,2],
+ [0,0],
+ [0,0],
+ [0,0]], dtype=tf.float32)
+# -> 2 x 4 | 256 x 40K
+m_transposed = tf.transpose(m)
+# -> AttributeError: 'Tensor' object has no attribute '_lazy_read'
+m_new = tf.Variable([[1., 0., 0., 0.],
+ [2., 0., 0., 0.]], dtype=tf.float32)
+
+# 1 x 3 | 1 x Y
+idx = tf.constant([1,2,3], dtype=tf.int32)
+idx = sess.run(idx)
+_idx = []
+for j in idx:
+ for i in range(0,m_new.shape[0]):
+ _idx.append([i,j])
+#idx_new = tf.constant(_idx, dtype=tf.int32)
+idx_new = np.full(fill_value=_idx, shape=[6,2], dtype=np.int32)
+
+# 2 x 2
+up = tf.constant([[1,1],[1,1],[1,1]], dtype=tf.float32)
+# 1 x 4
+up_new = tf.reshape(up, [tf.size(up)])
+
+sess.run(tf.global_variables_initializer())
+
+print("m")
+print(sess.run(m))
+print("m_new")
+print(sess.run(m_new))
+print("m_transposed")
+print(sess.run(m_transposed))
+print("idx")
+print(idx)
+print("idx_new")
+#print(sess.run(idx_new))
+print(idx_new)
+print("up")
+print(sess.run(up))
+print("up_new")
+print(sess.run(up_new))
+
+print()
+print("scatter_nd_add")
+print(sess.run(tf.scatter_nd_add(m_new, indices=idx_new, updates=up_new)))
+
+print()
+