Skip to content

Commit 9c2f561

Browse files
committed
fix thread safety issues
- support for ml-explore/mlx-swift-examples#454 - ModelContainer appeared to provide thread safe access to the KVCache and model - but in fact was not -- async token generation could use the KVCache concurrently - if you were to break the async stream early the previously call could still be running swift-format
1 parent a1addb4 commit 9c2f561

File tree

7 files changed

+539
-272
lines changed

7 files changed

+539
-272
lines changed

Libraries/MLXLMCommon/ChatSession.swift

Lines changed: 139 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@ import MLX
2020
/// model operations.
2121
public final class ChatSession {
2222

23-
private enum Model {
24-
case container(ModelContainer)
25-
case context(ModelContext)
23+
enum Cache {
24+
case empty
25+
case kvcache([KVCache])
26+
case history([Chat.Message])
2627
}
2728

28-
private let model: Model
29-
private var messages: [Chat.Message]
30-
private var cache: [KVCache]
29+
private let model: ModelContainer
30+
private let instructions: String?
31+
private let cache: SerialAccessContainer<Cache>
3132
private let processing: UserInput.Processing
3233
private let generateParameters: GenerateParameters
3334
private let additionalContext: [String: any Sendable]?
@@ -47,9 +48,9 @@ public final class ChatSession {
4748
processing: UserInput.Processing = .init(resize: CGSize(width: 512, height: 512)),
4849
additionalContext: [String: any Sendable]? = nil
4950
) {
50-
self.model = .container(model)
51-
self.messages = instructions.map { [.system($0)] } ?? []
52-
self.cache = []
51+
self.model = model
52+
self.instructions = instructions
53+
self.cache = .init(.empty)
5354
self.processing = processing
5455
self.generateParameters = generateParameters
5556
self.additionalContext = additionalContext
@@ -70,9 +71,9 @@ public final class ChatSession {
7071
processing: UserInput.Processing = .init(resize: CGSize(width: 512, height: 512)),
7172
additionalContext: [String: any Sendable]? = nil
7273
) {
73-
self.model = .context(model)
74-
self.messages = instructions.map { [.system($0)] } ?? []
75-
self.cache = []
74+
self.model = ModelContainer(context: model)
75+
self.instructions = instructions
76+
self.cache = .init(.empty)
7677
self.processing = processing
7778
self.generateParameters = generateParameters
7879
self.additionalContext = additionalContext
@@ -88,21 +89,20 @@ public final class ChatSession {
8889
/// - generateParameters: parameters that control generation
8990
/// - processing: media processing configuration for images/videos
9091
/// - additionalContext: optional model-specific context
91-
public convenience init(
92+
public init(
9293
_ model: ModelContainer,
93-
history: [Chat.Message],
94+
instructions: String? = nil,
95+
history: consuming [Chat.Message],
9496
generateParameters: GenerateParameters = .init(),
9597
processing: UserInput.Processing = .init(resize: CGSize(width: 512, height: 512)),
9698
additionalContext: [String: any Sendable]? = nil
9799
) {
98-
self.init(
99-
model,
100-
instructions: nil,
101-
generateParameters: generateParameters,
102-
processing: processing,
103-
additionalContext: additionalContext
104-
)
105-
self.messages = history
100+
self.model = model
101+
self.instructions = instructions
102+
self.cache = .init(.history(history))
103+
self.processing = processing
104+
self.generateParameters = generateParameters
105+
self.additionalContext = additionalContext
106106
}
107107

108108
/// Initialize the `ChatSession` with an existing message history.
@@ -115,21 +115,20 @@ public final class ChatSession {
115115
/// - generateParameters: parameters that control generation
116116
/// - processing: media processing configuration for images/videos
117117
/// - additionalContext: optional model-specific context
118-
public convenience init(
118+
public init(
119119
_ model: ModelContext,
120+
instructions: String? = nil,
120121
history: [Chat.Message],
121122
generateParameters: GenerateParameters = .init(),
122123
processing: UserInput.Processing = .init(resize: CGSize(width: 512, height: 512)),
123124
additionalContext: [String: any Sendable]? = nil
124125
) {
125-
self.init(
126-
model,
127-
instructions: nil,
128-
generateParameters: generateParameters,
129-
processing: processing,
130-
additionalContext: additionalContext
131-
)
132-
self.messages = history
126+
self.model = ModelContainer(context: model)
127+
self.instructions = instructions
128+
self.cache = .init(.history(history))
129+
self.processing = processing
130+
self.generateParameters = generateParameters
131+
self.additionalContext = additionalContext
133132
}
134133

135134
/// Produces a response to a prompt.
@@ -141,45 +140,13 @@ public final class ChatSession {
141140
/// - Returns: the model's response
142141
public func respond(
143142
to prompt: String,
144-
images: [UserInput.Image],
145-
videos: [UserInput.Video]
143+
images: consuming [UserInput.Image],
144+
videos: consuming [UserInput.Video]
146145
) async throws -> String {
147-
messages.append(.user(prompt, images: images, videos: videos))
148-
149-
func generate(context: ModelContext) async throws -> String {
150-
let userInput = UserInput(
151-
chat: messages, processing: processing, additionalContext: additionalContext)
152-
let input = try await context.processor.prepare(input: userInput)
153-
154-
if cache.isEmpty {
155-
cache = context.model.newCache(parameters: generateParameters)
156-
}
157-
158-
var output = ""
159-
for await generation in try MLXLMCommon.generate(
160-
input: input, cache: cache, parameters: generateParameters, context: context
161-
) {
162-
if let chunk = generation.chunk {
163-
output += chunk
164-
}
165-
}
166-
167-
Stream().synchronize()
168-
169-
return output
170-
}
171-
172-
let output: String
173-
switch model {
174-
case .container(let container):
175-
output = try await container.perform { context in
176-
try await generate(context: context)
177-
}
178-
case .context(let context):
179-
output = try await generate(context: context)
146+
var output = ""
147+
for try await chunk in streamResponse(to: prompt, images: images, videos: videos) {
148+
output += chunk
180149
}
181-
182-
messages.append(.assistant(output))
183150
return output
184151
}
185152

@@ -211,16 +178,105 @@ public final class ChatSession {
211178
/// - Returns: a stream of string chunks from the model
212179
public func streamResponse(
213180
to prompt: String,
214-
images: [UserInput.Image],
215-
videos: [UserInput.Video]
181+
images: consuming [UserInput.Image],
182+
videos: consuming [UserInput.Video]
216183
) -> AsyncThrowingStream<String, Error> {
217-
messages.append(.user(prompt, images: images, videos: videos))
218-
219184
let (stream, continuation) = AsyncThrowingStream<String, Error>.makeStream()
220185

186+
// images and videos are not Sendable (MLXArray) but they are consumed
187+
// and are only being sent to the inner async
188+
let message = SendableBox<Chat.Message>(
189+
.user(prompt, images: images, videos: videos)
190+
)
191+
221192
let task = Task {
193+
[
194+
model,
195+
instructions, processing, additionalContext, cache, generateParameters
196+
] in
222197
do {
223-
try await self.performStreaming(continuation: continuation)
198+
try await cache.update { cache in
199+
200+
// these are all Sendable
201+
let processor = await model.processor
202+
let tokenizer = await model.tokenizer
203+
let modelConfiguration = await model.configuration
204+
205+
var messages: [Chat.Message] = []
206+
if let instructions {
207+
messages.append(.system(instructions))
208+
}
209+
210+
// prepare the cache, if needed. note:
211+
// this is using the LanguageModel (not Sendable) outside
212+
// the protective lock. Assuming the weights are not
213+
// being mutated behind the scenes, this will obey the MLXArray
214+
// contract that they be evaluated if used across threads.
215+
// This is internal to the implementation and this technique
216+
// should not be used in calling code.
217+
//
218+
// The benefit is that callers can be running multiple
219+
// ChatSessions in parallel, as long as the instances
220+
// are distinct. In particular the KVCache cannot
221+
// be shared and that is the lock that is held here.
222+
223+
let model = await model.perform { context in
224+
SendableBox(context.model)
225+
}.consume()
226+
227+
var kvCache: [KVCache]
228+
switch cache {
229+
case .empty:
230+
kvCache = model.newCache(parameters: generateParameters)
231+
cache = .kvcache(kvCache)
232+
233+
case .kvcache(let array):
234+
kvCache = array
235+
236+
case .history(let history):
237+
// the KVCache is represented by a chat history
238+
kvCache = model.newCache(parameters: generateParameters)
239+
cache = .kvcache(kvCache)
240+
messages.append(contentsOf: history)
241+
}
242+
243+
// prepare the input
244+
messages.append(message.consume())
245+
246+
let userInput = UserInput(
247+
chat: messages, processing: processing,
248+
additionalContext: additionalContext)
249+
let input = try await processor.prepare(input: userInput)
250+
251+
// generate output
252+
let iterator = try TokenIterator(
253+
input: input, model: model, cache: kvCache,
254+
parameters: generateParameters)
255+
256+
let (stream, task) = MLXLMCommon.generateTask(
257+
promptTokenCount: input.text.tokens.size,
258+
modelConfiguration: modelConfiguration,
259+
tokenizer: tokenizer,
260+
iterator: iterator
261+
)
262+
263+
var fullResponse = ""
264+
for await item in stream {
265+
if let chunk = item.chunk {
266+
fullResponse += chunk
267+
if case .terminated = continuation.yield(chunk) {
268+
break
269+
}
270+
}
271+
}
272+
273+
// wait for the task to complete -- this is important in
274+
// the case where we broke the loop early as the generation
275+
// work may continue (briefly) and use the KVCache
276+
await task.value
277+
278+
continuation.finish()
279+
}
224280
} catch {
225281
continuation.finish(throwing: error)
226282
}
@@ -253,48 +309,17 @@ public final class ChatSession {
253309
}
254310

255311
/// Clear the session history and cache, preserving system instructions.
256-
public func clear() {
257-
messages = messages.filter { $0.role == .system }
258-
cache = []
259-
}
260-
261-
// MARK: - Private
262-
263-
private func performStreaming(
264-
continuation: AsyncThrowingStream<String, Error>.Continuation
265-
) async throws {
266-
func stream(context: ModelContext) async throws {
267-
let userInput = UserInput(
268-
chat: messages, processing: processing, additionalContext: additionalContext)
269-
let input = try await context.processor.prepare(input: userInput)
270-
271-
if cache.isEmpty {
272-
cache = context.model.newCache(parameters: generateParameters)
273-
}
274-
275-
var fullResponse = ""
276-
for await item in try MLXLMCommon.generate(
277-
input: input, cache: cache, parameters: generateParameters, context: context
278-
) {
279-
if let chunk = item.chunk {
280-
fullResponse += chunk
281-
continuation.yield(chunk)
282-
}
283-
}
284-
285-
Stream().synchronize()
286-
287-
messages.append(.assistant(fullResponse))
288-
continuation.finish()
312+
public func clear() async {
313+
await cache.update { cache in
314+
cache = .empty
289315
}
316+
}
290317

291-
switch model {
292-
case .container(let container):
293-
try await container.perform { context in
294-
try await stream(context: context)
295-
}
296-
case .context(let context):
297-
try await stream(context: context)
298-
}
318+
/// Wait for exclusive access to the KVCache.
319+
///
320+
/// This is useful for cases where a program is terminating and wants to ensure that any
321+
/// async operations are complete.
322+
public func synchronize() async {
323+
await cache.read { _ in }
299324
}
300325
}

0 commit comments

Comments
 (0)