@@ -22,12 +22,40 @@ public struct Gemma3TextConfiguration: Codable {
2222 let rmsNormEps : Float
2323 let vocabularySize : Int
2424 let kvHeads : Int
25- let ropeGlobalBaseFreq : Float
25+ let ropeTheta : Float
2626 let ropeLocalBaseFreq : Float
2727 let ropeTraditional : Bool
2828 let queryPreAttnScalar : Float
2929 let slidingWindow : Int
3030 let slidingWindowPattern : Int
31+ let maxPositionEmbeddings : Int
32+ let ropeScaling : [ String : StringOrNumber ] ?
33+
34+ public init (
35+ modelType: String , hiddenSize: Int , hiddenLayers: Int , intermediateSize: Int ,
36+ attentionHeads: Int , headDim: Int , rmsNormEps: Float , vocabularySize: Int , kvHeads: Int ,
37+ ropeTheta: Float , ropeLocalBaseFreq: Float , ropeTraditional: Bool ,
38+ queryPreAttnScalar: Float , slidingWindow: Int , slidingWindowPattern: Int ,
39+ maxPositionEmbeddings: Int , ropeScaling: [ String : StringOrNumber ] ? = nil
40+ ) {
41+ self . modelType = modelType
42+ self . hiddenSize = hiddenSize
43+ self . hiddenLayers = hiddenLayers
44+ self . intermediateSize = intermediateSize
45+ self . attentionHeads = attentionHeads
46+ self . headDim = headDim
47+ self . rmsNormEps = rmsNormEps
48+ self . vocabularySize = vocabularySize
49+ self . kvHeads = kvHeads
50+ self . ropeTheta = ropeTheta
51+ self . ropeLocalBaseFreq = ropeLocalBaseFreq
52+ self . ropeTraditional = ropeTraditional
53+ self . queryPreAttnScalar = queryPreAttnScalar
54+ self . slidingWindow = slidingWindow
55+ self . slidingWindowPattern = slidingWindowPattern
56+ self . maxPositionEmbeddings = maxPositionEmbeddings
57+ self . ropeScaling = ropeScaling
58+ }
3159
3260 enum CodingKeys : String , CodingKey {
3361 case modelType = " model_type "
@@ -39,12 +67,14 @@ public struct Gemma3TextConfiguration: Codable {
3967 case rmsNormEps = " rms_norm_eps "
4068 case vocabularySize = " vocab_size "
4169 case kvHeads = " num_key_value_heads "
42- case ropeGlobalBaseFreq = " rope_global_base_freq "
70+ case ropeTheta = " rope_theta "
4371 case ropeLocalBaseFreq = " rope_local_base_freq "
4472 case ropeTraditional = " rope_traditional "
4573 case queryPreAttnScalar = " query_pre_attn_scalar "
4674 case slidingWindow = " sliding_window "
4775 case slidingWindowPattern = " sliding_window_pattern "
76+ case maxPositionEmbeddings = " max_position_embeddings "
77+ case ropeScaling = " rope_scaling "
4878 }
4979
5080 enum VLMCodingKeys : String , CodingKey {
@@ -64,16 +94,17 @@ public struct Gemma3TextConfiguration: Codable {
6494 }
6595
6696 modelType = try container. decode ( String . self, forKey: . modelType)
67- hiddenSize = try container. decode ( Int . self, forKey: . hiddenSize)
68- hiddenLayers = try container. decode ( Int . self, forKey: . hiddenLayers)
69- intermediateSize = try container. decode ( Int . self, forKey: . intermediateSize)
97+ hiddenSize = try container. decodeIfPresent ( Int . self, forKey: . hiddenSize) ?? 1152
98+ hiddenLayers = try container. decodeIfPresent ( Int . self, forKey: . hiddenLayers) ?? 26
99+ intermediateSize =
100+ try container. decodeIfPresent ( Int . self, forKey: . intermediateSize) ?? 6912
70101 attentionHeads = try container. decodeIfPresent ( Int . self, forKey: . attentionHeads) ?? 4
71102 headDim = try container. decodeIfPresent ( Int . self, forKey: . headDim) ?? 256
72103 rmsNormEps = try container. decodeIfPresent ( Float . self, forKey: . rmsNormEps) ?? 1.0e-6
73104 vocabularySize = try container. decodeIfPresent ( Int . self, forKey: . vocabularySize) ?? 262144
74105 kvHeads = try container. decodeIfPresent ( Int . self, forKey: . kvHeads) ?? 1
75- ropeGlobalBaseFreq =
76- try container. decodeIfPresent ( Float . self, forKey: . ropeGlobalBaseFreq ) ?? 1_000_000.0
106+ ropeTheta =
107+ try container. decodeIfPresent ( Float . self, forKey: . ropeTheta ) ?? 1_000_000.0
77108 ropeLocalBaseFreq =
78109 try container. decodeIfPresent ( Float . self, forKey: . ropeLocalBaseFreq) ?? 10_000.0
79110 ropeTraditional =
@@ -83,6 +114,10 @@ public struct Gemma3TextConfiguration: Codable {
83114 slidingWindow = try container. decodeIfPresent ( Int . self, forKey: . slidingWindow) ?? 512
84115 slidingWindowPattern =
85116 try container. decodeIfPresent ( Int . self, forKey: . slidingWindowPattern) ?? 6
117+ maxPositionEmbeddings =
118+ try container. decodeIfPresent ( Int . self, forKey: . maxPositionEmbeddings) ?? 32768
119+ ropeScaling =
120+ try container. decodeIfPresent ( [ String : StringOrNumber ] . self, forKey: . ropeScaling)
86121 }
87122}
88123
@@ -105,7 +140,7 @@ class Gemma3Attention: Module {
105140 @ModuleInfo ( key: " q_norm " ) var queryNorm : Gemma . RMSNorm
106141 @ModuleInfo ( key: " k_norm " ) var keyNorm : Gemma . RMSNorm
107142
108- @ModuleInfo var rope : RoPE
143+ @ModuleInfo var rope : OffsetLayer
109144
110145 init ( _ config: Gemma3TextConfiguration , layerIdx: Int ) {
111146 let dim = config. hiddenSize
@@ -130,12 +165,16 @@ class Gemma3Attention: Module {
130165
131166 self . isSliding = ( layerIdx + 1 ) % config. slidingWindowPattern != 0
132167
133- let baseFreq = isSliding ? config. ropeLocalBaseFreq : config. ropeGlobalBaseFreq
134- self . _rope. wrappedValue = RoPE (
135- dimensions: headDim,
136- traditional: config. ropeTraditional,
137- base: baseFreq
138- )
168+ if isSliding {
169+ self . rope = initializeRope (
170+ dims: headDim, base: config. ropeLocalBaseFreq, traditional: false ,
171+ scalingConfig: nil , maxPositionEmbeddings: nil )
172+ } else {
173+ self . rope = initializeRope (
174+ dims: headDim, base: config. ropeTheta, traditional: false ,
175+ scalingConfig: config. ropeScaling,
176+ maxPositionEmbeddings: config. maxPositionEmbeddings)
177+ }
139178
140179 super. init ( )
141180 }
@@ -162,18 +201,8 @@ class Gemma3Attention: Module {
162201 queries = rope ( queries, offset: cache. offset)
163202 keys = rope ( keys, offset: cache. offset)
164203 } else {
165- queries = rope ( queries)
166- keys = rope ( keys)
167- }
168-
169- // Sliding window masking
170- var finalMask = mask
171- if case . array( let maskArray) = mask {
172- let keySeqLen = keys. shape [ 2 ]
173- if maskArray. shape. last! != keySeqLen {
174- let slicedMask = maskArray [ . ellipsis, ( - keySeqLen) ... ]
175- finalMask = . array( slicedMask)
176- }
204+ queries = rope ( queries, offset: 0 )
205+ keys = rope ( keys, offset: 0 )
177206 }
178207
179208 let output = attentionWithCacheUpdate (
@@ -182,7 +211,7 @@ class Gemma3Attention: Module {
182211 values: values,
183212 cache: cache,
184213 scale: scale,
185- mask: finalMask
214+ mask: mask
186215 )
187216 . transposed ( 0 , 2 , 1 , 3 )
188217 . reshaped ( B, L, - 1 )
@@ -295,30 +324,19 @@ public class Gemma3Model: Module {
295324 if layerCache == nil {
296325 layerCache = Array ( repeating: nil as KVCache ? , count: layers. count)
297326 }
298- // Create attention masks
299- var fullMask : MLXFast . ScaledDotProductAttentionMaskMode = . none
300- var slidingWindowMask : MLXFast . ScaledDotProductAttentionMaskMode = . none
301- if mask == nil {
302- let j = config. slidingWindowPattern
303- let globalCache : KVCache ? =
304- ( j > 0 && j <= ( layerCache? . count ?? 0 ) ) ? layerCache ? [ j - 1 ] : nil
305- fullMask = createAttentionMask ( h: h, cache: globalCache)
306- let slidingCache : KVCache ? = layerCache? . first ?? nil
307- slidingWindowMask = createAttentionMask (
308- h: h, cache: slidingCache, windowSize: config. slidingWindow)
309- }
310- for (i, layer) in layers. enumerated ( ) {
311- let isGlobal = ( i % config. slidingWindowPattern == config. slidingWindowPattern - 1 )
312327
313- let localMask : MLXFast . ScaledDotProductAttentionMaskMode
314- if let mask {
315- localMask = mask
316- } else if isGlobal {
317- localMask = fullMask
328+ let globalMask = createAttentionMask ( h: h, cache: cache ? [ config. slidingWindowPattern - 1 ] )
329+ let slidingWindowMask =
330+ if config. slidingWindowPattern > 1 {
331+ createAttentionMask ( h: h, cache: cache ? [ 0 ] , windowSize: config. slidingWindow)
318332 } else {
319- localMask = slidingWindowMask
333+ MLXFast . ScaledDotProductAttentionMaskMode . none
320334 }
321- h = layer ( h, mask: localMask, cache: layerCache ? [ i] )
335+
336+ for (i, layer) in layers. enumerated ( ) {
337+ let isGlobal = ( i % config. slidingWindowPattern == config. slidingWindowPattern - 1 )
338+ let mask = isGlobal ? globalMask : slidingWindowMask
339+ h = layer ( h, mask: mask, cache: layerCache ? [ i] )
322340 }
323341 return norm ( h)
324342 }
0 commit comments