@@ -25,55 +25,6 @@ private func rotateHalf(_ x: MLXArray) -> MLXArray {
25
25
26
26
private enum Language {
27
27
28
- fileprivate class Qwen2RotaryEmbedding {
29
-
30
- private let dimensions : Int
31
- private let maxPositionEmbeddings : Int
32
- private let base : Float
33
-
34
- private let inverseFreq : MLXArray
35
-
36
- private var cachedSequenceLength = 0
37
- private var cachedCos = MLXArray ( 0 )
38
- private var cachedSin = MLXArray ( 0 )
39
-
40
- public init ( dimensions: Int , maxPositionEmbeddings: Int , base: Float ) {
41
- self . dimensions = dimensions
42
- self . maxPositionEmbeddings = maxPositionEmbeddings
43
- self . base = base
44
-
45
- self . inverseFreq =
46
- 1.0
47
- / ( pow (
48
- base,
49
- MLXArray ( stride ( from: 0 , to: dimensions, by: 2 ) ) . asType ( . float32) / dimensions) )
50
-
51
- buildCache ( length: maxPositionEmbeddings)
52
- }
53
-
54
- private func buildCache( length: Int ) {
55
- cachedSequenceLength = length
56
- let t = MLXArray ( 0 ..< cachedSequenceLength) . asType ( . float32)
57
- let freqs = outer ( t, inverseFreq)
58
-
59
- // Different from paper, but it uses a different permutation in order to obtain the same calculation
60
- let emb = concatenated ( [ freqs, freqs] , axis: - 1 )
61
- cachedCos = cos ( emb)
62
- cachedSin = sin ( emb)
63
- }
64
-
65
- public func callAsFunction( _ x: MLXArray , sequenceLength: Int ) -> ( MLXArray , MLXArray ) {
66
- if sequenceLength > self . cachedSequenceLength {
67
- buildCache ( length: sequenceLength)
68
- }
69
-
70
- return (
71
- cachedCos [ 0 ..< sequenceLength] . asType ( x. dtype) ,
72
- cachedSin [ 0 ..< sequenceLength] . asType ( x. dtype)
73
- )
74
- }
75
- }
76
-
77
28
/// Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors
78
29
static private func applyMultimodalRotaryPositionEmbedding(
79
30
q: MLXArray , k: MLXArray , cos: MLXArray , sin: MLXArray ,
@@ -114,7 +65,7 @@ private enum Language {
114
65
@ModuleInfo ( key: " v_proj " ) var wv : Linear
115
66
@ModuleInfo ( key: " o_proj " ) var wo : Linear
116
67
117
- let rotaryEmbedding : Qwen2RotaryEmbedding
68
+ @ ModuleInfo ( key : " rotary_emb " ) var rotaryEmbedding : RoPE
118
69
119
70
public init ( _ args: Qwen2VLConfiguration . TextConfiguration ) {
120
71
let dim = args. hiddenSize
@@ -143,9 +94,8 @@ private enum Language {
143
94
fatalError ( " rope_scaling['mrope_section'] must be an array of integers " )
144
95
}
145
96
146
- self . rotaryEmbedding = Qwen2RotaryEmbedding (
147
- dimensions: headDim, maxPositionEmbeddings: args. maxpPositionEmbeddings,
148
- base: args. ropeTheta)
97
+ self . _rotaryEmbedding. wrappedValue = RoPE (
98
+ dimensions: headDim, traditional: args. ropeTraditional, base: args. ropeTheta)
149
99
}
150
100
151
101
public func callAsFunction(
@@ -162,30 +112,11 @@ private enum Language {
162
112
keys = keys. reshaped ( B, L, kvHeads, headDim) . transposed ( 0 , 2 , 1 , 3 )
163
113
values = values. reshaped ( B, L, kvHeads, headDim) . transposed ( 0 , 2 , 1 , 3 )
164
114
165
- var kvSequenceLength = keys. dim ( - 2 )
166
- var positionIds : MLXArray
167
- if let cache {
168
- kvSequenceLength += cache. offset + 1
169
- positionIds = MLXArray ( cache. offset ..< ( cache. offset + L) )
170
- } else {
171
- positionIds = MLXArray ( 0 ..< L)
172
- }
173
-
174
- positionIds = expandedDimensions ( positionIds, axis: 0 )
175
- positionIds = tiled ( positionIds, repetitions: [ 3 , 1 , 1 ] )
176
-
177
- let ( cos, sin) = rotaryEmbedding ( values, sequenceLength: kvSequenceLength)
178
-
179
- let mask : MLXArray ? =
180
- if var mask {
181
- mask [ . newAxis, . newAxis, 0 ... , 0 ... ] [ 0 ... , 0 ... , 0 ... , ..< keys. dim ( - 2 ) ]
182
- } else {
183
- nil
184
- }
115
+ let offset = cache? . offset ?? 0
116
+ let mask = mask ? [ 0 ... , 0 ..< keys. dim ( - 2 ) ]
185
117
186
- ( queries, keys) = applyMultimodalRotaryPositionEmbedding (
187
- q: queries, k: keys, cos: cos, sin: sin, positionIds: positionIds,
188
- mropeSection: mropeSection)
118
+ queries = rotaryEmbedding ( queries, offset: offset)
119
+ keys = rotaryEmbedding ( keys, offset: offset)
189
120
190
121
if let cache {
191
122
( keys, values) = cache. update ( keys: keys, values: values)
@@ -450,34 +381,25 @@ private enum Vision {
450
381
}
451
382
452
383
public func callAsFunction(
453
- _ x: MLXArray , cuSequenceLengths : MLXArray , rotaryPositionEmbedding: MLXArray
384
+ _ x: MLXArray , gridThw : [ THW ] , rotaryPositionEmbedding: MLXArray
454
385
) -> MLXArray {
455
386
let sequenceLength = x. dim ( 0 )
387
+ let B = gridThw [ 0 ] . t
388
+ let L = sequenceLength / B
456
389
457
- let qkv = qkv ( x)
458
- . reshaped ( sequenceLength, 3 , self . numHeads, - 1 )
459
- . transposed ( 1 , 0 , 2 , 3 )
460
- let s = split ( qkv, parts: 3 )
390
+ let qkv = qkv ( x) . reshaped ( sequenceLength, 3 , - 1 )
391
+ let s = split ( qkv, parts: 3 , axis: 1 )
461
392
var ( q, k, v) = ( s [ 0 ] , s [ 1 ] , s [ 2 ] )
462
393
463
- print ( " rotaryPositionEmbedding \( rotaryPositionEmbedding. shape) " )
464
-
465
- q =
466
- applyMultimodalRotaryPositionEmbedding (
467
- expandedDimensions ( q, axis: 0 ) , freqs: rotaryPositionEmbedding) [ 0 ]
468
- k =
469
- applyMultimodalRotaryPositionEmbedding (
470
- expandedDimensions ( k, axis: 0 ) , freqs: rotaryPositionEmbedding) [ 0 ]
471
-
472
- q = q. transposed ( 0 , 2 , 1 , 3 )
473
- k = k. transposed ( 0 , 2 , 1 , 3 )
474
- v = v. transposed ( 0 , 2 , 1 , 3 )
394
+ q = applyMultimodalRotaryPositionEmbedding ( q, freqs: rotaryPositionEmbedding)
395
+ k = applyMultimodalRotaryPositionEmbedding ( k, freqs: rotaryPositionEmbedding)
475
396
476
- let mask = makeMask (
477
- cuSequenceLengths: cuSequenceLengths, sequenceLength: sequenceLength)
397
+ q = q. reshaped ( B, L, numHeads, - 1 ) . transposed ( 0 , 2 , 1 , 3 )
398
+ k = k. reshaped ( B, L, numHeads, - 1 ) . transposed ( 0 , 2 , 1 , 3 )
399
+ v = v. reshaped ( B, L, numHeads, - 1 ) . transposed ( 0 , 2 , 1 , 3 )
478
400
479
401
let output = MLXFast . scaledDotProductAttention (
480
- queries: q, keys: k, values: v, scale: scale, mask: mask
402
+ queries: q, keys: k, values: v, scale: scale, mask: nil
481
403
)
482
404
. transposed ( 0 , 2 , 1 , 3 )
483
405
. reshaped ( sequenceLength, - 1 )
@@ -523,13 +445,13 @@ private enum Vision {
523
445
}
524
446
525
447
func callAsFunction(
526
- _ hiddenStates: MLXArray , cuSequenceLengths : MLXArray , rotaryPositionEmbedding: MLXArray
448
+ _ hiddenStates: MLXArray , gridThw : [ THW ] , rotaryPositionEmbedding: MLXArray
527
449
) -> MLXArray {
528
450
var hiddenStates =
529
451
hiddenStates
530
452
+ attention(
531
453
norm1 ( hiddenStates) ,
532
- cuSequenceLengths : cuSequenceLengths ,
454
+ gridThw : gridThw ,
533
455
rotaryPositionEmbedding: rotaryPositionEmbedding
534
456
)
535
457
hiddenStates = hiddenStates + mlp( norm2 ( hiddenStates) )
@@ -619,22 +541,9 @@ private enum Vision {
619
541
620
542
let batchSize = gridThw. count
621
543
622
- // Calculate cu_seqlens for each item in the batch
623
- var collect = [ MLXArray] ( )
624
- for thw in gridThw {
625
- let sequenceLength = thw. h * thw. w
626
- collect. append ( repeated ( MLXArray ( sequenceLength) , count: thw. t) )
627
- }
628
-
629
- // Concatenate the cu_seqlens for all items in the batch
630
- var cuSeqLengths = concatenated ( collect)
631
-
632
- cuSeqLengths = cumsum ( cuSeqLengths. asType ( Int32 . self) , axis: 0 )
633
- cuSeqLengths = padded ( cuSeqLengths, width: [ 1 , 0 ] , mode: . constant, value: MLXArray ( 0 ) )
634
-
635
544
for block in blocks {
636
545
hiddenStates = block (
637
- hiddenStates, cuSequenceLengths : cuSeqLengths ,
546
+ hiddenStates, gridThw : gridThw ,
638
547
rotaryPositionEmbedding: rotaryPositionEmbedding)
639
548
}
640
549
@@ -797,7 +706,7 @@ public class Qwen2VLProcessor: UserInputProcessor {
797
706
let mask = ones ( like: promptArray)
798
707
799
708
// image_processing_qwen2_vl.preprocess
800
- let images = try input. images. map { try preprocess ( images: [ $0] ) }
709
+ let images = try input. images. map { try preprocess ( images: [ $0. asCIImage ( ) ] ) }
801
710
let pixels = concatenated ( images. map { $0. 0 } )
802
711
let image = LMInput . ProcessedImage ( pixels: pixels, imageGridThw: images. map { $0. 1 } )
803
712
@@ -873,8 +782,12 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
873
782
-> PrepareResult
874
783
{
875
784
let gridThw = input. image? . imageGridThw
785
+
786
+ let dtype = visionModel. patchEmbed. proj. weight. dtype
787
+ let pixels = input. image? . pixels. asType ( dtype)
788
+
876
789
let inputEmbeddings = self . inputEmbeddings (
877
- inputIds: input. text. tokens, pixelValues: input . image ? . pixels, gridThw: gridThw)
790
+ inputIds: input. text. tokens, pixelValues: pixels, gridThw: gridThw)
878
791
879
792
let result = languageModel ( nil , cache: cache, inputEmbedding: inputEmbeddings)
880
793
0 commit comments