@@ -20,17 +20,18 @@ import MLX
2020/// model operations.
2121public final class ChatSession {
2222
23- private enum Model {
24- case container( ModelContainer )
25- case context( ModelContext )
23+ enum Cache {
24+ case empty
25+ case kvcache( [ KVCache ] )
26+ case history( [ Chat . Message ] )
2627 }
2728
28- private let model : Model
29- private var messages : [ Chat . Message ]
30- private var cache : [ KVCache ]
31- private let processing : UserInput . Processing
32- private let generateParameters : GenerateParameters
33- private let additionalContext : [ String : any Sendable ] ?
29+ private let model : ModelContainer
30+ public var instructions : String ?
31+ private let cache : SerialAccessContainer < Cache >
32+ public var processing : UserInput . Processing
33+ public var generateParameters : GenerateParameters
34+ public var additionalContext : [ String : any Sendable ] ?
3435
3536 /// Initialize the `ChatSession`.
3637 ///
@@ -47,9 +48,9 @@ public final class ChatSession {
4748 processing: UserInput . Processing = . init( resize: CGSize ( width: 512 , height: 512 ) ) ,
4849 additionalContext: [ String : any Sendable ] ? = nil
4950 ) {
50- self . model = . container ( model)
51- self . messages = instructions. map { [ . system ( $0 ) ] } ?? [ ]
52- self . cache = [ ]
51+ self . model = model
52+ self . instructions = instructions
53+ self . cache = . init ( . empty )
5354 self . processing = processing
5455 self . generateParameters = generateParameters
5556 self . additionalContext = additionalContext
@@ -70,9 +71,9 @@ public final class ChatSession {
7071 processing: UserInput . Processing = . init( resize: CGSize ( width: 512 , height: 512 ) ) ,
7172 additionalContext: [ String : any Sendable ] ? = nil
7273 ) {
73- self . model = . context ( model)
74- self . messages = instructions. map { [ . system ( $0 ) ] } ?? [ ]
75- self . cache = [ ]
74+ self . model = ModelContainer ( context : model)
75+ self . instructions = instructions
76+ self . cache = . init ( . empty )
7677 self . processing = processing
7778 self . generateParameters = generateParameters
7879 self . additionalContext = additionalContext
@@ -88,21 +89,20 @@ public final class ChatSession {
8889 /// - generateParameters: parameters that control generation
8990 /// - processing: media processing configuration for images/videos
9091 /// - additionalContext: optional model-specific context
91- public convenience init (
92+ public init (
9293 _ model: ModelContainer ,
93- history: [ Chat . Message ] ,
94+ instructions: String ? = nil ,
95+ history: consuming [ Chat . Message ] ,
9496 generateParameters: GenerateParameters = . init( ) ,
9597 processing: UserInput . Processing = . init( resize: CGSize ( width: 512 , height: 512 ) ) ,
9698 additionalContext: [ String : any Sendable ] ? = nil
9799 ) {
98- self . init (
99- model,
100- instructions: nil ,
101- generateParameters: generateParameters,
102- processing: processing,
103- additionalContext: additionalContext
104- )
105- self . messages = history
100+ self . model = model
101+ self . instructions = instructions
102+ self . cache = . init( . history( history) )
103+ self . processing = processing
104+ self . generateParameters = generateParameters
105+ self . additionalContext = additionalContext
106106 }
107107
108108 /// Initialize the `ChatSession` with an existing message history.
@@ -115,21 +115,20 @@ public final class ChatSession {
115115 /// - generateParameters: parameters that control generation
116116 /// - processing: media processing configuration for images/videos
117117 /// - additionalContext: optional model-specific context
118- public convenience init (
118+ public init (
119119 _ model: ModelContext ,
120+ instructions: String ? = nil ,
120121 history: [ Chat . Message ] ,
121122 generateParameters: GenerateParameters = . init( ) ,
122123 processing: UserInput . Processing = . init( resize: CGSize ( width: 512 , height: 512 ) ) ,
123124 additionalContext: [ String : any Sendable ] ? = nil
124125 ) {
125- self . init (
126- model,
127- instructions: nil ,
128- generateParameters: generateParameters,
129- processing: processing,
130- additionalContext: additionalContext
131- )
132- self . messages = history
126+ self . model = ModelContainer ( context: model)
127+ self . instructions = instructions
128+ self . cache = . init( . history( history) )
129+ self . processing = processing
130+ self . generateParameters = generateParameters
131+ self . additionalContext = additionalContext
133132 }
134133
135134 /// Produces a response to a prompt.
@@ -141,45 +140,13 @@ public final class ChatSession {
141140 /// - Returns: the model's response
142141 public func respond(
143142 to prompt: String ,
144- images: [ UserInput . Image ] ,
145- videos: [ UserInput . Video ]
143+ images: consuming [ UserInput . Image ] ,
144+ videos: consuming [ UserInput . Video ]
146145 ) async throws -> String {
147- messages. append ( . user( prompt, images: images, videos: videos) )
148-
149- func generate( context: ModelContext ) async throws -> String {
150- let userInput = UserInput (
151- chat: messages, processing: processing, additionalContext: additionalContext)
152- let input = try await context. processor. prepare ( input: userInput)
153-
154- if cache. isEmpty {
155- cache = context. model. newCache ( parameters: generateParameters)
156- }
157-
158- var output = " "
159- for await generation in try MLXLMCommon . generate (
160- input: input, cache: cache, parameters: generateParameters, context: context
161- ) {
162- if let chunk = generation. chunk {
163- output += chunk
164- }
165- }
166-
167- Stream ( ) . synchronize ( )
168-
169- return output
170- }
171-
172- let output : String
173- switch model {
174- case . container( let container) :
175- output = try await container. perform { context in
176- try await generate ( context: context)
177- }
178- case . context( let context) :
179- output = try await generate ( context: context)
146+ var output = " "
147+ for try await chunk in streamResponse ( to: prompt, images: images, videos: videos) {
148+ output += chunk
180149 }
181-
182- messages. append ( . assistant( output) )
183150 return output
184151 }
185152
@@ -202,7 +169,7 @@ public final class ChatSession {
202169 )
203170 }
204171
205- /// Produces a streaming response to a prompt.
172+ /// Produces a streaming response to a prompt as Strings .
206173 ///
207174 /// - Parameters:
208175 /// - prompt: the user prompt
@@ -211,16 +178,139 @@ public final class ChatSession {
211178 /// - Returns: a stream of string chunks from the model
212179 public func streamResponse(
213180 to prompt: String ,
214- images: [ UserInput . Image ] ,
215- videos: [ UserInput . Video ]
181+ images: consuming [ UserInput . Image ] ,
182+ videos: consuming [ UserInput . Video ]
216183 ) -> AsyncThrowingStream < String , Error > {
217- messages. append ( . user( prompt, images: images, videos: videos) )
184+ streamMap ( to: prompt, images: images, videos: videos) {
185+ $0. chunk
186+ }
187+ }
188+
189+ /// Produces a streaming response to a prompt as `Generation`.
190+ ///
191+ /// - Parameters:
192+ /// - prompt: the user prompt
193+ /// - images: list of images (for use with VLMs)
194+ /// - videos: list of videos (for use with VLMs)
195+ /// - Returns: a stream of `Generation` from the model
196+ public func streamDetails(
197+ to prompt: String ,
198+ images: consuming [ UserInput . Image ] ,
199+ videos: consuming [ UserInput . Video ]
200+ ) -> AsyncThrowingStream < Generation , Error > {
201+ streamMap ( to: prompt, images: images, videos: videos) {
202+ $0
203+ }
204+ }
218205
219- let ( stream, continuation) = AsyncThrowingStream < String , Error > . makeStream ( )
206+ /// Produces a streaming response to a prompt by transforming the
207+ /// raw `Generation` values.
208+ ///
209+ /// - Parameters:
210+ /// - prompt: the user prompt
211+ /// - images: list of images (for use with VLMs)
212+ /// - videos: list of videos (for use with VLMs)
213+ /// - Returns: a stream of transformed values from the model
214+ private func streamMap< R: Sendable > (
215+ to prompt: String ,
216+ images: consuming [ UserInput . Image ] ,
217+ videos: consuming [ UserInput . Video ] ,
218+ transform: @Sendable @escaping ( Generation ) -> R ?
219+ ) -> AsyncThrowingStream < R , Error > {
220+ let ( stream, continuation) = AsyncThrowingStream < R , Error > . makeStream ( )
221+
222+ // images and videos are not Sendable (MLXArray) but they are consumed
223+ // and are only being sent to the inner async
224+ let message = SendableBox < Chat . Message > (
225+ . user( prompt, images: images, videos: videos)
226+ )
220227
221228 let task = Task {
229+ [
230+ model,
231+ instructions, processing, additionalContext, cache, generateParameters
232+ ] in
222233 do {
223- try await self . performStreaming ( continuation: continuation)
234+ try await cache. update { cache in
235+
236+ // these are all Sendable
237+ let processor = await model. processor
238+ let tokenizer = await model. tokenizer
239+ let modelConfiguration = await model. configuration
240+
241+ var messages : [ Chat . Message ] = [ ]
242+ if let instructions {
243+ messages. append ( . system( instructions) )
244+ }
245+
246+ // prepare the cache, if needed. note:
247+ // this is using the LanguageModel (not Sendable) outside
248+ // the protective lock. Assuming the weights are not
249+ // being mutated behind the scenes, this will obey the MLXArray
250+ // contract that they be evaluated if used across threads.
251+ // This is internal to the implementation and this technique
252+ // should not be used in calling code.
253+ //
254+ // The benefit is that callers can be running multiple
255+ // ChatSessions in parallel, as long as the instances
256+ // are distinct. In particular the KVCache cannot
257+ // be shared and that is the lock that is held here.
258+
259+ let model = await model. perform { context in
260+ SendableBox ( context. model)
261+ } . consume ( )
262+
263+ var kvCache : [ KVCache ]
264+ switch cache {
265+ case . empty:
266+ kvCache = model. newCache ( parameters: generateParameters)
267+ cache = . kvcache( kvCache)
268+
269+ case . kvcache( let array) :
270+ kvCache = array
271+
272+ case . history( let history) :
273+ // the KVCache is represented by a chat history
274+ kvCache = model. newCache ( parameters: generateParameters)
275+ cache = . kvcache( kvCache)
276+ messages. append ( contentsOf: history)
277+ }
278+
279+ // prepare the input
280+ messages. append ( message. consume ( ) )
281+
282+ let userInput = UserInput (
283+ chat: messages, processing: processing,
284+ additionalContext: additionalContext)
285+ let input = try await processor. prepare ( input: userInput)
286+
287+ // generate output
288+ let iterator = try TokenIterator (
289+ input: input, model: model, cache: kvCache,
290+ parameters: generateParameters)
291+
292+ let ( stream, task) = MLXLMCommon . generateTask (
293+ promptTokenCount: input. text. tokens. size,
294+ modelConfiguration: modelConfiguration,
295+ tokenizer: tokenizer,
296+ iterator: iterator
297+ )
298+
299+ for await item in stream {
300+ if let value = transform ( item) {
301+ if case . terminated = continuation. yield ( value) {
302+ break
303+ }
304+ }
305+ }
306+
307+ // wait for the task to complete -- this is important in
308+ // the case where we broke the loop early as the generation
309+ // work may continue (briefly) and use the KVCache
310+ await task. value
311+
312+ continuation. finish ( )
313+ }
224314 } catch {
225315 continuation. finish ( throwing: error)
226316 }
@@ -253,48 +343,17 @@ public final class ChatSession {
253343 }
254344
255345 /// Clear the session history and cache, preserving system instructions.
256- public func clear( ) {
257- messages = messages. filter { $0. role == . system }
258- cache = [ ]
259- }
260-
261- // MARK: - Private
262-
263- private func performStreaming(
264- continuation: AsyncThrowingStream < String , Error > . Continuation
265- ) async throws {
266- func stream( context: ModelContext ) async throws {
267- let userInput = UserInput (
268- chat: messages, processing: processing, additionalContext: additionalContext)
269- let input = try await context. processor. prepare ( input: userInput)
270-
271- if cache. isEmpty {
272- cache = context. model. newCache ( parameters: generateParameters)
273- }
274-
275- var fullResponse = " "
276- for await item in try MLXLMCommon . generate (
277- input: input, cache: cache, parameters: generateParameters, context: context
278- ) {
279- if let chunk = item. chunk {
280- fullResponse += chunk
281- continuation. yield ( chunk)
282- }
283- }
284-
285- Stream ( ) . synchronize ( )
286-
287- messages. append ( . assistant( fullResponse) )
288- continuation. finish ( )
346+ public func clear( ) async {
347+ await cache. update { cache in
348+ cache = . empty
289349 }
350+ }
290351
291- switch model {
292- case . container( let container) :
293- try await container. perform { context in
294- try await stream ( context: context)
295- }
296- case . context( let context) :
297- try await stream ( context: context)
298- }
352+ /// Wait for exclusive access to the KVCache.
353+ ///
354+ /// This is useful for cases where a program is terminating and wants to ensure that any
355+ /// async operations are complete.
356+ public func synchronize( ) async {
357+ await cache. read { _ in }
299358 }
300359}
0 commit comments