Skip to content

Commit 1a5bbcf

Browse files
committed
fix gemma3 + attention mask
- see #27 - a port of ml-explore/mlx-lm#463 (happened after the initial port to swift) - in support of ml-explore/mlx-swift-examples#454
1 parent 7f3e6aa commit 1a5bbcf

File tree

4 files changed

+238
-213
lines changed

4 files changed

+238
-213
lines changed

Libraries/MLXLLM/Models/Gemma3Text.swift

Lines changed: 66 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,40 @@ public struct Gemma3TextConfiguration: Codable {
2222
let rmsNormEps: Float
2323
let vocabularySize: Int
2424
let kvHeads: Int
25-
let ropeGlobalBaseFreq: Float
25+
let ropeTheta: Float
2626
let ropeLocalBaseFreq: Float
2727
let ropeTraditional: Bool
2828
let queryPreAttnScalar: Float
2929
let slidingWindow: Int
3030
let slidingWindowPattern: Int
31+
let maxPositionEmbeddings: Int
32+
let ropeScaling: [String: StringOrNumber]?
33+
34+
public init(
35+
modelType: String, hiddenSize: Int, hiddenLayers: Int, intermediateSize: Int,
36+
attentionHeads: Int, headDim: Int, rmsNormEps: Float, vocabularySize: Int, kvHeads: Int,
37+
ropeTheta: Float, ropeLocalBaseFreq: Float, ropeTraditional: Bool,
38+
queryPreAttnScalar: Float, slidingWindow: Int, slidingWindowPattern: Int,
39+
maxPositionEmbeddings: Int, ropeScaling: [String: StringOrNumber]? = nil
40+
) {
41+
self.modelType = modelType
42+
self.hiddenSize = hiddenSize
43+
self.hiddenLayers = hiddenLayers
44+
self.intermediateSize = intermediateSize
45+
self.attentionHeads = attentionHeads
46+
self.headDim = headDim
47+
self.rmsNormEps = rmsNormEps
48+
self.vocabularySize = vocabularySize
49+
self.kvHeads = kvHeads
50+
self.ropeTheta = ropeTheta
51+
self.ropeLocalBaseFreq = ropeLocalBaseFreq
52+
self.ropeTraditional = ropeTraditional
53+
self.queryPreAttnScalar = queryPreAttnScalar
54+
self.slidingWindow = slidingWindow
55+
self.slidingWindowPattern = slidingWindowPattern
56+
self.maxPositionEmbeddings = maxPositionEmbeddings
57+
self.ropeScaling = ropeScaling
58+
}
3159

3260
enum CodingKeys: String, CodingKey {
3361
case modelType = "model_type"
@@ -39,12 +67,14 @@ public struct Gemma3TextConfiguration: Codable {
3967
case rmsNormEps = "rms_norm_eps"
4068
case vocabularySize = "vocab_size"
4169
case kvHeads = "num_key_value_heads"
42-
case ropeGlobalBaseFreq = "rope_global_base_freq"
70+
case ropeTheta = "rope_theta"
4371
case ropeLocalBaseFreq = "rope_local_base_freq"
4472
case ropeTraditional = "rope_traditional"
4573
case queryPreAttnScalar = "query_pre_attn_scalar"
4674
case slidingWindow = "sliding_window"
4775
case slidingWindowPattern = "sliding_window_pattern"
76+
case maxPositionEmbeddings = "max_position_embeddings"
77+
case ropeScaling = "rope_scaling"
4878
}
4979

5080
enum VLMCodingKeys: String, CodingKey {
@@ -64,16 +94,17 @@ public struct Gemma3TextConfiguration: Codable {
6494
}
6595

6696
modelType = try container.decode(String.self, forKey: .modelType)
67-
hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
68-
hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers)
69-
intermediateSize = try container.decode(Int.self, forKey: .intermediateSize)
97+
hiddenSize = try container.decodeIfPresent(Int.self, forKey: .hiddenSize) ?? 1152
98+
hiddenLayers = try container.decodeIfPresent(Int.self, forKey: .hiddenLayers) ?? 26
99+
intermediateSize =
100+
try container.decodeIfPresent(Int.self, forKey: .intermediateSize) ?? 6912
70101
attentionHeads = try container.decodeIfPresent(Int.self, forKey: .attentionHeads) ?? 4
71102
headDim = try container.decodeIfPresent(Int.self, forKey: .headDim) ?? 256
72103
rmsNormEps = try container.decodeIfPresent(Float.self, forKey: .rmsNormEps) ?? 1.0e-6
73104
vocabularySize = try container.decodeIfPresent(Int.self, forKey: .vocabularySize) ?? 262144
74105
kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? 1
75-
ropeGlobalBaseFreq =
76-
try container.decodeIfPresent(Float.self, forKey: .ropeGlobalBaseFreq) ?? 1_000_000.0
106+
ropeTheta =
107+
try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 1_000_000.0
77108
ropeLocalBaseFreq =
78109
try container.decodeIfPresent(Float.self, forKey: .ropeLocalBaseFreq) ?? 10_000.0
79110
ropeTraditional =
@@ -83,6 +114,10 @@ public struct Gemma3TextConfiguration: Codable {
83114
slidingWindow = try container.decodeIfPresent(Int.self, forKey: .slidingWindow) ?? 512
84115
slidingWindowPattern =
85116
try container.decodeIfPresent(Int.self, forKey: .slidingWindowPattern) ?? 6
117+
maxPositionEmbeddings =
118+
try container.decodeIfPresent(Int.self, forKey: .maxPositionEmbeddings) ?? 32768
119+
ropeScaling =
120+
try container.decodeIfPresent([String: StringOrNumber].self, forKey: .ropeScaling)
86121
}
87122
}
88123

@@ -105,7 +140,7 @@ class Gemma3Attention: Module {
105140
@ModuleInfo(key: "q_norm") var queryNorm: Gemma.RMSNorm
106141
@ModuleInfo(key: "k_norm") var keyNorm: Gemma.RMSNorm
107142

108-
@ModuleInfo var rope: RoPE
143+
@ModuleInfo var rope: OffsetLayer
109144

110145
init(_ config: Gemma3TextConfiguration, layerIdx: Int) {
111146
let dim = config.hiddenSize
@@ -130,12 +165,16 @@ class Gemma3Attention: Module {
130165

131166
self.isSliding = (layerIdx + 1) % config.slidingWindowPattern != 0
132167

133-
let baseFreq = isSliding ? config.ropeLocalBaseFreq : config.ropeGlobalBaseFreq
134-
self._rope.wrappedValue = RoPE(
135-
dimensions: headDim,
136-
traditional: config.ropeTraditional,
137-
base: baseFreq
138-
)
168+
if isSliding {
169+
self.rope = initializeRope(
170+
dims: headDim, base: config.ropeLocalBaseFreq, traditional: false,
171+
scalingConfig: nil, maxPositionEmbeddings: nil)
172+
} else {
173+
self.rope = initializeRope(
174+
dims: headDim, base: config.ropeTheta, traditional: false,
175+
scalingConfig: config.ropeScaling,
176+
maxPositionEmbeddings: config.maxPositionEmbeddings)
177+
}
139178

140179
super.init()
141180
}
@@ -162,18 +201,8 @@ class Gemma3Attention: Module {
162201
queries = rope(queries, offset: cache.offset)
163202
keys = rope(keys, offset: cache.offset)
164203
} else {
165-
queries = rope(queries)
166-
keys = rope(keys)
167-
}
168-
169-
// Sliding window masking
170-
var finalMask = mask
171-
if case .array(let maskArray) = mask {
172-
let keySeqLen = keys.shape[2]
173-
if maskArray.shape.last! != keySeqLen {
174-
let slicedMask = maskArray[.ellipsis, (-keySeqLen)...]
175-
finalMask = .array(slicedMask)
176-
}
204+
queries = rope(queries, offset: 0)
205+
keys = rope(keys, offset: 0)
177206
}
178207

179208
let output = attentionWithCacheUpdate(
@@ -182,7 +211,7 @@ class Gemma3Attention: Module {
182211
values: values,
183212
cache: cache,
184213
scale: scale,
185-
mask: finalMask
214+
mask: mask
186215
)
187216
.transposed(0, 2, 1, 3)
188217
.reshaped(B, L, -1)
@@ -295,30 +324,19 @@ public class Gemma3Model: Module {
295324
if layerCache == nil {
296325
layerCache = Array(repeating: nil as KVCache?, count: layers.count)
297326
}
298-
// Create attention masks
299-
var fullMask: MLXFast.ScaledDotProductAttentionMaskMode = .none
300-
var slidingWindowMask: MLXFast.ScaledDotProductAttentionMaskMode = .none
301-
if mask == nil {
302-
let j = config.slidingWindowPattern
303-
let globalCache: KVCache? =
304-
(j > 0 && j <= (layerCache?.count ?? 0)) ? layerCache?[j - 1] : nil
305-
fullMask = createAttentionMask(h: h, cache: globalCache)
306-
let slidingCache: KVCache? = layerCache?.first ?? nil
307-
slidingWindowMask = createAttentionMask(
308-
h: h, cache: slidingCache, windowSize: config.slidingWindow)
309-
}
310-
for (i, layer) in layers.enumerated() {
311-
let isGlobal = (i % config.slidingWindowPattern == config.slidingWindowPattern - 1)
312327

313-
let localMask: MLXFast.ScaledDotProductAttentionMaskMode
314-
if let mask {
315-
localMask = mask
316-
} else if isGlobal {
317-
localMask = fullMask
328+
let globalMask = createAttentionMask(h: h, cache: cache?[config.slidingWindowPattern - 1])
329+
let slidingWindowMask =
330+
if config.slidingWindowPattern > 1 {
331+
createAttentionMask(h: h, cache: cache?[0], windowSize: config.slidingWindow)
318332
} else {
319-
localMask = slidingWindowMask
333+
MLXFast.ScaledDotProductAttentionMaskMode.none
320334
}
321-
h = layer(h, mask: localMask, cache: layerCache?[i])
335+
336+
for (i, layer) in layers.enumerated() {
337+
let isGlobal = (i % config.slidingWindowPattern == config.slidingWindowPattern - 1)
338+
let mask = isGlobal ? globalMask : slidingWindowMask
339+
h = layer(h, mask: mask, cache: layerCache?[i])
322340
}
323341
return norm(h)
324342
}

Libraries/MLXVLM/Models/Gemma3.swift

Lines changed: 24 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public struct Gemma3TextConfiguration: Codable, Sendable {
4747
_queryPreAttnScalar ?? 256
4848
}
4949

50-
public let ropeGlobalBaseFreq: Float = 1_000_000.0
50+
public let ropeTheta: Float = 1_000_000.0
5151
public let ropeLocalBaseFreq: Float = 10_000.0
5252
public let ropeTraditional: Bool = false
5353
public let mmTokensPerImage: Int = 256
@@ -151,7 +151,7 @@ private class Attention: Module {
151151
@ModuleInfo(key: "q_norm") var queryNorm: Gemma.RMSNorm
152152
@ModuleInfo(key: "k_norm") var keyNorm: Gemma.RMSNorm
153153

154-
@ModuleInfo var rope: RoPE
154+
@ModuleInfo var rope: OffsetLayer
155155

156156
init(config: Gemma3TextConfiguration, layerIdx: Int) {
157157
let dim = config.hiddenSize
@@ -175,12 +175,16 @@ private class Attention: Module {
175175
// Gemma3 uses sliding window attention pattern
176176
self.isSliding = (layerIdx + 1) % config.slidingWindowPattern != 0
177177

178-
let baseFreq = isSliding ? config.ropeLocalBaseFreq : config.ropeGlobalBaseFreq
179-
self._rope.wrappedValue = RoPE(
180-
dimensions: headDim,
181-
traditional: config.ropeTraditional,
182-
base: baseFreq
183-
)
178+
if isSliding {
179+
self.rope = initializeRope(
180+
dims: headDim, base: config.ropeLocalBaseFreq, traditional: false,
181+
scalingConfig: nil, maxPositionEmbeddings: nil)
182+
} else {
183+
self.rope = initializeRope(
184+
dims: headDim, base: config.ropeTheta, traditional: false,
185+
scalingConfig: config.ropeScaling,
186+
maxPositionEmbeddings: config.maxPositionEmbeddings)
187+
}
184188
}
185189

186190
func callAsFunction(
@@ -208,30 +212,20 @@ private class Attention: Module {
208212
queries = rope(queries, offset: cache.offset)
209213
keys = rope(keys, offset: cache.offset)
210214
} else {
211-
queries = rope(queries)
212-
keys = rope(keys)
213-
}
214-
215-
// Handle sliding window masking
216-
var finalMask = mask
217-
if case .array(let maskArray) = mask, maskArray.shape.last! != keys.shape[2] {
218-
let keyLen = keys.shape[2]
219-
let slicedMask = maskArray[.ellipsis, (-keyLen)...]
220-
finalMask = .array(slicedMask)
215+
queries = rope(queries, offset: 0)
216+
keys = rope(keys, offset: 0)
221217
}
222218

223-
// Scaled dot-product attention with native GQA support
224219
let output = attentionWithCacheUpdate(
225220
queries: queries,
226221
keys: keys,
227222
values: values,
228223
cache: cache,
229224
scale: scale,
230-
mask: finalMask
225+
mask: mask
231226
)
232227
.transposed(0, 2, 1, 3)
233228
.reshaped(B, L, -1)
234-
235229
return outputProj(output)
236230
}
237231
}
@@ -346,36 +340,19 @@ private class GemmaModel: Module {
346340
layerCache = Array(repeating: nil as KVCache?, count: layers.count)
347341
}
348342

349-
// Create attention masks for global and sliding window layers
350-
var fullMask: MLXFast.ScaledDotProductAttentionMaskMode = .none
351-
var slidingWindowMask: MLXFast.ScaledDotProductAttentionMaskMode = .none
352-
353-
if mask == nil {
354-
let j = config.slidingWindowPattern
355-
if j > 0 && j <= layerCache!.count {
356-
let globalCache = layerCache?[j - 1]
357-
fullMask = createAttentionMask(h: h, cache: globalCache)
343+
let globalMask = createAttentionMask(h: h, cache: cache?[config.slidingWindowPattern - 1])
344+
let slidingWindowMask =
345+
if config.slidingWindowPattern > 1 {
346+
createAttentionMask(h: h, cache: cache?[0], windowSize: config.slidingWindow)
347+
} else {
348+
MLXFast.ScaledDotProductAttentionMaskMode.none
358349
}
359-
let slidingCache = layerCache?.first ?? nil
360-
slidingWindowMask = createAttentionMask(
361-
h: h, cache: slidingCache, windowSize: config.slidingWindow)
362-
}
363350

364351
for (i, layer) in layers.enumerated() {
365352
let isGlobal = (i % config.slidingWindowPattern == config.slidingWindowPattern - 1)
366-
367-
let localMask: MLXFast.ScaledDotProductAttentionMaskMode
368-
if let mask {
369-
localMask = mask
370-
} else if isGlobal {
371-
localMask = fullMask
372-
} else {
373-
localMask = slidingWindowMask
374-
}
375-
376-
h = layer(h, mask: localMask, cache: layerCache?[i])
353+
let mask = isGlobal ? globalMask : slidingWindowMask
354+
h = layer(h, mask: mask, cache: layerCache?[i])
377355
}
378-
379356
return norm(h)
380357
}
381358
}
@@ -1053,7 +1030,7 @@ public class Gemma3: Module, VLMModel, KVCacheDimensionProvider {
10531030
}
10541031
}
10551032

1056-
public class Gemma3Processor: UserInputProcessor {
1033+
public struct Gemma3Processor: UserInputProcessor {
10571034
private let config: Gemma3ProcessorConfiguration
10581035
private let tokenizer: any Tokenizer
10591036

0 commit comments

Comments
 (0)