99
1010import Foundation
1111import MLX
12- import MLXFast
1312import MLXLMCommon
1413import MLXNN
1514
@@ -23,12 +22,40 @@ public struct Gemma3TextConfiguration: Codable {
2322 let rmsNormEps : Float
2423 let vocabularySize : Int
2524 let kvHeads : Int
26- let ropeGlobalBaseFreq : Float
25+ let ropeTheta : Float
2726 let ropeLocalBaseFreq : Float
2827 let ropeTraditional : Bool
2928 let queryPreAttnScalar : Float
3029 let slidingWindow : Int
3130 let slidingWindowPattern : Int
31+ let maxPositionEmbeddings : Int
32+ let ropeScaling : [ String : StringOrNumber ] ?
33+
34+ public init (
35+ modelType: String , hiddenSize: Int , hiddenLayers: Int , intermediateSize: Int ,
36+ attentionHeads: Int , headDim: Int , rmsNormEps: Float , vocabularySize: Int , kvHeads: Int ,
37+ ropeTheta: Float , ropeLocalBaseFreq: Float , ropeTraditional: Bool ,
38+ queryPreAttnScalar: Float , slidingWindow: Int , slidingWindowPattern: Int ,
39+ maxPositionEmbeddings: Int , ropeScaling: [ String : StringOrNumber ] ? = nil
40+ ) {
41+ self . modelType = modelType
42+ self . hiddenSize = hiddenSize
43+ self . hiddenLayers = hiddenLayers
44+ self . intermediateSize = intermediateSize
45+ self . attentionHeads = attentionHeads
46+ self . headDim = headDim
47+ self . rmsNormEps = rmsNormEps
48+ self . vocabularySize = vocabularySize
49+ self . kvHeads = kvHeads
50+ self . ropeTheta = ropeTheta
51+ self . ropeLocalBaseFreq = ropeLocalBaseFreq
52+ self . ropeTraditional = ropeTraditional
53+ self . queryPreAttnScalar = queryPreAttnScalar
54+ self . slidingWindow = slidingWindow
55+ self . slidingWindowPattern = slidingWindowPattern
56+ self . maxPositionEmbeddings = maxPositionEmbeddings
57+ self . ropeScaling = ropeScaling
58+ }
3259
3360 enum CodingKeys : String , CodingKey {
3461 case modelType = " model_type "
@@ -40,12 +67,14 @@ public struct Gemma3TextConfiguration: Codable {
4067 case rmsNormEps = " rms_norm_eps "
4168 case vocabularySize = " vocab_size "
4269 case kvHeads = " num_key_value_heads "
43- case ropeGlobalBaseFreq = " rope_global_base_freq "
70+ case ropeTheta = " rope_theta "
4471 case ropeLocalBaseFreq = " rope_local_base_freq "
4572 case ropeTraditional = " rope_traditional "
4673 case queryPreAttnScalar = " query_pre_attn_scalar "
4774 case slidingWindow = " sliding_window "
4875 case slidingWindowPattern = " sliding_window_pattern "
76+ case maxPositionEmbeddings = " max_position_embeddings "
77+ case ropeScaling = " rope_scaling "
4978 }
5079
5180 enum VLMCodingKeys : String , CodingKey {
@@ -65,16 +94,17 @@ public struct Gemma3TextConfiguration: Codable {
6594 }
6695
6796 modelType = try container. decode ( String . self, forKey: . modelType)
68- hiddenSize = try container. decode ( Int . self, forKey: . hiddenSize)
69- hiddenLayers = try container. decode ( Int . self, forKey: . hiddenLayers)
70- intermediateSize = try container. decode ( Int . self, forKey: . intermediateSize)
97+ hiddenSize = try container. decodeIfPresent ( Int . self, forKey: . hiddenSize) ?? 1152
98+ hiddenLayers = try container. decodeIfPresent ( Int . self, forKey: . hiddenLayers) ?? 26
99+ intermediateSize =
100+ try container. decodeIfPresent ( Int . self, forKey: . intermediateSize) ?? 6912
71101 attentionHeads = try container. decodeIfPresent ( Int . self, forKey: . attentionHeads) ?? 4
72102 headDim = try container. decodeIfPresent ( Int . self, forKey: . headDim) ?? 256
73103 rmsNormEps = try container. decodeIfPresent ( Float . self, forKey: . rmsNormEps) ?? 1.0e-6
74104 vocabularySize = try container. decodeIfPresent ( Int . self, forKey: . vocabularySize) ?? 262144
75105 kvHeads = try container. decodeIfPresent ( Int . self, forKey: . kvHeads) ?? 1
76- ropeGlobalBaseFreq =
77- try container. decodeIfPresent ( Float . self, forKey: . ropeGlobalBaseFreq ) ?? 1_000_000.0
106+ ropeTheta =
107+ try container. decodeIfPresent ( Float . self, forKey: . ropeTheta ) ?? 1_000_000.0
78108 ropeLocalBaseFreq =
79109 try container. decodeIfPresent ( Float . self, forKey: . ropeLocalBaseFreq) ?? 10_000.0
80110 ropeTraditional =
@@ -84,6 +114,10 @@ public struct Gemma3TextConfiguration: Codable {
84114 slidingWindow = try container. decodeIfPresent ( Int . self, forKey: . slidingWindow) ?? 512
85115 slidingWindowPattern =
86116 try container. decodeIfPresent ( Int . self, forKey: . slidingWindowPattern) ?? 6
117+ maxPositionEmbeddings =
118+ try container. decodeIfPresent ( Int . self, forKey: . maxPositionEmbeddings) ?? 32768
119+ ropeScaling =
120+ try container. decodeIfPresent ( [ String : StringOrNumber ] . self, forKey: . ropeScaling)
87121 }
88122}
89123
@@ -106,7 +140,7 @@ class Gemma3Attention: Module {
106140 @ModuleInfo ( key: " q_norm " ) var queryNorm : Gemma . RMSNorm
107141 @ModuleInfo ( key: " k_norm " ) var keyNorm : Gemma . RMSNorm
108142
109- @ModuleInfo var rope : RoPE
143+ @ModuleInfo var rope : OffsetLayer
110144
111145 init ( _ config: Gemma3TextConfiguration , layerIdx: Int ) {
112146 let dim = config. hiddenSize
@@ -131,12 +165,16 @@ class Gemma3Attention: Module {
131165
132166 self . isSliding = ( layerIdx + 1 ) % config. slidingWindowPattern != 0
133167
134- let baseFreq = isSliding ? config. ropeLocalBaseFreq : config. ropeGlobalBaseFreq
135- self . _rope. wrappedValue = RoPE (
136- dimensions: headDim,
137- traditional: config. ropeTraditional,
138- base: baseFreq
139- )
168+ if isSliding {
169+ self . rope = initializeRope (
170+ dims: headDim, base: config. ropeLocalBaseFreq, traditional: false ,
171+ scalingConfig: nil , maxPositionEmbeddings: nil )
172+ } else {
173+ self . rope = initializeRope (
174+ dims: headDim, base: config. ropeTheta, traditional: false ,
175+ scalingConfig: config. ropeScaling,
176+ maxPositionEmbeddings: config. maxPositionEmbeddings)
177+ }
140178
141179 super. init ( )
142180 }
@@ -163,18 +201,8 @@ class Gemma3Attention: Module {
163201 queries = rope ( queries, offset: cache. offset)
164202 keys = rope ( keys, offset: cache. offset)
165203 } else {
166- queries = rope ( queries)
167- keys = rope ( keys)
168- }
169-
170- // Sliding window masking
171- var finalMask = mask
172- if case . array( let maskArray) = mask {
173- let keySeqLen = keys. shape [ 2 ]
174- if maskArray. shape. last! != keySeqLen {
175- let slicedMask = maskArray [ . ellipsis, ( - keySeqLen) ... ]
176- finalMask = . array( slicedMask)
177- }
204+ queries = rope ( queries, offset: 0 )
205+ keys = rope ( keys, offset: 0 )
178206 }
179207
180208 let output = attentionWithCacheUpdate (
@@ -183,7 +211,7 @@ class Gemma3Attention: Module {
183211 values: values,
184212 cache: cache,
185213 scale: scale,
186- mask: finalMask
214+ mask: mask
187215 )
188216 . transposed ( 0 , 2 , 1 , 3 )
189217 . reshaped ( B, L, - 1 )
@@ -296,30 +324,19 @@ public class Gemma3Model: Module {
296324 if layerCache == nil {
297325 layerCache = Array ( repeating: nil as KVCache ? , count: layers. count)
298326 }
299- // Create attention masks
300- var fullMask : MLXFast . ScaledDotProductAttentionMaskMode = . none
301- var slidingWindowMask : MLXFast . ScaledDotProductAttentionMaskMode = . none
302- if mask == nil {
303- let j = config. slidingWindowPattern
304- let globalCache : KVCache ? =
305- ( j > 0 && j <= ( layerCache? . count ?? 0 ) ) ? layerCache ? [ j - 1 ] : nil
306- fullMask = createAttentionMask ( h: h, cache: globalCache)
307- let slidingCache : KVCache ? = layerCache? . first ?? nil
308- slidingWindowMask = createAttentionMask (
309- h: h, cache: slidingCache, windowSize: config. slidingWindow)
310- }
311- for (i, layer) in layers. enumerated ( ) {
312- let isGlobal = ( i % config. slidingWindowPattern == config. slidingWindowPattern - 1 )
313327
314- let localMask : MLXFast . ScaledDotProductAttentionMaskMode
315- if let mask {
316- localMask = mask
317- } else if isGlobal {
318- localMask = fullMask
328+ let globalMask = createAttentionMask ( h: h, cache: cache ? [ config. slidingWindowPattern - 1 ] )
329+ let slidingWindowMask =
330+ if config. slidingWindowPattern > 1 {
331+ createAttentionMask ( h: h, cache: cache ? [ 0 ] , windowSize: config. slidingWindow)
319332 } else {
320- localMask = slidingWindowMask
333+ MLXFast . ScaledDotProductAttentionMaskMode . none
321334 }
322- h = layer ( h, mask: localMask, cache: layerCache ? [ i] )
335+
336+ for (i, layer) in layers. enumerated ( ) {
337+ let isGlobal = ( i % config. slidingWindowPattern == config. slidingWindowPattern - 1 )
338+ let mask = isGlobal ? globalMask : slidingWindowMask
339+ h = layer ( h, mask: mask, cache: layerCache ? [ i] )
323340 }
324341 return norm ( h)
325342 }
0 commit comments