Skip to content

Commit cfee2aa

Browse files
authored
Reformat U (#30)
Replaced the large parameter-heavy matrix U ∈ (N² × d, n_heads × np × dh) with a smaller linear decoder that maps each patch embedding (128-dim) to its corresponding flattened image patch (2738-dim). This significantly reduces parameter count and improves efficiency while preserving the ability to reconstruct the full image from attention output.
1 parent 0fc3760 commit cfee2aa

File tree

8 files changed

+428
-14
lines changed

8 files changed

+428
-14
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,6 @@ coverage
1212
docs/build/
1313
env
1414
node_modules
15+
16+
# Large file
17+
test/data_train_real.jld2

src/layer.jl

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,12 @@ function Lux.initialparameters(
4747
Ew = init_weight(rng, T, emb_size, patch_size * patch_size * d),
4848
Eb = zeros(T, emb_size),
4949
# then the multihead attention output matrix
50-
U = init_weight(rng, T, N * N * d, n_patches * n_heads * dh),
51-
# and the positional embedding
50+
#U = init_weight(rng, T, N * N * d, n_patches * n_heads * dh),
51+
U = init_weight(rng, T, emb_size, emb_size), # i.e., 128 × 128
52+
# the positional embedding
5253
pos_emb = init_weight(rng, T, emb_size, div(N, patch_size), div(N, patch_size)),
54+
# and a final decoder
55+
dec = init_weight(rng, T, patch_size * patch_size * d, emb_size), # (2738, 128)
5356
)
5457
end
5558

@@ -78,12 +81,15 @@ function Lux.parameterlength(
7881
size_wV = n_heads * dh * emb_size
7982
size_Ew = emb_size * patch_size * patch_size * d
8083
size_Eb = emb_size
81-
size_U = N * N * d * n_patches * n_heads * dh
82-
83-
total_size = size_wQ + size_wK + size_wV + size_Ew + size_Eb + size_U
84+
#size_U = N * N * d * n_patches * n_heads * dh
85+
size_U = emb_size * emb_size
86+
size_dec = patch_size * patch_size * d * emb_size
87+
size_pos_emb = emb_size * div(N, patch_size) * div(N, patch_size)
88+
total_size =
89+
size_wQ + size_wK + size_wV + size_Ew + size_Eb + size_U + size_dec + size_pos_emb
8490
return total_size
8591
end
86-
Lux.statelength(::attention) = 11
92+
Lux.statelength(::attention) = 12
8793

8894
# This is what each layer does:
8995
# expected input shape: [N, N, d, batch]
@@ -97,6 +103,7 @@ function ((;)::attention)(x, params, state)
97103
sqrtDh = state.sqrtDh
98104
n_heads = state.n_heads
99105
num_patches_1d = state.num_patches_1d
106+
emb_size = state.emb_size
100107

101108
Ew = params.Ew
102109
Eb = params.Eb
@@ -105,6 +112,7 @@ function ((;)::attention)(x, params, state)
105112
wV = params.wV
106113
U = params.U
107114
pos_emb = params.pos_emb
115+
dec = params.dec
108116

109117
batch = size(x, ndims(x))
110118

@@ -133,10 +141,55 @@ function ((;)::attention)(x, params, state)
133141
A = attention_scores(A, V)
134142

135143
# (7) multihead attention
136-
MSA = reshape(A, n_heads * np * dh, size(x, ndims(x)))
137-
MSA = U * MSA
138-
MSA = reshape(MSA, size(x)...)
144+
#MSA = reshape(A, n_heads * np * dh, size(x, ndims(x)))
145+
#MSA = U * MSA
146+
#MSA = reshape(MSA, size(x)...)
147+
148+
149+
#A = reshape(A, n_heads * dh, np, batch) # (emb_size, np, batch)
150+
#A_flat = reshape(A, n_heads * dh, :)
151+
#MSA = U * A_flat # U ∈ (emb_size, emb_size)
152+
#MSA = reshape(MSA, n_heads * dh, np, batch)
153+
#@info "***********************"
154+
#@info "x shape: $(size(x))"
155+
#@info "MSA size: $(size(MSA))"
156+
157+
## A ∈ (n_heads * dh, np, batch) == (128, 16, batch)
158+
#MSA = reshape(MSA, emb_size, np, batch)
159+
160+
## Flatten across np × batch
161+
#A_flat = reshape(MSA, emb_size, :) # (128, 16 * batch)
162+
163+
## Decode each patch embedding into flattened image patch
164+
#decoded_patches = dec * A_flat # (2738, 16 * batch)
165+
166+
## Reshape back into patch layout: (ps, ps, d, np, batch)
167+
#decoded_patches = reshape(decoded_patches, ps, ps, d, np, batch)
168+
169+
## Reshape np = 4 × 4 back into grid layout
170+
#patches_grid = reshape(decoded_patches, ps, ps, d, num_patches_1d, num_patches_1d, batch)
171+
172+
## Reorder axes to reconstruct full image
173+
#output = permutedims(patches_grid, (1, 4, 2, 5, 3, 6)) # (ps, np1d, ps, np1d, d, batch)
174+
#output = reshape(output, N, N, d, batch) # (148, 148, 2, batch)
175+
176+
## Attention layer does not modify state
177+
#output, state
178+
179+
180+
# (7) multihead attention
181+
# Combine reshapes and matrix multiplications
182+
A_flat = reshape(A, n_heads * dh, :) # (emb_size, np * batch)
183+
MSA = U * A_flat # Apply U (U ∈ (emb_size, emb_size)) -> (emb_size, np * batch)
184+
185+
# (8) Decode each patch and reshape directly into the final image layout
186+
output = reshape(dec * MSA, ps, ps, d, num_patches_1d, num_patches_1d, batch)
187+
188+
# (9) Reorder to reconstruct the full image
189+
output = permutedims(output, (1, 4, 2, 5, 3, 6)) # (ps, np1d, ps, np1d, d, batch)
190+
output = reshape(output, N, N, d, batch) # (148, 148, 2, batch)
139191

140192
# Attention layer does not modify state
141-
MSA, state
193+
output, state
194+
142195
end

test/config_real.yaml

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
docreatedata: true
2+
docomp: true
3+
ntrajectory: 8
4+
T: "Float32"
5+
params:
6+
D: 2
7+
lims: [0.0, 1.0]
8+
Re: 6000.0
9+
tburn: 0.5
10+
tsim: 5.0
11+
savefreq: 50
12+
ndns: 4096
13+
nles: [128]
14+
filters: ["FaceAverage()"]
15+
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
16+
method: "RKMethods.Wray3(; T)"
17+
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
18+
issteadybodyforce: true
19+
processors: "(; log = timelogger(; nupdate=100))"
20+
Δt: 0.00005
21+
seeds:
22+
dns: 123
23+
θ_start: 234
24+
prior: 345
25+
post: 456
26+
closure:
27+
name: "att_3"
28+
type: attentioncnn
29+
radii: [2, 2, 2, 2, 2]
30+
channels: [24, 24, 24, 24, 2]
31+
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
32+
use_bias: [true, true, true, true, false]
33+
use_attention: [true, false, false, false, false]
34+
emb_sizes: [124, 124, 124, 124, 124]
35+
Ns: [148, 144, 140, 136, 132]
36+
patch_sizes: [37, 36, 35, 34, 33]
37+
n_heads: [4, 4, 4, 4, 4]
38+
sum_attention: [false, false, false, false, false]
39+
rng: "Xoshiro(seeds.θ_start)"
40+
priori:
41+
dotrain: true
42+
nepoch: 10000
43+
batchsize: 64
44+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
45+
do_plot: false
46+
plot_train: false
47+
posteriori:
48+
dotrain: true
49+
projectorders: "(ProjectOrder.Last, )"
50+
nepoch: 100
51+
opt: "OptimiserChain(Adam(T(1.0e-4)), ClipGrad(0.1))"
52+
nunroll: 5
53+
nunroll_valid: 5
54+
dt: T(1e-5)
55+
do_plot: false
56+
plot_train: false

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The file will be automatically included inside a `@testset` with title "Title Fo
1111
for (root, dirs, files) in walkdir(@__DIR__)
1212
for file in files
1313
if isnothing(match(r"^test-.*\.jl$", file))
14+
#if isnothing(match(r"^test-couplednode_real.*\.jl$", file))
1415
continue
1516
end
1617
title = titlecase(replace(splitext(file[6:end])[1], "-" => " "))

test/test-cnn_1att.jl

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
using Test
2+
using Lux
3+
using CUDA
4+
using LuxCUDA
5+
using AttentionLayer: attention, attentioncnn
6+
using ComponentArrays: ComponentArray
7+
using Random
8+
using Zygote: Zygote
9+
10+
# Define parameters for the model
11+
T = Float32
12+
N = 128
13+
D = 2
14+
batch = 5
15+
rng = Xoshiro(123)
16+
r = [2, 2, 2, 2, 2]
17+
c = [24, 24, 24, 24, 2]
18+
σ = [tanh, tanh, tanh, tanh, identity]
19+
b = [true, true, true, true, false]
20+
use_attention = [true, false, false, false, false]
21+
sum_attention = [false, false, false, false, false]
22+
Ns = reverse([N + 2 * sum(r[1:i]) for i = 1:length(r)])
23+
patch_sizes = [37, 36, 35, 34, 33]
24+
emb_sizes = [8, 8, 8, 8, 8]
25+
n_heads = [2, 2, 2, 2, 2]
26+
27+
@testset "AttentionCNN (CPU)" begin
28+
29+
# Create the model
30+
closure, θ, st = attentioncnn(
31+
T = T,
32+
D = D,
33+
data_ch = D,
34+
radii = r,
35+
channels = c,
36+
activations = σ,
37+
use_bias = b,
38+
use_attention = use_attention,
39+
emb_sizes = emb_sizes,
40+
Ns = Ns,
41+
patch_sizes = patch_sizes,
42+
n_heads = n_heads,
43+
sum_attention = sum_attention,
44+
rng = rng,
45+
use_cuda = false,
46+
)
47+
48+
@testset "Model setup" begin
49+
@test closure != nothing
50+
@test θ != nothing
51+
@test st != nothing
52+
# Test model structure
53+
@test typeof(closure) <: Lux.Chain
54+
end
55+
56+
# Define input tensor and pass through model
57+
input_tensor = rand(T, N, N, D, batch) # Example input with shape (N, N, D, batch_size)
58+
output = closure(input_tensor, θ, st)
59+
60+
@testset "Model output" begin
61+
@test output != nothing
62+
@test length(output) == 2 # Check that the output is a tuple
63+
@test isa(output[1], Array)
64+
@test size(output[1]) == (N, N, D, batch) # Check final output size
65+
end
66+
67+
@testset "AD" begin
68+
# Test Differentiability by calculating gradients
69+
grad = Zygote.gradient-> sum(abs2, closure(input_tensor, θ, st)[1]), θ)
70+
@test !isnothing(grad) # Ensure gradients were successfully computed
71+
@test sum(grad) != 0.0 # Ensure gradients are not zero
72+
73+
y, back = Zygote.pullback-> sum(abs2, closure(input_tensor, θ, st)[1]), θ)
74+
@test y sum(abs2, closure(input_tensor, θ, st)[1])
75+
y_bar = ones(T, size(y))
76+
θ_bar = back(y_bar)
77+
@test θ_bar != nothing
78+
@test sum(θ_bar) != 0.0 # Ensure gradients are not zero
79+
end
80+
81+
end
82+
83+
@testset "AttentionCNN (GPU)" begin
84+
if !CUDA.functional()
85+
@testset "CUDA not available" begin
86+
@test true
87+
end
88+
return
89+
end
90+
91+
# Create the model
92+
closure, θ, st = attentioncnn(
93+
T = T,
94+
D = D,
95+
data_ch = D,
96+
radii = r,
97+
channels = c,
98+
activations = σ,
99+
use_bias = b,
100+
use_attention = use_attention,
101+
emb_sizes = emb_sizes,
102+
Ns = Ns,
103+
patch_sizes = patch_sizes,
104+
n_heads = n_heads,
105+
sum_attention = sum_attention,
106+
rng = rng,
107+
use_cuda = true,
108+
)
109+
110+
@testset "Model setup" begin
111+
@test closure != nothing
112+
@test θ != nothing
113+
@test st != nothing
114+
# Test model structure
115+
@test typeof(closure) <: Lux.Chain
116+
end
117+
118+
# Define input tensor and pass through model
119+
input_tensor = CUDA.rand(T, N, N, D, batch) # Example input with shape (N, N, D, batch_size)
120+
output = closure(input_tensor, θ, st)
121+
122+
@testset "Model output" begin
123+
@test output != nothing
124+
@test length(output) == 2 # Check that the output is a tuple
125+
@test isa(output[1], CuArray)
126+
@test size(output[1]) == (N, N, D, batch) # Check final output size
127+
end
128+
129+
@testset "AD" begin
130+
# Test Differentiability by calculating gradients
131+
grad = Zygote.gradient-> sum(abs2, closure(input_tensor, θ, st)[1]), θ)
132+
@test !isnothing(grad) # Ensure gradients were successfully computed
133+
@test sum(grad) != 0.0 # Ensure gradients are not zero
134+
135+
y, back = Zygote.pullback-> sum(abs2, closure(input_tensor, θ, st)[1]), θ)
136+
@test y sum(abs2, closure(input_tensor, θ, st)[1])
137+
y_bar = CUDA.ones(T, size(y))
138+
θ_bar = back(y_bar)
139+
@test θ_bar != nothing
140+
@test sum(θ_bar) != 0.0 # Ensure gradients are not zero
141+
end
142+
143+
end

test/test-couplednode.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ sum_attention = [false, false]
156156
end
157157
tmp1, tmp2 = back(λ)
158158
@test size(tmp1) == (18, 18, 2)
159-
@test size(tmp2) == (94194,)
159+
@test size(tmp2) == (2266,)
160160

161161
# Final integration test of the entire train interface
162162
l, trainstate = CoupledNODE.train(
@@ -308,7 +308,7 @@ end
308308
end
309309
tmp1, tmp2 = back(λ)
310310
@test size(tmp1) == (18, 18, 2)
311-
@test size(tmp2) == (94194,)
311+
@test size(tmp2) == (2266,)
312312
@test isa(tmp1, CuArray) # Check if tmp1 is on GPU
313313

314314
# Final integration test of the entire train interface

test/test-couplednode_loader.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ end
5555
@info "CNN warm up run"
5656
u = randn(Float32, 32 + 2, 32 + 2, 2, 10) |> device
5757
θ = θ_start |> device
58-
#u = CUDA.rand(Float32, 32+2, 32+2, 2, 10)
59-
#θ = θ_start |> Lux.gpu_device()
6058
output, _ = closure(u, θ, st)
6159

6260
@test size(output) == (32 + 2, 32 + 2, 2, 10)

0 commit comments

Comments
 (0)