summaryrefslogtreecommitdiff
path: root/tensorflow/transformer-attention2.py
blob: c21493464e4dda5fcecd609a1836e58b814ac079 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
import math

dmodel = 32
embedding_dim = 8
nwords = 3
num_heads = 4

assert(dmodel/num_heads == embedding_dim)

states = np.array([np.random.rand(embedding_dim) for i in range(nwords)]) # num. words x embedding dim

def softmax(m):
    return np.exp(m) / np.sum(np.exp(m), axis=1)

Wqs = np.random.rand(embedding_dim, dmodel)
Wks = np.random.rand(embedding_dim, dmodel)
Wvs = np.random.rand(embedding_dim, dmodel)

queries = np.matmul(states, Wqs)
keys    = np.matmul(states, Wks)
values  = np.matmul(states, Wvs)

print(values)

out = np.matmul(queries, np.transpose(keys))
out = out/math.sqrt(dmodel/float(num_heads))

out = softmax(out)
print(out)
out = np.matmul(out, values)

out = np.matmul(np.random.rand(nwords,out.shape[0]), out)
print(out.shape)
print(out)