@@ -47,9 +47,12 @@ function Lux.initialparameters(
4747 Ew = init_weight (rng, T, emb_size, patch_size * patch_size * d),
4848 Eb = zeros (T, emb_size),
4949 # then the multihead attention output matrix
50- U = init_weight (rng, T, N * N * d, n_patches * n_heads * dh),
51- # and the positional embedding
50+ # U = init_weight(rng, T, N * N * d, n_patches * n_heads * dh),
51+ U = init_weight (rng, T, emb_size, emb_size), # i.e., 128 × 128
52+ # the positional embedding
5253 pos_emb = init_weight (rng, T, emb_size, div (N, patch_size), div (N, patch_size)),
54+ # and a final decoder
55+ dec = init_weight (rng, T, patch_size * patch_size * d, emb_size), # (2738, 128)
5356 )
5457end
5558
@@ -78,12 +81,15 @@ function Lux.parameterlength(
7881 size_wV = n_heads * dh * emb_size
7982 size_Ew = emb_size * patch_size * patch_size * d
8083 size_Eb = emb_size
81- size_U = N * N * d * n_patches * n_heads * dh
82-
83- total_size = size_wQ + size_wK + size_wV + size_Ew + size_Eb + size_U
84+ # size_U = N * N * d * n_patches * n_heads * dh
85+ size_U = emb_size * emb_size
86+ size_dec = patch_size * patch_size * d * emb_size
87+ size_pos_emb = emb_size * div (N, patch_size) * div (N, patch_size)
88+ total_size =
89+ size_wQ + size_wK + size_wV + size_Ew + size_Eb + size_U + size_dec + size_pos_emb
8490 return total_size
8591end
86- Lux. statelength (:: attention ) = 11
92+ Lux. statelength (:: attention ) = 12
8793
8894# This is what each layer does:
8995# expected input shape: [N, N, d, batch]
@@ -97,6 +103,7 @@ function ((;)::attention)(x, params, state)
97103 sqrtDh = state. sqrtDh
98104 n_heads = state. n_heads
99105 num_patches_1d = state. num_patches_1d
106+ emb_size = state. emb_size
100107
101108 Ew = params. Ew
102109 Eb = params. Eb
@@ -105,6 +112,7 @@ function ((;)::attention)(x, params, state)
105112 wV = params. wV
106113 U = params. U
107114 pos_emb = params. pos_emb
115+ dec = params. dec
108116
109117 batch = size (x, ndims (x))
110118
@@ -133,10 +141,55 @@ function ((;)::attention)(x, params, state)
133141 A = attention_scores (A, V)
134142
135143 # (7) multihead attention
136- MSA = reshape (A, n_heads * np * dh, size (x, ndims (x)))
137- MSA = U * MSA
138- MSA = reshape (MSA, size (x)... )
144+ # MSA = reshape(A, n_heads * np * dh, size(x, ndims(x)))
145+ # MSA = U * MSA
146+ # MSA = reshape(MSA, size(x)...)
147+
148+
149+ # A = reshape(A, n_heads * dh, np, batch) # (emb_size, np, batch)
150+ # A_flat = reshape(A, n_heads * dh, :)
151+ # MSA = U * A_flat # U ∈ (emb_size, emb_size)
152+ # MSA = reshape(MSA, n_heads * dh, np, batch)
153+ # @info "***********************"
154+ # @info "x shape: $(size(x))"
155+ # @info "MSA size: $(size(MSA))"
156+
157+ # # A ∈ (n_heads * dh, np, batch) == (128, 16, batch)
158+ # MSA = reshape(MSA, emb_size, np, batch)
159+
160+ # # Flatten across np × batch
161+ # A_flat = reshape(MSA, emb_size, :) # (128, 16 * batch)
162+
163+ # # Decode each patch embedding into flattened image patch
164+ # decoded_patches = dec * A_flat # (2738, 16 * batch)
165+
166+ # # Reshape back into patch layout: (ps, ps, d, np, batch)
167+ # decoded_patches = reshape(decoded_patches, ps, ps, d, np, batch)
168+
169+ # # Reshape np = 4 × 4 back into grid layout
170+ # patches_grid = reshape(decoded_patches, ps, ps, d, num_patches_1d, num_patches_1d, batch)
171+
172+ # # Reorder axes to reconstruct full image
173+ # output = permutedims(patches_grid, (1, 4, 2, 5, 3, 6)) # (ps, np1d, ps, np1d, d, batch)
174+ # output = reshape(output, N, N, d, batch) # (148, 148, 2, batch)
175+
176+ # # Attention layer does not modify state
177+ # output, state
178+
179+
180+ # (7) multihead attention
181+ # Combine reshapes and matrix multiplications
182+ A_flat = reshape (A, n_heads * dh, :) # (emb_size, np * batch)
183+ MSA = U * A_flat # Apply U (U ∈ (emb_size, emb_size)) -> (emb_size, np * batch)
184+
185+ # (8) Decode each patch and reshape directly into the final image layout
186+ output = reshape (dec * MSA, ps, ps, d, num_patches_1d, num_patches_1d, batch)
187+
188+ # (9) Reorder to reconstruct the full image
189+ output = permutedims (output, (1 , 4 , 2 , 5 , 3 , 6 )) # (ps, np1d, ps, np1d, d, batch)
190+ output = reshape (output, N, N, d, batch) # (148, 148, 2, batch)
139191
140192 # Attention layer does not modify state
141- MSA, state
193+ output, state
194+
142195end
0 commit comments