Skip to content

Commit 5a70788

Browse files
davidkoskironaldmannak
authored andcommitted
fix thread safety issues (ml-explore#55)
* fix thread safety issues - support for ml-explore/mlx-swift-examples#454 - ModelContainer appeared to provide thread safe access to the KVCache and model - but in fact was not -- async token generation could use the KVCache concurrently - if you were to break the async stream early the previously call could still be running * pick up mlx-swift 0.30.3 which has additional thread safety fixes
1 parent 331cb03 commit 5a70788

File tree

11 files changed

+599
-286
lines changed

11 files changed

+599
-286
lines changed

Libraries/MLXLLM/Documentation.docc/Documentation.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@ Example implementations of various Large Language Models (LLMs).
77
- [MLXEmbedders](MLXEmbedders)
88
- [MLXLLM](MLXLLM)
99
- [MLXLMCommon](MLXLMCommon)
10-
- [MLXMNIST](MLXMNIST)
1110
- [MLXVLM](MLXVLM)
12-
- [StableDiffusion](StableDiffusion)
1311

1412
## Quick Start
1513

Libraries/MLXLMCommon/ChatSession.swift

Lines changed: 177 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,18 @@ import MLX
2020
/// model operations.
2121
public final class ChatSession {
2222

23-
private enum Model {
24-
case container(ModelContainer)
25-
case context(ModelContext)
23+
enum Cache {
24+
case empty
25+
case kvcache([KVCache])
26+
case history([Chat.Message])
2627
}
2728

28-
private let model: Model
29-
private var messages: [Chat.Message]
30-
private var cache: [KVCache]
31-
private let processing: UserInput.Processing
32-
private let generateParameters: GenerateParameters
33-
private let additionalContext: [String: any Sendable]?
29+
private let model: ModelContainer
30+
public var instructions: String?
31+
private let cache: SerialAccessContainer<Cache>
32+
public var processing: UserInput.Processing
33+
public var generateParameters: GenerateParameters
34+
public var additionalContext: [String: any Sendable]?
3435

3536
/// Initialize the `ChatSession`.
3637
///
@@ -47,9 +48,9 @@ public final class ChatSession {
4748
processing: UserInput.Processing = .init(resize: CGSize(width: 512, height: 512)),
4849
additionalContext: [String: any Sendable]? = nil
4950
) {
50-
self.model = .container(model)
51-
self.messages = instructions.map { [.system($0)] } ?? []
52-
self.cache = []
51+
self.model = model
52+
self.instructions = instructions
53+
self.cache = .init(.empty)
5354
self.processing = processing
5455
self.generateParameters = generateParameters
5556
self.additionalContext = additionalContext
@@ -70,9 +71,9 @@ public final class ChatSession {
7071
processing: UserInput.Processing = .init(resize: CGSize(width: 512, height: 512)),
7172
additionalContext: [String: any Sendable]? = nil
7273
) {
73-
self.model = .context(model)
74-
self.messages = instructions.map { [.system($0)] } ?? []
75-
self.cache = []
74+
self.model = ModelContainer(context: model)
75+
self.instructions = instructions
76+
self.cache = .init(.empty)
7677
self.processing = processing
7778
self.generateParameters = generateParameters
7879
self.additionalContext = additionalContext
@@ -88,21 +89,20 @@ public final class ChatSession {
8889
/// - generateParameters: parameters that control generation
8990
/// - processing: media processing configuration for images/videos
9091
/// - additionalContext: optional model-specific context
91-
public convenience init(
92+
public init(
9293
_ model: ModelContainer,
93-
history: [Chat.Message],
94+
instructions: String? = nil,
95+
history: consuming [Chat.Message],
9496
generateParameters: GenerateParameters = .init(),
9597
processing: UserInput.Processing = .init(resize: CGSize(width: 512, height: 512)),
9698
additionalContext: [String: any Sendable]? = nil
9799
) {
98-
self.init(
99-
model,
100-
instructions: nil,
101-
generateParameters: generateParameters,
102-
processing: processing,
103-
additionalContext: additionalContext
104-
)
105-
self.messages = history
100+
self.model = model
101+
self.instructions = instructions
102+
self.cache = .init(.history(history))
103+
self.processing = processing
104+
self.generateParameters = generateParameters
105+
self.additionalContext = additionalContext
106106
}
107107

108108
/// Initialize the `ChatSession` with an existing message history.
@@ -115,21 +115,20 @@ public final class ChatSession {
115115
/// - generateParameters: parameters that control generation
116116
/// - processing: media processing configuration for images/videos
117117
/// - additionalContext: optional model-specific context
118-
public convenience init(
118+
public init(
119119
_ model: ModelContext,
120+
instructions: String? = nil,
120121
history: [Chat.Message],
121122
generateParameters: GenerateParameters = .init(),
122123
processing: UserInput.Processing = .init(resize: CGSize(width: 512, height: 512)),
123124
additionalContext: [String: any Sendable]? = nil
124125
) {
125-
self.init(
126-
model,
127-
instructions: nil,
128-
generateParameters: generateParameters,
129-
processing: processing,
130-
additionalContext: additionalContext
131-
)
132-
self.messages = history
126+
self.model = ModelContainer(context: model)
127+
self.instructions = instructions
128+
self.cache = .init(.history(history))
129+
self.processing = processing
130+
self.generateParameters = generateParameters
131+
self.additionalContext = additionalContext
133132
}
134133

135134
/// Produces a response to a prompt.
@@ -141,45 +140,13 @@ public final class ChatSession {
141140
/// - Returns: the model's response
142141
public func respond(
143142
to prompt: String,
144-
images: [UserInput.Image],
145-
videos: [UserInput.Video]
143+
images: consuming [UserInput.Image],
144+
videos: consuming [UserInput.Video]
146145
) async throws -> String {
147-
messages.append(.user(prompt, images: images, videos: videos))
148-
149-
func generate(context: ModelContext) async throws -> String {
150-
let userInput = UserInput(
151-
chat: messages, processing: processing, additionalContext: additionalContext)
152-
let input = try await context.processor.prepare(input: userInput)
153-
154-
if cache.isEmpty {
155-
cache = context.model.newCache(parameters: generateParameters)
156-
}
157-
158-
var output = ""
159-
for await generation in try MLXLMCommon.generate(
160-
input: input, cache: cache, parameters: generateParameters, context: context
161-
) {
162-
if let chunk = generation.chunk {
163-
output += chunk
164-
}
165-
}
166-
167-
Stream().synchronize()
168-
169-
return output
170-
}
171-
172-
let output: String
173-
switch model {
174-
case .container(let container):
175-
output = try await container.perform { context in
176-
try await generate(context: context)
177-
}
178-
case .context(let context):
179-
output = try await generate(context: context)
146+
var output = ""
147+
for try await chunk in streamResponse(to: prompt, images: images, videos: videos) {
148+
output += chunk
180149
}
181-
182-
messages.append(.assistant(output))
183150
return output
184151
}
185152

@@ -202,7 +169,7 @@ public final class ChatSession {
202169
)
203170
}
204171

205-
/// Produces a streaming response to a prompt.
172+
/// Produces a streaming response to a prompt as Strings.
206173
///
207174
/// - Parameters:
208175
/// - prompt: the user prompt
@@ -211,16 +178,139 @@ public final class ChatSession {
211178
/// - Returns: a stream of string chunks from the model
212179
public func streamResponse(
213180
to prompt: String,
214-
images: [UserInput.Image],
215-
videos: [UserInput.Video]
181+
images: consuming [UserInput.Image],
182+
videos: consuming [UserInput.Video]
216183
) -> AsyncThrowingStream<String, Error> {
217-
messages.append(.user(prompt, images: images, videos: videos))
184+
streamMap(to: prompt, images: images, videos: videos) {
185+
$0.chunk
186+
}
187+
}
188+
189+
/// Produces a streaming response to a prompt as `Generation`.
190+
///
191+
/// - Parameters:
192+
/// - prompt: the user prompt
193+
/// - images: list of images (for use with VLMs)
194+
/// - videos: list of videos (for use with VLMs)
195+
/// - Returns: a stream of `Generation` from the model
196+
public func streamDetails(
197+
to prompt: String,
198+
images: consuming [UserInput.Image],
199+
videos: consuming [UserInput.Video]
200+
) -> AsyncThrowingStream<Generation, Error> {
201+
streamMap(to: prompt, images: images, videos: videos) {
202+
$0
203+
}
204+
}
218205

219-
let (stream, continuation) = AsyncThrowingStream<String, Error>.makeStream()
206+
/// Produces a streaming response to a prompt by transforming the
207+
/// raw `Generation` values.
208+
///
209+
/// - Parameters:
210+
/// - prompt: the user prompt
211+
/// - images: list of images (for use with VLMs)
212+
/// - videos: list of videos (for use with VLMs)
213+
/// - Returns: a stream of transformed values from the model
214+
private func streamMap<R: Sendable>(
215+
to prompt: String,
216+
images: consuming [UserInput.Image],
217+
videos: consuming [UserInput.Video],
218+
transform: @Sendable @escaping (Generation) -> R?
219+
) -> AsyncThrowingStream<R, Error> {
220+
let (stream, continuation) = AsyncThrowingStream<R, Error>.makeStream()
221+
222+
// images and videos are not Sendable (MLXArray) but they are consumed
223+
// and are only being sent to the inner async
224+
let message = SendableBox<Chat.Message>(
225+
.user(prompt, images: images, videos: videos)
226+
)
220227

221228
let task = Task {
229+
[
230+
model,
231+
instructions, processing, additionalContext, cache, generateParameters
232+
] in
222233
do {
223-
try await self.performStreaming(continuation: continuation)
234+
try await cache.update { cache in
235+
236+
// these are all Sendable
237+
let processor = await model.processor
238+
let tokenizer = await model.tokenizer
239+
let modelConfiguration = await model.configuration
240+
241+
var messages: [Chat.Message] = []
242+
if let instructions {
243+
messages.append(.system(instructions))
244+
}
245+
246+
// prepare the cache, if needed. note:
247+
// this is using the LanguageModel (not Sendable) outside
248+
// the protective lock. Assuming the weights are not
249+
// being mutated behind the scenes, this will obey the MLXArray
250+
// contract that they be evaluated if used across threads.
251+
// This is internal to the implementation and this technique
252+
// should not be used in calling code.
253+
//
254+
// The benefit is that callers can be running multiple
255+
// ChatSessions in parallel, as long as the instances
256+
// are distinct. In particular the KVCache cannot
257+
// be shared and that is the lock that is held here.
258+
259+
let model = await model.perform { context in
260+
SendableBox(context.model)
261+
}.consume()
262+
263+
var kvCache: [KVCache]
264+
switch cache {
265+
case .empty:
266+
kvCache = model.newCache(parameters: generateParameters)
267+
cache = .kvcache(kvCache)
268+
269+
case .kvcache(let array):
270+
kvCache = array
271+
272+
case .history(let history):
273+
// the KVCache is represented by a chat history
274+
kvCache = model.newCache(parameters: generateParameters)
275+
cache = .kvcache(kvCache)
276+
messages.append(contentsOf: history)
277+
}
278+
279+
// prepare the input
280+
messages.append(message.consume())
281+
282+
let userInput = UserInput(
283+
chat: messages, processing: processing,
284+
additionalContext: additionalContext)
285+
let input = try await processor.prepare(input: userInput)
286+
287+
// generate output
288+
let iterator = try TokenIterator(
289+
input: input, model: model, cache: kvCache,
290+
parameters: generateParameters)
291+
292+
let (stream, task) = MLXLMCommon.generateTask(
293+
promptTokenCount: input.text.tokens.size,
294+
modelConfiguration: modelConfiguration,
295+
tokenizer: tokenizer,
296+
iterator: iterator
297+
)
298+
299+
for await item in stream {
300+
if let value = transform(item) {
301+
if case .terminated = continuation.yield(value) {
302+
break
303+
}
304+
}
305+
}
306+
307+
// wait for the task to complete -- this is important in
308+
// the case where we broke the loop early as the generation
309+
// work may continue (briefly) and use the KVCache
310+
await task.value
311+
312+
continuation.finish()
313+
}
224314
} catch {
225315
continuation.finish(throwing: error)
226316
}
@@ -253,48 +343,17 @@ public final class ChatSession {
253343
}
254344

255345
/// Clear the session history and cache, preserving system instructions.
256-
public func clear() {
257-
messages = messages.filter { $0.role == .system }
258-
cache = []
259-
}
260-
261-
// MARK: - Private
262-
263-
private func performStreaming(
264-
continuation: AsyncThrowingStream<String, Error>.Continuation
265-
) async throws {
266-
func stream(context: ModelContext) async throws {
267-
let userInput = UserInput(
268-
chat: messages, processing: processing, additionalContext: additionalContext)
269-
let input = try await context.processor.prepare(input: userInput)
270-
271-
if cache.isEmpty {
272-
cache = context.model.newCache(parameters: generateParameters)
273-
}
274-
275-
var fullResponse = ""
276-
for await item in try MLXLMCommon.generate(
277-
input: input, cache: cache, parameters: generateParameters, context: context
278-
) {
279-
if let chunk = item.chunk {
280-
fullResponse += chunk
281-
continuation.yield(chunk)
282-
}
283-
}
284-
285-
Stream().synchronize()
286-
287-
messages.append(.assistant(fullResponse))
288-
continuation.finish()
346+
public func clear() async {
347+
await cache.update { cache in
348+
cache = .empty
289349
}
350+
}
290351

291-
switch model {
292-
case .container(let container):
293-
try await container.perform { context in
294-
try await stream(context: context)
295-
}
296-
case .context(let context):
297-
try await stream(context: context)
298-
}
352+
/// Wait for exclusive access to the KVCache.
353+
///
354+
/// This is useful for cases where a program is terminating and wants to ensure that any
355+
/// async operations are complete.
356+
public func synchronize() async {
357+
await cache.read { _ in }
299358
}
300359
}

Libraries/MLXLMCommon/Documentation.docc/Documentation.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,5 @@ Common language model code.
77
- [MLXEmbedders](MLXEmbedders)
88
- [MLXLLM](MLXLLM)
99
- [MLXLMCommon](MLXLMCommon)
10-
- [MLXMNIST](MLXMNIST)
1110
- [MLXVLM](MLXVLM)
12-
- [StableDiffusion](StableDiffusion)
1311

0 commit comments

Comments
 (0)