@@ -283,11 +283,8 @@ private enum Vision {
283283 }
284284
285285 func callAsFunction( sequenceLength: Int ) -> MLXArray {
286- let inverseFreq =
287- 1.0
288- / ( pow (
289- theta,
290- MLXArray ( stride ( from: 0 , to: dimensions, by: 2 ) ) . asType ( . float32) / dimensions) )
286+ let p = MLXArray ( stride ( from: 0 , to: dimensions, by: 2 ) ) . asType ( . float32) / dimensions
287+ let inverseFreq = 1.0 / pow( theta, p)
291288 let seq = MLXArray ( 0 ..< sequenceLength) . asType ( inverseFreq. dtype)
292289 let freqs = outer ( seq, inverseFreq)
293290 return freqs
@@ -370,16 +367,6 @@ private enum Vision {
370367 self . _proj. wrappedValue = Linear ( dims, dims)
371368 }
372369
373- private func makeMask( cuSequenceLengths: MLXArray , sequenceLength: Int ) -> MLXArray {
374- let starts = cuSequenceLengths [ . newAxis, ..< ( - 1 ) ]
375- let ends = cuSequenceLengths [ . newAxis, 1 ... ]
376- let indices = MLXArray ( 0 ..< sequenceLength) [ 0 ... , . newAxis]
377- var mask = ( indices .>= starts) & ( indices .< ends)
378- mask = mask. any ( axis: - 1 )
379- mask = mask [ . newAxis] & mask [ 0 ... , . newAxis]
380- return 1 - mask
381- }
382-
383370 public func callAsFunction(
384371 _ x: MLXArray , gridThw: [ THW ] , rotaryPositionEmbedding: MLXArray
385372 ) -> MLXArray {
@@ -391,6 +378,10 @@ private enum Vision {
391378 let s = split ( qkv, parts: 3 , axis: 1 )
392379 var ( q, k, v) = ( s [ 0 ] , s [ 1 ] , s [ 2 ] )
393380
381+ q = q. reshaped ( sequenceLength, numHeads, - 1 )
382+ k = k. reshaped ( sequenceLength, numHeads, - 1 )
383+ v = v. reshaped ( sequenceLength, numHeads, - 1 )
384+
394385 q = applyMultimodalRotaryPositionEmbedding ( q, freqs: rotaryPositionEmbedding)
395386 k = applyMultimodalRotaryPositionEmbedding ( k, freqs: rotaryPositionEmbedding)
396387
@@ -530,8 +521,6 @@ private enum Vision {
530521 let rotaryPositionEmbedFull = rotaryPositionEmbedding ( sequenceLength: maxGridSize) [
531522 indices]
532523
533- print ( " rot_pos_emb(), \( maxGridSize) \( gridThw) , \( rotaryPositionEmbedFull. shape) " )
534-
535524 return rotaryPositionEmbedFull. reshaped ( indices. dim ( 0 ) , - 1 )
536525 }
537526
@@ -634,16 +623,20 @@ public class Qwen2VLProcessor: UserInputProcessor {
634623 return ( hBar, wBar)
635624 }
636625
637- public func preprocess( images: [ CIImage ] ) throws -> ( MLXArray , THW ) {
626+ public func preprocess( images: [ CIImage ] , processing: UserInput . Processing ? ) throws -> (
627+ MLXArray , THW
628+ ) {
638629
639630 // image_processing_qwen2_vl._preprocess
640631
632+ let images = images. map { MediaProcessing . apply ( $0, processing: processing) }
633+
641634 let size = images [ 0 ] . extent. size
642635 let ( resizedHeight, resizedWidth) = try targetSize (
643636 height: Int ( size. height) , width: Int ( size. width) ,
644637 factor: config. patchSize * config. mergeSize,
645638 minPixels: config. size. minPixels, maxPixels: config. size. maxPixels)
646- let resizedSize = CGSize ( width: resizedHeight , height: resizedWidth )
639+ let resizedSize = CGSize ( width: resizedWidth , height: resizedHeight )
647640
648641 let processedImages =
649642 try images
@@ -691,26 +684,68 @@ public class Qwen2VLProcessor: UserInputProcessor {
691684 return ( flattenedPatches, . init( gridT, gridH, gridW) )
692685 }
693686
694- public func prepare( input: UserInput ) throws -> LMInput {
695- // this doesn't have a chat template so just use the last message
696- let prompt = input. prompt. asMessages ( ) . last ? [ " content " ] ?? " "
687+ public func prepare( prompt: UserInput . Prompt , imageTHW: [ THW ] ? ) -> String {
688+ // the tokenizer does have a chat template and it expects messages
689+ // like this:
690+ //
691+ // [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'},
692+ // {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}]
693+ //
694+ // The output of the prompt template is fed into
695+ // image_processing_qwen2_vl.preprocess where it is further augmented
696+ // by replacing tokens according to imageTHW.
697+ //
698+ // Neither the structured content nor the postprocessing of the template
699+ // are supported in current Tokenizer/Jinja (swift) so handle that here.
700+
701+ var messages = prompt. asMessages ( )
702+ if messages [ 0 ] [ " role " ] != " system " {
703+ messages. insert ( [ " role " : " system " , " content " : " You are a helpful assistant. " ] , at: 0 )
704+ }
705+
706+ let lastIndex = messages. count - 1
707+ var lastMessage = messages [ lastIndex] [ " content " ] ?? " "
708+
709+ // image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image
710+ let mergeLength = config. mergeSize * config. mergeSize
711+ for thw in imageTHW ?? [ ] {
712+ lastMessage += " <|vision_start|> "
713+ lastMessage += Array ( repeating: " <|image_pad|> " , count: thw. product / mergeLength)
714+ . joined ( )
715+ lastMessage += " <|vision_end|> "
716+ }
717+
718+ messages [ lastIndex] [ " content " ] = lastMessage
719+
720+ return
721+ messages
722+ . map {
723+ " <|im_start|> \( $0 [ " role " ] ?? " user " ) \n \( $0 [ " content " ] ?? " " ) <|im_end|> "
724+ }
725+ . joined ( separator: " \n " )
726+ + " \n <|im_start|>assistant \n "
727+ }
697728
729+ public func prepare( input: UserInput ) throws -> LMInput {
698730 if input. images. isEmpty {
699731 // just a straight text prompt
732+ let prompt = prepare ( prompt: input. prompt, imageTHW: nil )
700733 let promptTokens = try tokenizer. encode ( text: prompt)
701734 return LMInput ( tokens: MLXArray ( promptTokens) )
702735 }
703736
704- let promptTokens = try tokenizer. encode ( text: prompt)
705- let promptArray = MLXArray ( promptTokens) . expandedDimensions ( axis: 0 )
706- let mask = ones ( like: promptArray)
707-
708737 // image_processing_qwen2_vl.preprocess
709- let images = try input. images. map { try preprocess ( images: [ $0. asCIImage ( ) ] ) }
738+ let images = try input. images. map {
739+ try preprocess ( images: [ $0. asCIImage ( ) ] , processing: input. processing)
740+ }
710741 let pixels = concatenated ( images. map { $0. 0 } )
711742 let image = LMInput . ProcessedImage ( pixels: pixels, imageGridThw: images. map { $0. 1 } )
712743
713- print ( " image \( image. pixels. shape) , \( image. imageGridThw) " )
744+ // processing_qwen2_vl.Qwen2VLProcessor
745+ let prompt = prepare ( prompt: input. prompt, imageTHW: image. imageGridThw)
746+ let promptTokens = try tokenizer. encode ( text: prompt)
747+ let promptArray = MLXArray ( promptTokens) . expandedDimensions ( axis: 0 )
748+ let mask = ones ( like: promptArray)
714749
715750 return LMInput ( text: . init( tokens: promptArray, mask: mask) , image: image)
716751 }
@@ -773,7 +808,7 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
773808 imageIndices. append ( i)
774809 }
775810 }
776- // TODO look at the inputIds -- I think I am missing something here
811+
777812 inputEmbeds [ 0 ... , MLXArray ( imageIndices) , 0 ... ] = imageFeatures
778813 return inputEmbeds
779814 }
0 commit comments