summaryrefslogtreecommitdiff
path: root/tensorflow/transformer-attention2.py
diff options
context:
space:
mode:
authorPatrick Simianer <pks@pks.rocks>2019-04-04 12:34:37 +0200
committerPatrick Simianer <pks@pks.rocks>2019-04-04 12:34:37 +0200
commit847b438603127dd7213cef471bdfe4c1671d8524 (patch)
tree543fbb06472a60f01aee6a75d6cde2a383daaef3 /tensorflow/transformer-attention2.py
parentcd792b182dc2f641b9eb79c54db76da98b69c59a (diff)
numpy and tensorflow stuff
Diffstat (limited to 'tensorflow/transformer-attention2.py')
-rw-r--r--tensorflow/transformer-attention2.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/transformer-attention2.py b/tensorflow/transformer-attention2.py
new file mode 100644
index 0000000..c214934
--- /dev/null
+++ b/tensorflow/transformer-attention2.py
@@ -0,0 +1,36 @@
+import numpy as np
+import math
+
+dmodel = 32
+embedding_dim = 8
+nwords = 3
+num_heads = 4
+
+assert(dmodel/num_heads == embedding_dim)
+
+states = np.array([np.random.rand(embedding_dim) for i in range(nwords)]) # num. words x embedding dim
+
+def softmax(m):
+ return np.exp(m) / np.sum(np.exp(m), axis=1)
+
+Wqs = np.random.rand(embedding_dim, dmodel)
+Wks = np.random.rand(embedding_dim, dmodel)
+Wvs = np.random.rand(embedding_dim, dmodel)
+
+queries = np.matmul(states, Wqs)
+keys = np.matmul(states, Wks)
+values = np.matmul(states, Wvs)
+
+print(values)
+
+out = np.matmul(queries, np.transpose(keys))
+out = out/math.sqrt(dmodel/float(num_heads))
+
+out = softmax(out)
+print(out)
+out = np.matmul(out, values)
+
+out = np.matmul(np.random.rand(nwords,out.shape[0]), out)
+print(out.shape)
+print(out)
+