11using Lux: Lux
22using LuxCore: AbstractLuxLayer
3+ using CUDA
34using Random: AbstractRNG
5+ using NNlib: batched_mul
46
57struct attention{F} <: AbstractLuxLayer
68 T:: Type
@@ -28,7 +30,7 @@ function attention(
2830 end
2931 @assert N % patch_size == 0 " N must be divisible by patch_size"
3032 n_patches = (div (N, patch_size))^ d
31- dh = div (emb_size, n_heads)
33+ dh = div (emb_size, n_heads) # dimension of each head (scale down the embedding size)
3234 attention (T, N, d, emb_size, patch_size, n_patches, n_heads, dh, init_weight)
3335end
3436
@@ -40,14 +42,16 @@ function Lux.initialparameters(
4042)
4143 (;
4244 # the attention weights have this size
43- wQ = init_weight (rng, T, n_heads, dh, emb_size + 1 ),
44- wK = init_weight (rng, T, n_heads, dh, emb_size + 1 ),
45- wV = init_weight (rng, T, n_heads, dh, emb_size + 1 ),
45+ wQ = init_weight (rng, T, n_heads, dh, emb_size),
46+ wK = init_weight (rng, T, n_heads, dh, emb_size),
47+ wV = init_weight (rng, T, n_heads, dh, emb_size),
4648 # then the embedding operator
4749 Ew = init_weight (rng, T, emb_size, patch_size * patch_size * d),
4850 Eb = zeros (T, emb_size),
49- # then the multihead attention
51+ # then the multihead attention output matrix
5052 U = init_weight (rng, T, N * N * d, n_patches * n_heads * dh),
53+ # and the positional embedding
54+ pos_emb = init_weight (rng, T, emb_size, div (N, patch_size), div (N, patch_size)),
5155 )
5256end
5357
@@ -61,26 +65,27 @@ function Lux.initialstates(
6165 d = d,
6266 emb_size = emb_size,
6367 patch_size = patch_size,
64- n_patches = n_patches,
68+ n_patches = n_patches, # total number of patches
6569 n_heads = n_heads,
6670 dh = dh,
6771 sqrtDh = T (sqrt (dh)),
72+ num_patches_1d = div (N, patch_size),
6873 )
6974end
7075function Lux. parameterlength (
7176 (; N, d, n_heads, dh, emb_size, patch_size, n_patches):: attention ,
7277)
73- size_wQ = n_heads * dh * ( emb_size + 1 )
74- size_wK = n_heads * dh * ( emb_size + 1 )
75- size_wV = n_heads * dh * ( emb_size + 1 )
78+ size_wQ = n_heads * dh * emb_size
79+ size_wK = n_heads * dh * emb_size
80+ size_wV = n_heads * dh * emb_size
7681 size_Ew = emb_size * patch_size * patch_size * d
7782 size_Eb = emb_size
7883 size_U = N * N * d * n_patches * n_heads * dh
7984
8085 total_size = size_wQ + size_wK + size_wV + size_Ew + size_Eb + size_U
8186 return total_size
8287end
83- Lux. statelength (:: attention ) = 9
88+ Lux. statelength (:: attention ) = 11
8489
8590# This is what each layer does:
8691# expected input shape: [N, N, d, batch]
@@ -93,52 +98,44 @@ function ((;)::attention)(x, params, state)
9398 dh = state. dh
9499 sqrtDh = state. sqrtDh
95100 n_heads = state. n_heads
101+ num_patches_1d = state. num_patches_1d
96102
97103 Ew = params. Ew
98104 Eb = params. Eb
99105 wQ = params. wQ
100106 wK = params. wK
101107 wV = params. wV
102108 U = params. U
109+ pos_emb = params. pos_emb
110+
111+ batch = size (x, ndims (x))
103112
104113 # (1) Split the image into patches
105- num_patches = div (N, ps)
106- # The subarray of x here is by default a copy, but it can be a view (its not edited)
107- x_patches = [
108- @view (x[(i* ps+ 1 ): (i* ps+ ps), (j* ps+ 1 ): (j* ps+ ps), :, :]) for
109- i = 0 : (num_patches- 1 ), j = 0 : (num_patches- 1 )
110- ]
114+ x_patches = reshape (x, ps, num_patches_1d, ps, num_patches_1d, d, batch)
115+ x_patches = permutedims (x_patches, (1 , 3 , 5 , 2 , 4 , 6 ))
111116 # (2) flatten the patches
112- # reshape is fine and will not create a copy here, as only the first dims are merged, and because julia
113- # is column order, this does not change the shape of the underlying data, this is true for all following reshapes
114- x_pflat = [reshape (p, ps * ps * d, size (p, ndims (p))) for p in x_patches]
115-
117+ x_patches = reshape (x_patches, ps * ps * d, :)
116118 # (3) project the patches onto the embedding space
117- x_emb = [Ew * p .+ Eb for p in x_pflat]
119+ x_emb = Ew * x_patches .+ Eb
120+ x_emb = reshape (x_emb, size (x_emb, 1 ), num_patches_1d, num_patches_1d, batch)
118121
119- # (4) positional embedding
122+ # (4) add the positional embedding
120123 # notice that we use 1D positional embedding, as suggested [here](https://arxiv.org/pdf/2010.11929)
121- x_lemb = [
122- cat (p, ones (state. T, 1 , size (p)[2 : end ]. .. ) * i; dims = 1 ) for
123- (i, p) in enumerate (x_emb)
124- ]
124+ x_lemb = x_emb .+ pos_emb
125+ x_lemb = reshape (x_lemb, size (x_lemb, 1 ), num_patches_1d * num_patches_1d, batch)
125126
126127 # (5) compute the attention scores
127- # [!] notice that you can not reuse some variable names otherwise Zygote gets confused
128- Q0 = [wQ[i, :, :] * x_lemb[patchi] for i = 1 : n_heads, patchi = 1 : np]
129- K0 = [wK[i, :, :] * x_lemb[patchi] for i = 1 : n_heads, patchi = 1 : np]
130- V0 = [wV[i, :, :] * x_lemb[patchi] for i = 1 : n_heads, patchi = 1 : np]
131- # Reshape Q, K, V to match desired output dimensions
132- Q = reshape (vcat (Q0... ), (n_heads, np, dh, size (x, ndims (x))))
133- K = reshape (vcat (K0... ), (n_heads, np, dh, size (x, ndims (x))))
134- V = reshape (vcat (V0... ), (n_heads, np, dh, size (x, ndims (x))))
135- # (6) Compute attention scores without mutations
136- A = [Lux. softmax (Q[i, p, :, :] .* K[i, p, :, :] / sqrtDh) for i = 1 : n_heads, p = 1 : np]
137- A = reshape (vcat (A... ), (n_heads, np, dh, size (x, ndims (x))))
138- SA = A .* V
128+ Q = compute_QKV (x_lemb, wQ)
129+ K = compute_QKV (x_lemb, wK)
130+ V = compute_QKV (x_lemb, wV)
131+
132+ # (6) Compute attention scores
133+ A = attention_weights (Q, K)
134+ A = Lux. softmax (A / sqrtDh, dims = 3 )
135+ A = attention_scores (A, V)
139136
140137 # (7) multihead attention
141- MSA = reshape (SA , n_heads * np * dh, size (x, ndims (x)))
138+ MSA = reshape (A , n_heads * np * dh, size (x, ndims (x)))
142139 MSA = U * MSA
143140 MSA = reshape (MSA, size (x)... )
144141
0 commit comments