@@ -20,14 +20,15 @@ import MLX
2020/// model operations.
2121public final class ChatSession {
2222
23- private enum Model {
24- case container( ModelContainer )
25- case context( ModelContext )
23+ enum Cache {
24+ case empty
25+ case kvcache( [ KVCache ] )
26+ case history( [ Chat . Message ] )
2627 }
2728
28- private let model : Model
29- private var messages : [ Chat . Message ]
30- private var cache : [ KVCache ]
29+ private let model : ModelContainer
30+ private let instructions : String ?
31+ private let cache : SerialAccessContainer < Cache >
3132 private let processing : UserInput . Processing
3233 private let generateParameters : GenerateParameters
3334 private let additionalContext : [ String : any Sendable ] ?
@@ -47,9 +48,9 @@ public final class ChatSession {
4748 processing: UserInput . Processing = . init( resize: CGSize ( width: 512 , height: 512 ) ) ,
4849 additionalContext: [ String : any Sendable ] ? = nil
4950 ) {
50- self . model = . container ( model)
51- self . messages = instructions. map { [ . system ( $0 ) ] } ?? [ ]
52- self . cache = [ ]
51+ self . model = model
52+ self . instructions = instructions
53+ self . cache = . init ( . empty )
5354 self . processing = processing
5455 self . generateParameters = generateParameters
5556 self . additionalContext = additionalContext
@@ -70,9 +71,9 @@ public final class ChatSession {
7071 processing: UserInput . Processing = . init( resize: CGSize ( width: 512 , height: 512 ) ) ,
7172 additionalContext: [ String : any Sendable ] ? = nil
7273 ) {
73- self . model = . context ( model)
74- self . messages = instructions. map { [ . system ( $0 ) ] } ?? [ ]
75- self . cache = [ ]
74+ self . model = ModelContainer ( context : model)
75+ self . instructions = instructions
76+ self . cache = . init ( . empty )
7677 self . processing = processing
7778 self . generateParameters = generateParameters
7879 self . additionalContext = additionalContext
@@ -88,21 +89,20 @@ public final class ChatSession {
8889 /// - generateParameters: parameters that control generation
8990 /// - processing: media processing configuration for images/videos
9091 /// - additionalContext: optional model-specific context
91- public convenience init (
92+ public init (
9293 _ model: ModelContainer ,
93- history: [ Chat . Message ] ,
94+ instructions: String ? = nil ,
95+ history: consuming [ Chat . Message ] ,
9496 generateParameters: GenerateParameters = . init( ) ,
9597 processing: UserInput . Processing = . init( resize: CGSize ( width: 512 , height: 512 ) ) ,
9698 additionalContext: [ String : any Sendable ] ? = nil
9799 ) {
98- self . init (
99- model,
100- instructions: nil ,
101- generateParameters: generateParameters,
102- processing: processing,
103- additionalContext: additionalContext
104- )
105- self . messages = history
100+ self . model = model
101+ self . instructions = instructions
102+ self . cache = . init( . history( history) )
103+ self . processing = processing
104+ self . generateParameters = generateParameters
105+ self . additionalContext = additionalContext
106106 }
107107
108108 /// Initialize the `ChatSession` with an existing message history.
@@ -115,21 +115,20 @@ public final class ChatSession {
115115 /// - generateParameters: parameters that control generation
116116 /// - processing: media processing configuration for images/videos
117117 /// - additionalContext: optional model-specific context
118- public convenience init (
118+ public init (
119119 _ model: ModelContext ,
120+ instructions: String ? = nil ,
120121 history: [ Chat . Message ] ,
121122 generateParameters: GenerateParameters = . init( ) ,
122123 processing: UserInput . Processing = . init( resize: CGSize ( width: 512 , height: 512 ) ) ,
123124 additionalContext: [ String : any Sendable ] ? = nil
124125 ) {
125- self . init (
126- model,
127- instructions: nil ,
128- generateParameters: generateParameters,
129- processing: processing,
130- additionalContext: additionalContext
131- )
132- self . messages = history
126+ self . model = ModelContainer ( context: model)
127+ self . instructions = instructions
128+ self . cache = . init( . history( history) )
129+ self . processing = processing
130+ self . generateParameters = generateParameters
131+ self . additionalContext = additionalContext
133132 }
134133
135134 /// Produces a response to a prompt.
@@ -141,45 +140,13 @@ public final class ChatSession {
141140 /// - Returns: the model's response
142141 public func respond(
143142 to prompt: String ,
144- images: [ UserInput . Image ] ,
145- videos: [ UserInput . Video ]
143+ images: consuming [ UserInput . Image ] ,
144+ videos: consuming [ UserInput . Video ]
146145 ) async throws -> String {
147- messages. append ( . user( prompt, images: images, videos: videos) )
148-
149- func generate( context: ModelContext ) async throws -> String {
150- let userInput = UserInput (
151- chat: messages, processing: processing, additionalContext: additionalContext)
152- let input = try await context. processor. prepare ( input: userInput)
153-
154- if cache. isEmpty {
155- cache = context. model. newCache ( parameters: generateParameters)
156- }
157-
158- var output = " "
159- for await generation in try MLXLMCommon . generate (
160- input: input, cache: cache, parameters: generateParameters, context: context
161- ) {
162- if let chunk = generation. chunk {
163- output += chunk
164- }
165- }
166-
167- Stream ( ) . synchronize ( )
168-
169- return output
170- }
171-
172- let output : String
173- switch model {
174- case . container( let container) :
175- output = try await container. perform { context in
176- try await generate ( context: context)
177- }
178- case . context( let context) :
179- output = try await generate ( context: context)
146+ var output = " "
147+ for try await chunk in streamResponse ( to: prompt, images: images, videos: videos) {
148+ output += chunk
180149 }
181-
182- messages. append ( . assistant( output) )
183150 return output
184151 }
185152
@@ -211,16 +178,105 @@ public final class ChatSession {
211178 /// - Returns: a stream of string chunks from the model
212179 public func streamResponse(
213180 to prompt: String ,
214- images: [ UserInput . Image ] ,
215- videos: [ UserInput . Video ]
181+ images: consuming [ UserInput . Image ] ,
182+ videos: consuming [ UserInput . Video ]
216183 ) -> AsyncThrowingStream < String , Error > {
217- messages. append ( . user( prompt, images: images, videos: videos) )
218-
219184 let ( stream, continuation) = AsyncThrowingStream < String , Error > . makeStream ( )
220185
186+ // images and videos are not Sendable (MLXArray) but they are consumed
187+ // and are only being sent to the inner async
188+ let message = SendableBox < Chat . Message > (
189+ . user( prompt, images: images, videos: videos)
190+ )
191+
221192 let task = Task {
193+ [
194+ model,
195+ instructions, processing, additionalContext, cache, generateParameters
196+ ] in
222197 do {
223- try await self . performStreaming ( continuation: continuation)
198+ try await cache. update { cache in
199+
200+ // these are all Sendable
201+ let processor = await model. processor
202+ let tokenizer = await model. tokenizer
203+ let modelConfiguration = await model. configuration
204+
205+ var messages : [ Chat . Message ] = [ ]
206+ if let instructions {
207+ messages. append ( . system( instructions) )
208+ }
209+
210+ // prepare the cache, if needed. note:
211+ // this is using the LanguageModel (not Sendable) outside
212+ // the protective lock. Assuming the weights are not
213+ // being mutated behind the scenes, this will obey the MLXArray
214+ // contract that they be evaluated if used across threads.
215+ // This is internal to the implementation and this technique
216+ // should not be used in calling code.
217+ //
218+ // The benefit is that callers can be running multiple
219+ // ChatSessions in parallel, as long as the instances
220+ // are distinct. In particular the KVCache cannot
221+ // be shared and that is the lock that is held here.
222+
223+ let model = await model. perform { context in
224+ SendableBox ( context. model)
225+ } . consume ( )
226+
227+ var kvCache : [ KVCache ]
228+ switch cache {
229+ case . empty:
230+ kvCache = model. newCache ( parameters: generateParameters)
231+ cache = . kvcache( kvCache)
232+
233+ case . kvcache( let array) :
234+ kvCache = array
235+
236+ case . history( let history) :
237+ // the KVCache is represented by a chat history
238+ kvCache = model. newCache ( parameters: generateParameters)
239+ cache = . kvcache( kvCache)
240+ messages. append ( contentsOf: history)
241+ }
242+
243+ // prepare the input
244+ messages. append ( message. consume ( ) )
245+
246+ let userInput = UserInput (
247+ chat: messages, processing: processing,
248+ additionalContext: additionalContext)
249+ let input = try await processor. prepare ( input: userInput)
250+
251+ // generate output
252+ let iterator = try TokenIterator (
253+ input: input, model: model, cache: kvCache,
254+ parameters: generateParameters)
255+
256+ let ( stream, task) = MLXLMCommon . generateTask (
257+ promptTokenCount: input. text. tokens. size,
258+ modelConfiguration: modelConfiguration,
259+ tokenizer: tokenizer,
260+ iterator: iterator
261+ )
262+
263+ var fullResponse = " "
264+ for await item in stream {
265+ if let chunk = item. chunk {
266+ fullResponse += chunk
267+ if case . terminated = continuation. yield ( chunk) {
268+ break
269+ }
270+ }
271+ }
272+
273+ // wait for the task to complete -- this is important in
274+ // the case where we broke the loop early as the generation
275+ // work may continue (briefly) and use the KVCache
276+ await task. value
277+
278+ continuation. finish ( )
279+ }
224280 } catch {
225281 continuation. finish ( throwing: error)
226282 }
@@ -253,48 +309,17 @@ public final class ChatSession {
253309 }
254310
255311 /// Clear the session history and cache, preserving system instructions.
256- public func clear( ) {
257- messages = messages. filter { $0. role == . system }
258- cache = [ ]
259- }
260-
261- // MARK: - Private
262-
263- private func performStreaming(
264- continuation: AsyncThrowingStream < String , Error > . Continuation
265- ) async throws {
266- func stream( context: ModelContext ) async throws {
267- let userInput = UserInput (
268- chat: messages, processing: processing, additionalContext: additionalContext)
269- let input = try await context. processor. prepare ( input: userInput)
270-
271- if cache. isEmpty {
272- cache = context. model. newCache ( parameters: generateParameters)
273- }
274-
275- var fullResponse = " "
276- for await item in try MLXLMCommon . generate (
277- input: input, cache: cache, parameters: generateParameters, context: context
278- ) {
279- if let chunk = item. chunk {
280- fullResponse += chunk
281- continuation. yield ( chunk)
282- }
283- }
284-
285- Stream ( ) . synchronize ( )
286-
287- messages. append ( . assistant( fullResponse) )
288- continuation. finish ( )
312+ public func clear( ) async {
313+ await cache. update { cache in
314+ cache = . empty
289315 }
316+ }
290317
291- switch model {
292- case . container( let container) :
293- try await container. perform { context in
294- try await stream ( context: context)
295- }
296- case . context( let context) :
297- try await stream ( context: context)
298- }
318+ /// Wait for exclusive access to the KVCache.
319+ ///
320+ /// This is useful for cases where a program is terminating and wants to ensure that any
321+ /// async operations are complete.
322+ public func synchronize( ) async {
323+ await cache. read { _ in }
299324 }
300325}
0 commit comments