Skip to content

Commit 5e4c906

Browse files
committed
qwen2-vl working
1 parent 8a68223 commit 5e4c906

File tree

6 files changed

+120
-50
lines changed

6 files changed

+120
-50
lines changed

Libraries/LMCommon/LanguageModel.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ public struct THW: Sendable {
2020
public var values: (Int, Int, Int) {
2121
(t, h, w)
2222
}
23+
24+
public var product: Int { t * h * w }
2325
}
2426

2527
extension Array where Element == THW {

Libraries/LMCommon/InputProcessor.swift renamed to Libraries/LMCommon/UserInput.swift

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,21 @@ import CoreImage
44
import Foundation
55
import MLX
66

7-
public enum UserInputPrompt: Sendable {
8-
case text(String)
9-
case messages([[String: String]])
10-
11-
public func asMessages() -> [[String: String]] {
12-
switch self {
13-
case .text(let text):
14-
return [["role": "user", "content": text]]
15-
case .messages(let messages):
16-
return messages
7+
public struct UserInput: Sendable {
8+
9+
public enum Prompt: Sendable {
10+
case text(String)
11+
case messages([[String: String]])
12+
13+
public func asMessages() -> [[String: String]] {
14+
switch self {
15+
case .text(let text):
16+
return [["role": "user", "content": text]]
17+
case .messages(let messages):
18+
return messages
19+
}
1720
}
1821
}
19-
}
20-
21-
public struct UserInput: Sendable {
2222

2323
public enum Image: Sendable {
2424
case ciImage(CIImage)
@@ -83,8 +83,13 @@ public struct UserInput: Sendable {
8383
}
8484
}
8585

86-
public var prompt: UserInputPrompt
86+
public struct Processing: Sendable {
87+
public var resize: CGSize?
88+
}
89+
90+
public var prompt: Prompt
8791
public var images = [Image]()
92+
public var processing: Processing = .init()
8893

8994
public init(prompt: String, images: [Image] = [Image]()) {
9095
self.prompt = .text(prompt)
@@ -96,7 +101,7 @@ public struct UserInput: Sendable {
96101
self.images = images
97102
}
98103

99-
public init(prompt: UserInputPrompt, images: [Image] = [Image]()) {
104+
public init(prompt: Prompt, images: [Image] = [Image]()) {
100105
self.prompt = prompt
101106
self.images = images
102107
}

Libraries/VLM/MediaProcessing.swift

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// Copyright © 2024 Apple Inc.
22

3-
@preconcurrency import CoreImage.CIFilterBuiltins
3+
import CoreImage.CIFilterBuiltins
44
import MLX
5+
import MLXLMCommon
56

67
private let context = CIContext()
78

@@ -111,4 +112,15 @@ public enum MediaProcessing {
111112

112113
return array
113114
}
115+
116+
static func apply(_ image: CIImage, processing: UserInput.Processing?) -> CIImage {
117+
var image = image
118+
119+
if let resize = processing?.resize {
120+
let scale = bestFitScale(image.extent.size, in: resize)
121+
image = image.transformed(by: CGAffineTransform(scaleX: scale, y: scale))
122+
}
123+
124+
return image
125+
}
114126
}

Libraries/VLM/Models/Paligemma.swift

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ public class PaligGemmaProcessor: UserInputProcessor {
450450
self.tokenizer = tokenizer
451451
}
452452

453-
public func convert(image: CIImage) -> MLXArray {
453+
private func prepare(image: CIImage, processing: UserInput.Processing?) -> MLXArray {
454454
// based on image_processing_siglip from transformers
455455
var image = image
456456

@@ -459,6 +459,9 @@ public class PaligGemmaProcessor: UserInputProcessor {
459459
// do (implicitly by using sRGB rasters directly)
460460
image = MediaProcessing.inSRGBToneCurveSpace(image)
461461

462+
// apply user instructions
463+
image = MediaProcessing.apply(image, processing: processing)
464+
462465
image = MediaProcessing.resampleBicubic(image, to: config.size.cgSize)
463466
image = MediaProcessing.normalize(
464467
image, mean: config.imageMeanTuple, std: config.imageStdTuple)
@@ -473,7 +476,7 @@ public class PaligGemmaProcessor: UserInputProcessor {
473476
default: throw VLMError.singleImageAllowed
474477
}
475478

476-
// this doesn't have a chat template so just use the last message
479+
// this doesn't have a chat template so just use the last message.
477480
var prompt = input.prompt.asMessages().last?["content"] ?? ""
478481

479482
// based on transformers/processing_paligemma
@@ -486,7 +489,7 @@ public class PaligGemmaProcessor: UserInputProcessor {
486489
let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
487490
let mask = ones(like: promptArray)
488491

489-
let pixels = try convert(image: input.images[0].asCIImage())
492+
let pixels = try prepare(image: input.images[0].asCIImage(), processing: input.processing)
490493

491494
return LMInput(text: .init(tokens: promptArray, mask: mask), image: .init(pixels: pixels))
492495
}

Libraries/VLM/Models/Qwen2VL.swift

Lines changed: 64 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,8 @@ private enum Vision {
283283
}
284284

285285
func callAsFunction(sequenceLength: Int) -> MLXArray {
286-
let inverseFreq =
287-
1.0
288-
/ (pow(
289-
theta,
290-
MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions))
286+
let p = MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions
287+
let inverseFreq = 1.0 / pow(theta, p)
291288
let seq = MLXArray(0 ..< sequenceLength).asType(inverseFreq.dtype)
292289
let freqs = outer(seq, inverseFreq)
293290
return freqs
@@ -370,16 +367,6 @@ private enum Vision {
370367
self._proj.wrappedValue = Linear(dims, dims)
371368
}
372369

373-
private func makeMask(cuSequenceLengths: MLXArray, sequenceLength: Int) -> MLXArray {
374-
let starts = cuSequenceLengths[.newAxis, ..<(-1)]
375-
let ends = cuSequenceLengths[.newAxis, 1...]
376-
let indices = MLXArray(0 ..< sequenceLength)[0..., .newAxis]
377-
var mask = (indices .>= starts) & (indices .< ends)
378-
mask = mask.any(axis: -1)
379-
mask = mask[.newAxis] & mask[0..., .newAxis]
380-
return 1 - mask
381-
}
382-
383370
public func callAsFunction(
384371
_ x: MLXArray, gridThw: [THW], rotaryPositionEmbedding: MLXArray
385372
) -> MLXArray {
@@ -391,6 +378,10 @@ private enum Vision {
391378
let s = split(qkv, parts: 3, axis: 1)
392379
var (q, k, v) = (s[0], s[1], s[2])
393380

381+
q = q.reshaped(sequenceLength, numHeads, -1)
382+
k = k.reshaped(sequenceLength, numHeads, -1)
383+
v = v.reshaped(sequenceLength, numHeads, -1)
384+
394385
q = applyMultimodalRotaryPositionEmbedding(q, freqs: rotaryPositionEmbedding)
395386
k = applyMultimodalRotaryPositionEmbedding(k, freqs: rotaryPositionEmbedding)
396387

@@ -530,8 +521,6 @@ private enum Vision {
530521
let rotaryPositionEmbedFull = rotaryPositionEmbedding(sequenceLength: maxGridSize)[
531522
indices]
532523

533-
print("rot_pos_emb(), \(maxGridSize) \(gridThw), \(rotaryPositionEmbedFull.shape)")
534-
535524
return rotaryPositionEmbedFull.reshaped(indices.dim(0), -1)
536525
}
537526

@@ -634,16 +623,20 @@ public class Qwen2VLProcessor: UserInputProcessor {
634623
return (hBar, wBar)
635624
}
636625

637-
public func preprocess(images: [CIImage]) throws -> (MLXArray, THW) {
626+
public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> (
627+
MLXArray, THW
628+
) {
638629

639630
// image_processing_qwen2_vl._preprocess
640631

632+
let images = images.map { MediaProcessing.apply($0, processing: processing) }
633+
641634
let size = images[0].extent.size
642635
let (resizedHeight, resizedWidth) = try targetSize(
643636
height: Int(size.height), width: Int(size.width),
644637
factor: config.patchSize * config.mergeSize,
645638
minPixels: config.size.minPixels, maxPixels: config.size.maxPixels)
646-
let resizedSize = CGSize(width: resizedHeight, height: resizedWidth)
639+
let resizedSize = CGSize(width: resizedWidth, height: resizedHeight)
647640

648641
let processedImages =
649642
try images
@@ -691,26 +684,68 @@ public class Qwen2VLProcessor: UserInputProcessor {
691684
return (flattenedPatches, .init(gridT, gridH, gridW))
692685
}
693686

694-
public func prepare(input: UserInput) throws -> LMInput {
695-
// this doesn't have a chat template so just use the last message
696-
let prompt = input.prompt.asMessages().last?["content"] ?? ""
687+
public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?) -> String {
688+
// the tokenizer does have a chat template and it expects messages
689+
// like this:
690+
//
691+
// [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'},
692+
// {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}]
693+
//
694+
// The output of the prompt template is fed into
695+
// image_processing_qwen2_vl.preprocess where it is further augmented
696+
// by replacing tokens according to imageTHW.
697+
//
698+
// Neither the structured content nor the postprocessing of the template
699+
// are supported in current Tokenizer/Jinja (swift) so handle that here.
700+
701+
var messages = prompt.asMessages()
702+
if messages[0]["role"] != "system" {
703+
messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0)
704+
}
705+
706+
let lastIndex = messages.count - 1
707+
var lastMessage = messages[lastIndex]["content"] ?? ""
708+
709+
// image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image
710+
let mergeLength = config.mergeSize * config.mergeSize
711+
for thw in imageTHW ?? [] {
712+
lastMessage += "<|vision_start|>"
713+
lastMessage += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength)
714+
.joined()
715+
lastMessage += "<|vision_end|>"
716+
}
717+
718+
messages[lastIndex]["content"] = lastMessage
719+
720+
return
721+
messages
722+
.map {
723+
"<|im_start|>\($0["role"] ?? "user")\n\($0["content"] ?? "")<|im_end|>"
724+
}
725+
.joined(separator: "\n")
726+
+ "\n<|im_start|>assistant\n"
727+
}
697728

729+
public func prepare(input: UserInput) throws -> LMInput {
698730
if input.images.isEmpty {
699731
// just a straight text prompt
732+
let prompt = prepare(prompt: input.prompt, imageTHW: nil)
700733
let promptTokens = try tokenizer.encode(text: prompt)
701734
return LMInput(tokens: MLXArray(promptTokens))
702735
}
703736

704-
let promptTokens = try tokenizer.encode(text: prompt)
705-
let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
706-
let mask = ones(like: promptArray)
707-
708737
// image_processing_qwen2_vl.preprocess
709-
let images = try input.images.map { try preprocess(images: [$0.asCIImage()]) }
738+
let images = try input.images.map {
739+
try preprocess(images: [$0.asCIImage()], processing: input.processing)
740+
}
710741
let pixels = concatenated(images.map { $0.0 })
711742
let image = LMInput.ProcessedImage(pixels: pixels, imageGridThw: images.map { $0.1 })
712743

713-
print("image \(image.pixels.shape), \(image.imageGridThw)")
744+
// processing_qwen2_vl.Qwen2VLProcessor
745+
let prompt = prepare(prompt: input.prompt, imageTHW: image.imageGridThw)
746+
let promptTokens = try tokenizer.encode(text: prompt)
747+
let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
748+
let mask = ones(like: promptArray)
714749

715750
return LMInput(text: .init(tokens: promptArray, mask: mask), image: image)
716751
}
@@ -773,7 +808,7 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
773808
imageIndices.append(i)
774809
}
775810
}
776-
// TODO look at the inputIds -- I think I am missing something here
811+
777812
inputEmbeds[0..., MLXArray(imageIndices), 0...] = imageFeatures
778813
return inputEmbeds
779814
}

Tools/llm-tool/LLMTool.swift

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,14 +258,27 @@ struct VLMCommand: AsyncParsableCommand {
258258
mutating func run() async throws {
259259
let modelContainer = try await memory.start { [args] in
260260
try await args.load(
261-
defaultModel: "mlx-community/paligemma-3b-mix-448-8bit",
261+
defaultModel: MLXVLM.ModelRegistry.paligemma3bMix4488bit.name,
262262
modelFactory: VLMModelFactory.shared)
263263
}
264264
let modelConfiguration = modelContainer.configuration
265265

266266
let prompt = generate.prompt ?? modelConfiguration.defaultPrompt
267267

268-
let input = UserInput(prompt: prompt, images: image.map { .url($0) })
268+
var input = UserInput(prompt: prompt, images: image.map { .url($0) })
269+
270+
if !resize.isEmpty {
271+
let size: CGSize
272+
if resize.count == 1 {
273+
let v = resize[0]
274+
size = CGSize(width: v, height: v)
275+
} else {
276+
let v0 = resize[0]
277+
let v1 = resize[0]
278+
size = CGSize(width: v0, height: v1)
279+
}
280+
input.processing.resize = size
281+
}
269282

270283
let result = try await modelContainer.perform { [generate] context in
271284
let input = try context.processor.prepare(input: input)

0 commit comments

Comments
 (0)