Skip to content

Commit 1b62ad0

Browse files
committed
support for LLMBasic (mlx-swift-examples)
- ml-explore/mlx-swift-examples#454 - fixes #27 - move ChatSession integration tests into new test target so we can more easily control when it runs - make a ChatSession _unit_ (more or less) test - fix Sendable / thread safety issues uncovered by LLMBasic - collect TestTokenizer and friends in its own file. fix warnings in tests - UserInputProcessors -> structs
1 parent 27a2f21 commit 1b62ad0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+927
-577
lines changed

Libraries/Embedders/Pooling.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import Foundation
44
import MLX
5-
import MLXLinalg
65
import MLXNN
76

87
public struct PoolingConfiguration: Codable {

Libraries/Embedders/Qwen3.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import Foundation
44
import MLX
5-
import MLXFast
65
import MLXLMCommon
76
import MLXNN
87

Libraries/MLXLLM/Models/AfMoE.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import Foundation
1010
import MLX
11-
import MLXFast
1211
import MLXLMCommon
1312
import MLXNN
1413

Libraries/MLXLLM/Models/BaichuanM1.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77

88
import Foundation
99
import MLX
10-
import MLXFast
1110
import MLXLMCommon
1211
import MLXNN
13-
import MLXRandom
1412

1513
public struct BaichuanM1Configuration: Codable, Sendable {
1614
var vocabularySize: Int

Libraries/MLXLLM/Models/Bitnet.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import Foundation
99
import MLX
10-
import MLXFast
1110
import MLXLMCommon
1211
import MLXNN
1312
import Tokenizers
@@ -55,15 +54,15 @@ private func makeBitLinearKernel() -> MLXFast.MLXFastKernel {
5554
}
5655
"""
5756

58-
return metalKernel(
57+
return MLXFast.metalKernel(
5958
name: "bitlinear_matmul",
6059
inputNames: ["x", "packed_weights", "weight_scale"],
6160
outputNames: ["out"],
6261
source: source
6362
)
6463
}
6564

66-
final class BitLinearKernelManager: @unchecked Sendable {
65+
private final class BitLinearKernelManager: Sendable {
6766
static let shared = BitLinearKernelManager()
6867

6968
let bitlinearKernel: MLXFast.MLXFastKernel

Libraries/MLXLLM/Models/DeepseekV3.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import Foundation
44
import MLX
5-
import MLXFast
65
import MLXLMCommon
76
import MLXNN
87

Libraries/MLXLLM/Models/Exaone4.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import Foundation
99
import MLX
10-
import MLXFast
1110
import MLXLMCommon
1211
import MLXNN
1312

Libraries/MLXLLM/Models/GPTOSS.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77

88
import Foundation
99
import MLX
10-
import MLXFast
1110
import MLXLMCommon
1211
import MLXNN
13-
import MLXRandom
1412

1513
// MARK: - Configuration
1614

Libraries/MLXLLM/Models/Gemma3Text.swift

Lines changed: 66 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import Foundation
1111
import MLX
12-
import MLXFast
1312
import MLXLMCommon
1413
import MLXNN
1514

@@ -23,12 +22,40 @@ public struct Gemma3TextConfiguration: Codable {
2322
let rmsNormEps: Float
2423
let vocabularySize: Int
2524
let kvHeads: Int
26-
let ropeGlobalBaseFreq: Float
25+
let ropeTheta: Float
2726
let ropeLocalBaseFreq: Float
2827
let ropeTraditional: Bool
2928
let queryPreAttnScalar: Float
3029
let slidingWindow: Int
3130
let slidingWindowPattern: Int
31+
let maxPositionEmbeddings: Int
32+
let ropeScaling: [String: StringOrNumber]?
33+
34+
public init(
35+
modelType: String, hiddenSize: Int, hiddenLayers: Int, intermediateSize: Int,
36+
attentionHeads: Int, headDim: Int, rmsNormEps: Float, vocabularySize: Int, kvHeads: Int,
37+
ropeTheta: Float, ropeLocalBaseFreq: Float, ropeTraditional: Bool,
38+
queryPreAttnScalar: Float, slidingWindow: Int, slidingWindowPattern: Int,
39+
maxPositionEmbeddings: Int, ropeScaling: [String: StringOrNumber]? = nil
40+
) {
41+
self.modelType = modelType
42+
self.hiddenSize = hiddenSize
43+
self.hiddenLayers = hiddenLayers
44+
self.intermediateSize = intermediateSize
45+
self.attentionHeads = attentionHeads
46+
self.headDim = headDim
47+
self.rmsNormEps = rmsNormEps
48+
self.vocabularySize = vocabularySize
49+
self.kvHeads = kvHeads
50+
self.ropeTheta = ropeTheta
51+
self.ropeLocalBaseFreq = ropeLocalBaseFreq
52+
self.ropeTraditional = ropeTraditional
53+
self.queryPreAttnScalar = queryPreAttnScalar
54+
self.slidingWindow = slidingWindow
55+
self.slidingWindowPattern = slidingWindowPattern
56+
self.maxPositionEmbeddings = maxPositionEmbeddings
57+
self.ropeScaling = ropeScaling
58+
}
3259

3360
enum CodingKeys: String, CodingKey {
3461
case modelType = "model_type"
@@ -40,12 +67,14 @@ public struct Gemma3TextConfiguration: Codable {
4067
case rmsNormEps = "rms_norm_eps"
4168
case vocabularySize = "vocab_size"
4269
case kvHeads = "num_key_value_heads"
43-
case ropeGlobalBaseFreq = "rope_global_base_freq"
70+
case ropeTheta = "rope_theta"
4471
case ropeLocalBaseFreq = "rope_local_base_freq"
4572
case ropeTraditional = "rope_traditional"
4673
case queryPreAttnScalar = "query_pre_attn_scalar"
4774
case slidingWindow = "sliding_window"
4875
case slidingWindowPattern = "sliding_window_pattern"
76+
case maxPositionEmbeddings = "max_position_embeddings"
77+
case ropeScaling = "rope_scaling"
4978
}
5079

5180
enum VLMCodingKeys: String, CodingKey {
@@ -65,16 +94,17 @@ public struct Gemma3TextConfiguration: Codable {
6594
}
6695

6796
modelType = try container.decode(String.self, forKey: .modelType)
68-
hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
69-
hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers)
70-
intermediateSize = try container.decode(Int.self, forKey: .intermediateSize)
97+
hiddenSize = try container.decodeIfPresent(Int.self, forKey: .hiddenSize) ?? 1152
98+
hiddenLayers = try container.decodeIfPresent(Int.self, forKey: .hiddenLayers) ?? 26
99+
intermediateSize =
100+
try container.decodeIfPresent(Int.self, forKey: .intermediateSize) ?? 6912
71101
attentionHeads = try container.decodeIfPresent(Int.self, forKey: .attentionHeads) ?? 4
72102
headDim = try container.decodeIfPresent(Int.self, forKey: .headDim) ?? 256
73103
rmsNormEps = try container.decodeIfPresent(Float.self, forKey: .rmsNormEps) ?? 1.0e-6
74104
vocabularySize = try container.decodeIfPresent(Int.self, forKey: .vocabularySize) ?? 262144
75105
kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? 1
76-
ropeGlobalBaseFreq =
77-
try container.decodeIfPresent(Float.self, forKey: .ropeGlobalBaseFreq) ?? 1_000_000.0
106+
ropeTheta =
107+
try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 1_000_000.0
78108
ropeLocalBaseFreq =
79109
try container.decodeIfPresent(Float.self, forKey: .ropeLocalBaseFreq) ?? 10_000.0
80110
ropeTraditional =
@@ -84,6 +114,10 @@ public struct Gemma3TextConfiguration: Codable {
84114
slidingWindow = try container.decodeIfPresent(Int.self, forKey: .slidingWindow) ?? 512
85115
slidingWindowPattern =
86116
try container.decodeIfPresent(Int.self, forKey: .slidingWindowPattern) ?? 6
117+
maxPositionEmbeddings =
118+
try container.decodeIfPresent(Int.self, forKey: .maxPositionEmbeddings) ?? 32768
119+
ropeScaling =
120+
try container.decodeIfPresent([String: StringOrNumber].self, forKey: .ropeScaling)
87121
}
88122
}
89123

@@ -106,7 +140,7 @@ class Gemma3Attention: Module {
106140
@ModuleInfo(key: "q_norm") var queryNorm: Gemma.RMSNorm
107141
@ModuleInfo(key: "k_norm") var keyNorm: Gemma.RMSNorm
108142

109-
@ModuleInfo var rope: RoPE
143+
@ModuleInfo var rope: OffsetLayer
110144

111145
init(_ config: Gemma3TextConfiguration, layerIdx: Int) {
112146
let dim = config.hiddenSize
@@ -131,12 +165,16 @@ class Gemma3Attention: Module {
131165

132166
self.isSliding = (layerIdx + 1) % config.slidingWindowPattern != 0
133167

134-
let baseFreq = isSliding ? config.ropeLocalBaseFreq : config.ropeGlobalBaseFreq
135-
self._rope.wrappedValue = RoPE(
136-
dimensions: headDim,
137-
traditional: config.ropeTraditional,
138-
base: baseFreq
139-
)
168+
if isSliding {
169+
self.rope = initializeRope(
170+
dims: headDim, base: config.ropeLocalBaseFreq, traditional: false,
171+
scalingConfig: nil, maxPositionEmbeddings: nil)
172+
} else {
173+
self.rope = initializeRope(
174+
dims: headDim, base: config.ropeTheta, traditional: false,
175+
scalingConfig: config.ropeScaling,
176+
maxPositionEmbeddings: config.maxPositionEmbeddings)
177+
}
140178

141179
super.init()
142180
}
@@ -163,18 +201,8 @@ class Gemma3Attention: Module {
163201
queries = rope(queries, offset: cache.offset)
164202
keys = rope(keys, offset: cache.offset)
165203
} else {
166-
queries = rope(queries)
167-
keys = rope(keys)
168-
}
169-
170-
// Sliding window masking
171-
var finalMask = mask
172-
if case .array(let maskArray) = mask {
173-
let keySeqLen = keys.shape[2]
174-
if maskArray.shape.last! != keySeqLen {
175-
let slicedMask = maskArray[.ellipsis, (-keySeqLen)...]
176-
finalMask = .array(slicedMask)
177-
}
204+
queries = rope(queries, offset: 0)
205+
keys = rope(keys, offset: 0)
178206
}
179207

180208
let output = attentionWithCacheUpdate(
@@ -183,7 +211,7 @@ class Gemma3Attention: Module {
183211
values: values,
184212
cache: cache,
185213
scale: scale,
186-
mask: finalMask
214+
mask: mask
187215
)
188216
.transposed(0, 2, 1, 3)
189217
.reshaped(B, L, -1)
@@ -296,30 +324,19 @@ public class Gemma3Model: Module {
296324
if layerCache == nil {
297325
layerCache = Array(repeating: nil as KVCache?, count: layers.count)
298326
}
299-
// Create attention masks
300-
var fullMask: MLXFast.ScaledDotProductAttentionMaskMode = .none
301-
var slidingWindowMask: MLXFast.ScaledDotProductAttentionMaskMode = .none
302-
if mask == nil {
303-
let j = config.slidingWindowPattern
304-
let globalCache: KVCache? =
305-
(j > 0 && j <= (layerCache?.count ?? 0)) ? layerCache?[j - 1] : nil
306-
fullMask = createAttentionMask(h: h, cache: globalCache)
307-
let slidingCache: KVCache? = layerCache?.first ?? nil
308-
slidingWindowMask = createAttentionMask(
309-
h: h, cache: slidingCache, windowSize: config.slidingWindow)
310-
}
311-
for (i, layer) in layers.enumerated() {
312-
let isGlobal = (i % config.slidingWindowPattern == config.slidingWindowPattern - 1)
313327

314-
let localMask: MLXFast.ScaledDotProductAttentionMaskMode
315-
if let mask {
316-
localMask = mask
317-
} else if isGlobal {
318-
localMask = fullMask
328+
let globalMask = createAttentionMask(h: h, cache: cache?[config.slidingWindowPattern - 1])
329+
let slidingWindowMask =
330+
if config.slidingWindowPattern > 1 {
331+
createAttentionMask(h: h, cache: cache?[0], windowSize: config.slidingWindow)
319332
} else {
320-
localMask = slidingWindowMask
333+
MLXFast.ScaledDotProductAttentionMaskMode.none
321334
}
322-
h = layer(h, mask: localMask, cache: layerCache?[i])
335+
336+
for (i, layer) in layers.enumerated() {
337+
let isGlobal = (i % config.slidingWindowPattern == config.slidingWindowPattern - 1)
338+
let mask = isGlobal ? globalMask : slidingWindowMask
339+
h = layer(h, mask: mask, cache: layerCache?[i])
323340
}
324341
return norm(h)
325342
}

Libraries/MLXLLM/Models/Gemma3nText.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import Foundation
1111
import MLX
12-
import MLXFast
1312
import MLXLMCommon
1413
import MLXNN
1514

0 commit comments

Comments
 (0)