@@ -283,11 +283,8 @@ private enum Vision {
283
283
}
284
284
285
285
func callAsFunction( sequenceLength: Int ) -> MLXArray {
286
- let inverseFreq =
287
- 1.0
288
- / ( pow (
289
- theta,
290
- MLXArray ( stride ( from: 0 , to: dimensions, by: 2 ) ) . asType ( . float32) / dimensions) )
286
+ let p = MLXArray ( stride ( from: 0 , to: dimensions, by: 2 ) ) . asType ( . float32) / dimensions
287
+ let inverseFreq = 1.0 / pow( theta, p)
291
288
let seq = MLXArray ( 0 ..< sequenceLength) . asType ( inverseFreq. dtype)
292
289
let freqs = outer ( seq, inverseFreq)
293
290
return freqs
@@ -370,16 +367,6 @@ private enum Vision {
370
367
self . _proj. wrappedValue = Linear ( dims, dims)
371
368
}
372
369
373
- private func makeMask( cuSequenceLengths: MLXArray , sequenceLength: Int ) -> MLXArray {
374
- let starts = cuSequenceLengths [ . newAxis, ..< ( - 1 ) ]
375
- let ends = cuSequenceLengths [ . newAxis, 1 ... ]
376
- let indices = MLXArray ( 0 ..< sequenceLength) [ 0 ... , . newAxis]
377
- var mask = ( indices .>= starts) & ( indices .< ends)
378
- mask = mask. any ( axis: - 1 )
379
- mask = mask [ . newAxis] & mask [ 0 ... , . newAxis]
380
- return 1 - mask
381
- }
382
-
383
370
public func callAsFunction(
384
371
_ x: MLXArray , gridThw: [ THW ] , rotaryPositionEmbedding: MLXArray
385
372
) -> MLXArray {
@@ -391,6 +378,10 @@ private enum Vision {
391
378
let s = split ( qkv, parts: 3 , axis: 1 )
392
379
var ( q, k, v) = ( s [ 0 ] , s [ 1 ] , s [ 2 ] )
393
380
381
+ q = q. reshaped ( sequenceLength, numHeads, - 1 )
382
+ k = k. reshaped ( sequenceLength, numHeads, - 1 )
383
+ v = v. reshaped ( sequenceLength, numHeads, - 1 )
384
+
394
385
q = applyMultimodalRotaryPositionEmbedding ( q, freqs: rotaryPositionEmbedding)
395
386
k = applyMultimodalRotaryPositionEmbedding ( k, freqs: rotaryPositionEmbedding)
396
387
@@ -530,8 +521,6 @@ private enum Vision {
530
521
let rotaryPositionEmbedFull = rotaryPositionEmbedding ( sequenceLength: maxGridSize) [
531
522
indices]
532
523
533
- print ( " rot_pos_emb(), \( maxGridSize) \( gridThw) , \( rotaryPositionEmbedFull. shape) " )
534
-
535
524
return rotaryPositionEmbedFull. reshaped ( indices. dim ( 0 ) , - 1 )
536
525
}
537
526
@@ -634,16 +623,20 @@ public class Qwen2VLProcessor: UserInputProcessor {
634
623
return ( hBar, wBar)
635
624
}
636
625
637
- public func preprocess( images: [ CIImage ] ) throws -> ( MLXArray , THW ) {
626
+ public func preprocess( images: [ CIImage ] , processing: UserInput . Processing ? ) throws -> (
627
+ MLXArray , THW
628
+ ) {
638
629
639
630
// image_processing_qwen2_vl._preprocess
640
631
632
+ let images = images. map { MediaProcessing . apply ( $0, processing: processing) }
633
+
641
634
let size = images [ 0 ] . extent. size
642
635
let ( resizedHeight, resizedWidth) = try targetSize (
643
636
height: Int ( size. height) , width: Int ( size. width) ,
644
637
factor: config. patchSize * config. mergeSize,
645
638
minPixels: config. size. minPixels, maxPixels: config. size. maxPixels)
646
- let resizedSize = CGSize ( width: resizedHeight , height: resizedWidth )
639
+ let resizedSize = CGSize ( width: resizedWidth , height: resizedHeight )
647
640
648
641
let processedImages =
649
642
try images
@@ -691,26 +684,68 @@ public class Qwen2VLProcessor: UserInputProcessor {
691
684
return ( flattenedPatches, . init( gridT, gridH, gridW) )
692
685
}
693
686
694
- public func prepare( input: UserInput ) throws -> LMInput {
695
- // this doesn't have a chat template so just use the last message
696
- let prompt = input. prompt. asMessages ( ) . last ? [ " content " ] ?? " "
687
+ public func prepare( prompt: UserInput . Prompt , imageTHW: [ THW ] ? ) -> String {
688
+ // the tokenizer does have a chat template and it expects messages
689
+ // like this:
690
+ //
691
+ // [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'},
692
+ // {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}]
693
+ //
694
+ // The output of the prompt template is fed into
695
+ // image_processing_qwen2_vl.preprocess where it is further augmented
696
+ // by replacing tokens according to imageTHW.
697
+ //
698
+ // Neither the structured content nor the postprocessing of the template
699
+ // are supported in current Tokenizer/Jinja (swift) so handle that here.
700
+
701
+ var messages = prompt. asMessages ( )
702
+ if messages [ 0 ] [ " role " ] != " system " {
703
+ messages. insert ( [ " role " : " system " , " content " : " You are a helpful assistant. " ] , at: 0 )
704
+ }
705
+
706
+ let lastIndex = messages. count - 1
707
+ var lastMessage = messages [ lastIndex] [ " content " ] ?? " "
708
+
709
+ // image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image
710
+ let mergeLength = config. mergeSize * config. mergeSize
711
+ for thw in imageTHW ?? [ ] {
712
+ lastMessage += " <|vision_start|> "
713
+ lastMessage += Array ( repeating: " <|image_pad|> " , count: thw. product / mergeLength)
714
+ . joined ( )
715
+ lastMessage += " <|vision_end|> "
716
+ }
717
+
718
+ messages [ lastIndex] [ " content " ] = lastMessage
719
+
720
+ return
721
+ messages
722
+ . map {
723
+ " <|im_start|> \( $0 [ " role " ] ?? " user " ) \n \( $0 [ " content " ] ?? " " ) <|im_end|> "
724
+ }
725
+ . joined ( separator: " \n " )
726
+ + " \n <|im_start|>assistant \n "
727
+ }
697
728
729
+ public func prepare( input: UserInput ) throws -> LMInput {
698
730
if input. images. isEmpty {
699
731
// just a straight text prompt
732
+ let prompt = prepare ( prompt: input. prompt, imageTHW: nil )
700
733
let promptTokens = try tokenizer. encode ( text: prompt)
701
734
return LMInput ( tokens: MLXArray ( promptTokens) )
702
735
}
703
736
704
- let promptTokens = try tokenizer. encode ( text: prompt)
705
- let promptArray = MLXArray ( promptTokens) . expandedDimensions ( axis: 0 )
706
- let mask = ones ( like: promptArray)
707
-
708
737
// image_processing_qwen2_vl.preprocess
709
- let images = try input. images. map { try preprocess ( images: [ $0. asCIImage ( ) ] ) }
738
+ let images = try input. images. map {
739
+ try preprocess ( images: [ $0. asCIImage ( ) ] , processing: input. processing)
740
+ }
710
741
let pixels = concatenated ( images. map { $0. 0 } )
711
742
let image = LMInput . ProcessedImage ( pixels: pixels, imageGridThw: images. map { $0. 1 } )
712
743
713
- print ( " image \( image. pixels. shape) , \( image. imageGridThw) " )
744
+ // processing_qwen2_vl.Qwen2VLProcessor
745
+ let prompt = prepare ( prompt: input. prompt, imageTHW: image. imageGridThw)
746
+ let promptTokens = try tokenizer. encode ( text: prompt)
747
+ let promptArray = MLXArray ( promptTokens) . expandedDimensions ( axis: 0 )
748
+ let mask = ones ( like: promptArray)
714
749
715
750
return LMInput ( text: . init( tokens: promptArray, mask: mask) , image: image)
716
751
}
@@ -773,7 +808,7 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
773
808
imageIndices. append ( i)
774
809
}
775
810
}
776
- // TODO look at the inputIds -- I think I am missing something here
811
+
777
812
inputEmbeds [ 0 ... , MLXArray ( imageIndices) , 0 ... ] = imageFeatures
778
813
return inputEmbeds
779
814
}
0 commit comments