diff options
Diffstat (limited to 'tensorflow/transformer-attention2.py')
-rw-r--r-- | tensorflow/transformer-attention2.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/transformer-attention2.py b/tensorflow/transformer-attention2.py new file mode 100644 index 0000000..c214934 --- /dev/null +++ b/tensorflow/transformer-attention2.py @@ -0,0 +1,36 @@ +import numpy as np +import math + +dmodel = 32 +embedding_dim = 8 +nwords = 3 +num_heads = 4 + +assert(dmodel/num_heads == embedding_dim) + +states = np.array([np.random.rand(embedding_dim) for i in range(nwords)]) # num. words x embedding dim + +def softmax(m): + return np.exp(m) / np.sum(np.exp(m), axis=1) + +Wqs = np.random.rand(embedding_dim, dmodel) +Wks = np.random.rand(embedding_dim, dmodel) +Wvs = np.random.rand(embedding_dim, dmodel) + +queries = np.matmul(states, Wqs) +keys = np.matmul(states, Wks) +values = np.matmul(states, Wvs) + +print(values) + +out = np.matmul(queries, np.transpose(keys)) +out = out/math.sqrt(dmodel/float(num_heads)) + +out = softmax(out) +print(out) +out = np.matmul(out, values) + +out = np.matmul(np.random.rand(nwords,out.shape[0]), out) +print(out.shape) +print(out) + |