Skip to content

Commit 5ddaae3

Browse files
authored
GPU via KernelAbstraction (#18)
1 parent 1d148e4 commit 5ddaae3

File tree

10 files changed

+944
-173
lines changed

10 files changed

+944
-173
lines changed

Project.toml

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,31 @@ authors = ["SCiarella <[email protected]>"]
44
version = "0.1.0"
55

66
[deps]
7+
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
78
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
9+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
810
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
911
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
12+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1013
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
14+
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
1115
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1216
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1317
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1418

19+
[sources]
20+
CoupledNODE = {rev = "main", url = "https://github.com/DEEPDIP-project/CoupledNODE.jl.git"}
21+
NeuralClosure = {rev = "main", url = "https://github.com/DEEPDIP-project/NeuralClosure.jl.git"}
22+
1523
[compat]
24+
Atomix = "1.1.1"
1625
CUDA = "5"
26+
ChainRulesCore = "1.25.1"
1727
ComponentArrays = "0.15"
1828
JuliaFormatter = "1.0.62"
29+
KernelAbstractions = "0.9.34"
1930
Lux = "1"
31+
LuxCUDA = "0.3.3"
2032
LuxCore = "1"
2133
NNlib = "0.9.27"
2234
julia = "1.10"
@@ -33,9 +45,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3345
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
3446
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3547

36-
[sources]
37-
NeuralClosure = {rev = "main", url = "https://github.com/DEEPDIP-project/NeuralClosure.jl.git"}
38-
CoupledNODE = {rev = "main", url="https://github.com/DEEPDIP-project/CoupledNODE.jl.git"}
39-
4048
[targets]
4149
test = ["Test", "Adapt", "CoupledNODE", "IncompressibleNavierStokes", "JLD2", "NeuralClosure", "Optimisers", "OrdinaryDiffEqTsit5", "TestItemRunner", "Zygote"]

src/AttentionLayer.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module AttentionLayer
33
using CUDA: CUDA
44
ArrayType = CUDA.functional() ? CUDA.CuArray : Array
55

6+
include("utils.jl")
67
include("layer.jl")
78
include("attention_cnn.jl")
89

src/layer.jl

Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using Lux: Lux
22
using LuxCore: AbstractLuxLayer
3+
using CUDA
34
using Random: AbstractRNG
5+
using NNlib: batched_mul
46

57
struct attention{F} <: AbstractLuxLayer
68
T::Type
@@ -28,7 +30,7 @@ function attention(
2830
end
2931
@assert N % patch_size == 0 "N must be divisible by patch_size"
3032
n_patches = (div(N, patch_size))^d
31-
dh = div(emb_size, n_heads)
33+
dh = div(emb_size, n_heads) # dimension of each head (scale down the embedding size)
3234
attention(T, N, d, emb_size, patch_size, n_patches, n_heads, dh, init_weight)
3335
end
3436

@@ -40,14 +42,16 @@ function Lux.initialparameters(
4042
)
4143
(;
4244
# the attention weights have this size
43-
wQ = init_weight(rng, T, n_heads, dh, emb_size + 1),
44-
wK = init_weight(rng, T, n_heads, dh, emb_size + 1),
45-
wV = init_weight(rng, T, n_heads, dh, emb_size + 1),
45+
wQ = init_weight(rng, T, n_heads, dh, emb_size),
46+
wK = init_weight(rng, T, n_heads, dh, emb_size),
47+
wV = init_weight(rng, T, n_heads, dh, emb_size),
4648
# then the embedding operator
4749
Ew = init_weight(rng, T, emb_size, patch_size * patch_size * d),
4850
Eb = zeros(T, emb_size),
49-
# then the multihead attention
51+
# then the multihead attention output matrix
5052
U = init_weight(rng, T, N * N * d, n_patches * n_heads * dh),
53+
# and the positional embedding
54+
pos_emb = init_weight(rng, T, emb_size, div(N, patch_size), div(N, patch_size)),
5155
)
5256
end
5357

@@ -61,26 +65,27 @@ function Lux.initialstates(
6165
d = d,
6266
emb_size = emb_size,
6367
patch_size = patch_size,
64-
n_patches = n_patches,
68+
n_patches = n_patches, # total number of patches
6569
n_heads = n_heads,
6670
dh = dh,
6771
sqrtDh = T(sqrt(dh)),
72+
num_patches_1d = div(N, patch_size),
6873
)
6974
end
7075
function Lux.parameterlength(
7176
(; N, d, n_heads, dh, emb_size, patch_size, n_patches)::attention,
7277
)
73-
size_wQ = n_heads * dh * (emb_size + 1)
74-
size_wK = n_heads * dh * (emb_size + 1)
75-
size_wV = n_heads * dh * (emb_size + 1)
78+
size_wQ = n_heads * dh * emb_size
79+
size_wK = n_heads * dh * emb_size
80+
size_wV = n_heads * dh * emb_size
7681
size_Ew = emb_size * patch_size * patch_size * d
7782
size_Eb = emb_size
7883
size_U = N * N * d * n_patches * n_heads * dh
7984

8085
total_size = size_wQ + size_wK + size_wV + size_Ew + size_Eb + size_U
8186
return total_size
8287
end
83-
Lux.statelength(::attention) = 9
88+
Lux.statelength(::attention) = 11
8489

8590
# This is what each layer does:
8691
# expected input shape: [N, N, d, batch]
@@ -93,52 +98,44 @@ function ((;)::attention)(x, params, state)
9398
dh = state.dh
9499
sqrtDh = state.sqrtDh
95100
n_heads = state.n_heads
101+
num_patches_1d = state.num_patches_1d
96102

97103
Ew = params.Ew
98104
Eb = params.Eb
99105
wQ = params.wQ
100106
wK = params.wK
101107
wV = params.wV
102108
U = params.U
109+
pos_emb = params.pos_emb
110+
111+
batch = size(x, ndims(x))
103112

104113
# (1) Split the image into patches
105-
num_patches = div(N, ps)
106-
#The subarray of x here is by default a copy, but it can be a view (its not edited)
107-
x_patches = [
108-
@view(x[(i*ps+1):(i*ps+ps), (j*ps+1):(j*ps+ps), :, :]) for
109-
i = 0:(num_patches-1), j = 0:(num_patches-1)
110-
]
114+
x_patches = reshape(x, ps, num_patches_1d, ps, num_patches_1d, d, batch)
115+
x_patches = permutedims(x_patches, (1, 3, 5, 2, 4, 6))
111116
# (2) flatten the patches
112-
# reshape is fine and will not create a copy here, as only the first dims are merged, and because julia
113-
# is column order, this does not change the shape of the underlying data, this is true for all following reshapes
114-
x_pflat = [reshape(p, ps * ps * d, size(p, ndims(p))) for p in x_patches]
115-
117+
x_patches = reshape(x_patches, ps * ps * d, :)
116118
# (3) project the patches onto the embedding space
117-
x_emb = [Ew * p .+ Eb for p in x_pflat]
119+
x_emb = Ew * x_patches .+ Eb
120+
x_emb = reshape(x_emb, size(x_emb, 1), num_patches_1d, num_patches_1d, batch)
118121

119-
# (4) positional embedding
122+
# (4) add the positional embedding
120123
# notice that we use 1D positional embedding, as suggested [here](https://arxiv.org/pdf/2010.11929)
121-
x_lemb = [
122-
cat(p, ones(state.T, 1, size(p)[2:end]...) * i; dims = 1) for
123-
(i, p) in enumerate(x_emb)
124-
]
124+
x_lemb = x_emb .+ pos_emb
125+
x_lemb = reshape(x_lemb, size(x_lemb, 1), num_patches_1d * num_patches_1d, batch)
125126

126127
# (5) compute the attention scores
127-
# [!] notice that you can not reuse some variable names otherwise Zygote gets confused
128-
Q0 = [wQ[i, :, :] * x_lemb[patchi] for i = 1:n_heads, patchi = 1:np]
129-
K0 = [wK[i, :, :] * x_lemb[patchi] for i = 1:n_heads, patchi = 1:np]
130-
V0 = [wV[i, :, :] * x_lemb[patchi] for i = 1:n_heads, patchi = 1:np]
131-
# Reshape Q, K, V to match desired output dimensions
132-
Q = reshape(vcat(Q0...), (n_heads, np, dh, size(x, ndims(x))))
133-
K = reshape(vcat(K0...), (n_heads, np, dh, size(x, ndims(x))))
134-
V = reshape(vcat(V0...), (n_heads, np, dh, size(x, ndims(x))))
135-
# (6) Compute attention scores without mutations
136-
A = [Lux.softmax(Q[i, p, :, :] .* K[i, p, :, :] / sqrtDh) for i = 1:n_heads, p = 1:np]
137-
A = reshape(vcat(A...), (n_heads, np, dh, size(x, ndims(x))))
138-
SA = A .* V
128+
Q = compute_QKV(x_lemb, wQ)
129+
K = compute_QKV(x_lemb, wK)
130+
V = compute_QKV(x_lemb, wV)
131+
132+
# (6) Compute attention scores
133+
A = attention_weights(Q, K)
134+
A = Lux.softmax(A / sqrtDh, dims = 3)
135+
A = attention_scores(A, V)
139136

140137
# (7) multihead attention
141-
MSA = reshape(SA, n_heads * np * dh, size(x, ndims(x)))
138+
MSA = reshape(A, n_heads * np * dh, size(x, ndims(x)))
142139
MSA = U * MSA
143140
MSA = reshape(MSA, size(x)...)
144141

0 commit comments

Comments
 (0)