Skip to content

Commit 8a68223

Browse files
committed
qwen2 image processing
1 parent 1ac4265 commit 8a68223

File tree

1 file changed

+27
-114
lines changed

1 file changed

+27
-114
lines changed

Libraries/VLM/Models/Qwen2VL.swift

Lines changed: 27 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -25,55 +25,6 @@ private func rotateHalf(_ x: MLXArray) -> MLXArray {
2525

2626
private enum Language {
2727

28-
fileprivate class Qwen2RotaryEmbedding {
29-
30-
private let dimensions: Int
31-
private let maxPositionEmbeddings: Int
32-
private let base: Float
33-
34-
private let inverseFreq: MLXArray
35-
36-
private var cachedSequenceLength = 0
37-
private var cachedCos = MLXArray(0)
38-
private var cachedSin = MLXArray(0)
39-
40-
public init(dimensions: Int, maxPositionEmbeddings: Int, base: Float) {
41-
self.dimensions = dimensions
42-
self.maxPositionEmbeddings = maxPositionEmbeddings
43-
self.base = base
44-
45-
self.inverseFreq =
46-
1.0
47-
/ (pow(
48-
base,
49-
MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions))
50-
51-
buildCache(length: maxPositionEmbeddings)
52-
}
53-
54-
private func buildCache(length: Int) {
55-
cachedSequenceLength = length
56-
let t = MLXArray(0 ..< cachedSequenceLength).asType(.float32)
57-
let freqs = outer(t, inverseFreq)
58-
59-
// Different from paper, but it uses a different permutation in order to obtain the same calculation
60-
let emb = concatenated([freqs, freqs], axis: -1)
61-
cachedCos = cos(emb)
62-
cachedSin = sin(emb)
63-
}
64-
65-
public func callAsFunction(_ x: MLXArray, sequenceLength: Int) -> (MLXArray, MLXArray) {
66-
if sequenceLength > self.cachedSequenceLength {
67-
buildCache(length: sequenceLength)
68-
}
69-
70-
return (
71-
cachedCos[0 ..< sequenceLength].asType(x.dtype),
72-
cachedSin[0 ..< sequenceLength].asType(x.dtype)
73-
)
74-
}
75-
}
76-
7728
/// Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors
7829
static private func applyMultimodalRotaryPositionEmbedding(
7930
q: MLXArray, k: MLXArray, cos: MLXArray, sin: MLXArray,
@@ -114,7 +65,7 @@ private enum Language {
11465
@ModuleInfo(key: "v_proj") var wv: Linear
11566
@ModuleInfo(key: "o_proj") var wo: Linear
11667

117-
let rotaryEmbedding: Qwen2RotaryEmbedding
68+
@ModuleInfo(key: "rotary_emb") var rotaryEmbedding: RoPE
11869

11970
public init(_ args: Qwen2VLConfiguration.TextConfiguration) {
12071
let dim = args.hiddenSize
@@ -143,9 +94,8 @@ private enum Language {
14394
fatalError("rope_scaling['mrope_section'] must be an array of integers")
14495
}
14596

146-
self.rotaryEmbedding = Qwen2RotaryEmbedding(
147-
dimensions: headDim, maxPositionEmbeddings: args.maxpPositionEmbeddings,
148-
base: args.ropeTheta)
97+
self._rotaryEmbedding.wrappedValue = RoPE(
98+
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta)
14999
}
150100

151101
public func callAsFunction(
@@ -162,30 +112,11 @@ private enum Language {
162112
keys = keys.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3)
163113
values = values.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3)
164114

165-
var kvSequenceLength = keys.dim(-2)
166-
var positionIds: MLXArray
167-
if let cache {
168-
kvSequenceLength += cache.offset + 1
169-
positionIds = MLXArray(cache.offset ..< (cache.offset + L))
170-
} else {
171-
positionIds = MLXArray(0 ..< L)
172-
}
173-
174-
positionIds = expandedDimensions(positionIds, axis: 0)
175-
positionIds = tiled(positionIds, repetitions: [3, 1, 1])
176-
177-
let (cos, sin) = rotaryEmbedding(values, sequenceLength: kvSequenceLength)
178-
179-
let mask: MLXArray? =
180-
if var mask {
181-
mask[.newAxis, .newAxis, 0..., 0...][0..., 0..., 0..., ..<keys.dim(-2)]
182-
} else {
183-
nil
184-
}
115+
let offset = cache?.offset ?? 0
116+
let mask = mask?[0..., 0 ..< keys.dim(-2)]
185117

186-
(queries, keys) = applyMultimodalRotaryPositionEmbedding(
187-
q: queries, k: keys, cos: cos, sin: sin, positionIds: positionIds,
188-
mropeSection: mropeSection)
118+
queries = rotaryEmbedding(queries, offset: offset)
119+
keys = rotaryEmbedding(keys, offset: offset)
189120

190121
if let cache {
191122
(keys, values) = cache.update(keys: keys, values: values)
@@ -450,34 +381,25 @@ private enum Vision {
450381
}
451382

452383
public func callAsFunction(
453-
_ x: MLXArray, cuSequenceLengths: MLXArray, rotaryPositionEmbedding: MLXArray
384+
_ x: MLXArray, gridThw: [THW], rotaryPositionEmbedding: MLXArray
454385
) -> MLXArray {
455386
let sequenceLength = x.dim(0)
387+
let B = gridThw[0].t
388+
let L = sequenceLength / B
456389

457-
let qkv = qkv(x)
458-
.reshaped(sequenceLength, 3, self.numHeads, -1)
459-
.transposed(1, 0, 2, 3)
460-
let s = split(qkv, parts: 3)
390+
let qkv = qkv(x).reshaped(sequenceLength, 3, -1)
391+
let s = split(qkv, parts: 3, axis: 1)
461392
var (q, k, v) = (s[0], s[1], s[2])
462393

463-
print("rotaryPositionEmbedding \(rotaryPositionEmbedding.shape)")
464-
465-
q =
466-
applyMultimodalRotaryPositionEmbedding(
467-
expandedDimensions(q, axis: 0), freqs: rotaryPositionEmbedding)[0]
468-
k =
469-
applyMultimodalRotaryPositionEmbedding(
470-
expandedDimensions(k, axis: 0), freqs: rotaryPositionEmbedding)[0]
471-
472-
q = q.transposed(0, 2, 1, 3)
473-
k = k.transposed(0, 2, 1, 3)
474-
v = v.transposed(0, 2, 1, 3)
394+
q = applyMultimodalRotaryPositionEmbedding(q, freqs: rotaryPositionEmbedding)
395+
k = applyMultimodalRotaryPositionEmbedding(k, freqs: rotaryPositionEmbedding)
475396

476-
let mask = makeMask(
477-
cuSequenceLengths: cuSequenceLengths, sequenceLength: sequenceLength)
397+
q = q.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3)
398+
k = k.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3)
399+
v = v.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3)
478400

479401
let output = MLXFast.scaledDotProductAttention(
480-
queries: q, keys: k, values: v, scale: scale, mask: mask
402+
queries: q, keys: k, values: v, scale: scale, mask: nil
481403
)
482404
.transposed(0, 2, 1, 3)
483405
.reshaped(sequenceLength, -1)
@@ -523,13 +445,13 @@ private enum Vision {
523445
}
524446

525447
func callAsFunction(
526-
_ hiddenStates: MLXArray, cuSequenceLengths: MLXArray, rotaryPositionEmbedding: MLXArray
448+
_ hiddenStates: MLXArray, gridThw: [THW], rotaryPositionEmbedding: MLXArray
527449
) -> MLXArray {
528450
var hiddenStates =
529451
hiddenStates
530452
+ attention(
531453
norm1(hiddenStates),
532-
cuSequenceLengths: cuSequenceLengths,
454+
gridThw: gridThw,
533455
rotaryPositionEmbedding: rotaryPositionEmbedding
534456
)
535457
hiddenStates = hiddenStates + mlp(norm2(hiddenStates))
@@ -619,22 +541,9 @@ private enum Vision {
619541

620542
let batchSize = gridThw.count
621543

622-
// Calculate cu_seqlens for each item in the batch
623-
var collect = [MLXArray]()
624-
for thw in gridThw {
625-
let sequenceLength = thw.h * thw.w
626-
collect.append(repeated(MLXArray(sequenceLength), count: thw.t))
627-
}
628-
629-
// Concatenate the cu_seqlens for all items in the batch
630-
var cuSeqLengths = concatenated(collect)
631-
632-
cuSeqLengths = cumsum(cuSeqLengths.asType(Int32.self), axis: 0)
633-
cuSeqLengths = padded(cuSeqLengths, width: [1, 0], mode: .constant, value: MLXArray(0))
634-
635544
for block in blocks {
636545
hiddenStates = block(
637-
hiddenStates, cuSequenceLengths: cuSeqLengths,
546+
hiddenStates, gridThw: gridThw,
638547
rotaryPositionEmbedding: rotaryPositionEmbedding)
639548
}
640549

@@ -797,7 +706,7 @@ public class Qwen2VLProcessor: UserInputProcessor {
797706
let mask = ones(like: promptArray)
798707

799708
// image_processing_qwen2_vl.preprocess
800-
let images = try input.images.map { try preprocess(images: [$0]) }
709+
let images = try input.images.map { try preprocess(images: [$0.asCIImage()]) }
801710
let pixels = concatenated(images.map { $0.0 })
802711
let image = LMInput.ProcessedImage(pixels: pixels, imageGridThw: images.map { $0.1 })
803712

@@ -873,8 +782,12 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
873782
-> PrepareResult
874783
{
875784
let gridThw = input.image?.imageGridThw
785+
786+
let dtype = visionModel.patchEmbed.proj.weight.dtype
787+
let pixels = input.image?.pixels.asType(dtype)
788+
876789
let inputEmbeddings = self.inputEmbeddings(
877-
inputIds: input.text.tokens, pixelValues: input.image?.pixels, gridThw: gridThw)
790+
inputIds: input.text.tokens, pixelValues: pixels, gridThw: gridThw)
878791

879792
let result = languageModel(nil, cache: cache, inputEmbedding: inputEmbeddings)
880793

0 commit comments

Comments
 (0)