diff --git a/Libraries/LMCommon/LanguageModel.swift b/Libraries/LMCommon/LanguageModel.swift index 5c68cba2..b6edf27c 100644 --- a/Libraries/LMCommon/LanguageModel.swift +++ b/Libraries/LMCommon/LanguageModel.swift @@ -20,6 +20,8 @@ public struct THW: Sendable { public var values: (Int, Int, Int) { (t, h, w) } + + public var product: Int { t * h * w } } extension Array where Element == THW { diff --git a/Libraries/LMCommon/InputProcessor.swift b/Libraries/LMCommon/UserInput.swift similarity index 83% rename from Libraries/LMCommon/InputProcessor.swift rename to Libraries/LMCommon/UserInput.swift index a9ae917c..85312281 100644 --- a/Libraries/LMCommon/InputProcessor.swift +++ b/Libraries/LMCommon/UserInput.swift @@ -4,21 +4,21 @@ import CoreImage import Foundation import MLX -public enum UserInputPrompt: Sendable { - case text(String) - case messages([[String: String]]) - - public func asMessages() -> [[String: String]] { - switch self { - case .text(let text): - return [["role": "user", "content": text]] - case .messages(let messages): - return messages +public struct UserInput: Sendable { + + public enum Prompt: Sendable { + case text(String) + case messages([[String: String]]) + + public func asMessages() -> [[String: String]] { + switch self { + case .text(let text): + return [["role": "user", "content": text]] + case .messages(let messages): + return messages + } } } -} - -public struct UserInput: Sendable { public enum Image: Sendable { case ciImage(CIImage) @@ -83,8 +83,13 @@ public struct UserInput: Sendable { } } - public var prompt: UserInputPrompt + public struct Processing: Sendable { + public var resize: CGSize? + } + + public var prompt: Prompt public var images = [Image]() + public var processing: Processing = .init() public init(prompt: String, images: [Image] = [Image]()) { self.prompt = .text(prompt) @@ -96,7 +101,7 @@ public struct UserInput: Sendable { self.images = images } - public init(prompt: UserInputPrompt, images: [Image] = [Image]()) { + public init(prompt: Prompt, images: [Image] = [Image]()) { self.prompt = prompt self.images = images } diff --git a/Libraries/VLM/MediaProcessing.swift b/Libraries/VLM/MediaProcessing.swift index 6de707e3..b662d7a7 100644 --- a/Libraries/VLM/MediaProcessing.swift +++ b/Libraries/VLM/MediaProcessing.swift @@ -1,7 +1,8 @@ // Copyright © 2024 Apple Inc. -@preconcurrency import CoreImage.CIFilterBuiltins +import CoreImage.CIFilterBuiltins import MLX +import MLXLMCommon private let context = CIContext() @@ -111,4 +112,15 @@ public enum MediaProcessing { return array } + + static func apply(_ image: CIImage, processing: UserInput.Processing?) -> CIImage { + var image = image + + if let resize = processing?.resize { + let scale = bestFitScale(image.extent.size, in: resize) + image = image.transformed(by: CGAffineTransform(scaleX: scale, y: scale)) + } + + return image + } } diff --git a/Libraries/VLM/Models/Paligemma.swift b/Libraries/VLM/Models/Paligemma.swift index e50cfb32..330cd56a 100644 --- a/Libraries/VLM/Models/Paligemma.swift +++ b/Libraries/VLM/Models/Paligemma.swift @@ -450,7 +450,7 @@ public class PaligGemmaProcessor: UserInputProcessor { self.tokenizer = tokenizer } - public func convert(image: CIImage) -> MLXArray { + private func prepare(image: CIImage, processing: UserInput.Processing?) -> MLXArray { // based on image_processing_siglip from transformers var image = image @@ -459,6 +459,9 @@ public class PaligGemmaProcessor: UserInputProcessor { // do (implicitly by using sRGB rasters directly) image = MediaProcessing.inSRGBToneCurveSpace(image) + // apply user instructions + image = MediaProcessing.apply(image, processing: processing) + image = MediaProcessing.resampleBicubic(image, to: config.size.cgSize) image = MediaProcessing.normalize( image, mean: config.imageMeanTuple, std: config.imageStdTuple) @@ -473,7 +476,7 @@ public class PaligGemmaProcessor: UserInputProcessor { default: throw VLMError.singleImageAllowed } - // this doesn't have a chat template so just use the last message + // this doesn't have a chat template so just use the last message. var prompt = input.prompt.asMessages().last?["content"] ?? "" // based on transformers/processing_paligemma @@ -486,7 +489,7 @@ public class PaligGemmaProcessor: UserInputProcessor { let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) let mask = ones(like: promptArray) - let pixels = try convert(image: input.images[0].asCIImage()) + let pixels = try prepare(image: input.images[0].asCIImage(), processing: input.processing) return LMInput(text: .init(tokens: promptArray, mask: mask), image: .init(pixels: pixels)) } diff --git a/Libraries/VLM/Models/Qwen2VL.swift b/Libraries/VLM/Models/Qwen2VL.swift index 1d7cabb1..80ac829f 100644 --- a/Libraries/VLM/Models/Qwen2VL.swift +++ b/Libraries/VLM/Models/Qwen2VL.swift @@ -283,11 +283,8 @@ private enum Vision { } func callAsFunction(sequenceLength: Int) -> MLXArray { - let inverseFreq = - 1.0 - / (pow( - theta, - MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions)) + let p = MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions + let inverseFreq = 1.0 / pow(theta, p) let seq = MLXArray(0 ..< sequenceLength).asType(inverseFreq.dtype) let freqs = outer(seq, inverseFreq) return freqs @@ -370,16 +367,6 @@ private enum Vision { self._proj.wrappedValue = Linear(dims, dims) } - private func makeMask(cuSequenceLengths: MLXArray, sequenceLength: Int) -> MLXArray { - let starts = cuSequenceLengths[.newAxis, ..<(-1)] - let ends = cuSequenceLengths[.newAxis, 1...] - let indices = MLXArray(0 ..< sequenceLength)[0..., .newAxis] - var mask = (indices .>= starts) & (indices .< ends) - mask = mask.any(axis: -1) - mask = mask[.newAxis] & mask[0..., .newAxis] - return 1 - mask - } - public func callAsFunction( _ x: MLXArray, gridThw: [THW], rotaryPositionEmbedding: MLXArray ) -> MLXArray { @@ -391,6 +378,10 @@ private enum Vision { let s = split(qkv, parts: 3, axis: 1) var (q, k, v) = (s[0], s[1], s[2]) + q = q.reshaped(sequenceLength, numHeads, -1) + k = k.reshaped(sequenceLength, numHeads, -1) + v = v.reshaped(sequenceLength, numHeads, -1) + q = applyMultimodalRotaryPositionEmbedding(q, freqs: rotaryPositionEmbedding) k = applyMultimodalRotaryPositionEmbedding(k, freqs: rotaryPositionEmbedding) @@ -530,8 +521,6 @@ private enum Vision { let rotaryPositionEmbedFull = rotaryPositionEmbedding(sequenceLength: maxGridSize)[ indices] - print("rot_pos_emb(), \(maxGridSize) \(gridThw), \(rotaryPositionEmbedFull.shape)") - return rotaryPositionEmbedFull.reshaped(indices.dim(0), -1) } @@ -634,16 +623,20 @@ public class Qwen2VLProcessor: UserInputProcessor { return (hBar, wBar) } - public func preprocess(images: [CIImage]) throws -> (MLXArray, THW) { + public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> ( + MLXArray, THW + ) { // image_processing_qwen2_vl._preprocess + let images = images.map { MediaProcessing.apply($0, processing: processing) } + let size = images[0].extent.size let (resizedHeight, resizedWidth) = try targetSize( height: Int(size.height), width: Int(size.width), factor: config.patchSize * config.mergeSize, minPixels: config.size.minPixels, maxPixels: config.size.maxPixels) - let resizedSize = CGSize(width: resizedHeight, height: resizedWidth) + let resizedSize = CGSize(width: resizedWidth, height: resizedHeight) let processedImages = try images @@ -691,26 +684,68 @@ public class Qwen2VLProcessor: UserInputProcessor { return (flattenedPatches, .init(gridT, gridH, gridW)) } - public func prepare(input: UserInput) throws -> LMInput { - // this doesn't have a chat template so just use the last message - let prompt = input.prompt.asMessages().last?["content"] ?? "" + public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?) -> String { + // the tokenizer does have a chat template and it expects messages + // like this: + // + // [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'}, + // {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}] + // + // The output of the prompt template is fed into + // image_processing_qwen2_vl.preprocess where it is further augmented + // by replacing tokens according to imageTHW. + // + // Neither the structured content nor the postprocessing of the template + // are supported in current Tokenizer/Jinja (swift) so handle that here. + + var messages = prompt.asMessages() + if messages[0]["role"] != "system" { + messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0) + } + + let lastIndex = messages.count - 1 + var lastMessage = messages[lastIndex]["content"] ?? "" + + // image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image + let mergeLength = config.mergeSize * config.mergeSize + for thw in imageTHW ?? [] { + lastMessage += "<|vision_start|>" + lastMessage += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength) + .joined() + lastMessage += "<|vision_end|>" + } + + messages[lastIndex]["content"] = lastMessage + + return + messages + .map { + "<|im_start|>\($0["role"] ?? "user")\n\($0["content"] ?? "")<|im_end|>" + } + .joined(separator: "\n") + + "\n<|im_start|>assistant\n" + } + public func prepare(input: UserInput) throws -> LMInput { if input.images.isEmpty { // just a straight text prompt + let prompt = prepare(prompt: input.prompt, imageTHW: nil) let promptTokens = try tokenizer.encode(text: prompt) return LMInput(tokens: MLXArray(promptTokens)) } - let promptTokens = try tokenizer.encode(text: prompt) - let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) - let mask = ones(like: promptArray) - // image_processing_qwen2_vl.preprocess - let images = try input.images.map { try preprocess(images: [$0.asCIImage()]) } + let images = try input.images.map { + try preprocess(images: [$0.asCIImage()], processing: input.processing) + } let pixels = concatenated(images.map { $0.0 }) let image = LMInput.ProcessedImage(pixels: pixels, imageGridThw: images.map { $0.1 }) - print("image \(image.pixels.shape), \(image.imageGridThw)") + // processing_qwen2_vl.Qwen2VLProcessor + let prompt = prepare(prompt: input.prompt, imageTHW: image.imageGridThw) + let promptTokens = try tokenizer.encode(text: prompt) + let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) + let mask = ones(like: promptArray) return LMInput(text: .init(tokens: promptArray, mask: mask), image: image) } @@ -773,7 +808,7 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider { imageIndices.append(i) } } - // TODO look at the inputIds -- I think I am missing something here + inputEmbeds[0..., MLXArray(imageIndices), 0...] = imageFeatures return inputEmbeds } diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift index b6b403b5..ec946170 100644 --- a/Tools/llm-tool/LLMTool.swift +++ b/Tools/llm-tool/LLMTool.swift @@ -258,14 +258,27 @@ struct VLMCommand: AsyncParsableCommand { mutating func run() async throws { let modelContainer = try await memory.start { [args] in try await args.load( - defaultModel: "mlx-community/paligemma-3b-mix-448-8bit", + defaultModel: MLXVLM.ModelRegistry.paligemma3bMix4488bit.name, modelFactory: VLMModelFactory.shared) } let modelConfiguration = modelContainer.configuration let prompt = generate.prompt ?? modelConfiguration.defaultPrompt - let input = UserInput(prompt: prompt, images: image.map { .url($0) }) + var input = UserInput(prompt: prompt, images: image.map { .url($0) }) + + if !resize.isEmpty { + let size: CGSize + if resize.count == 1 { + let v = resize[0] + size = CGSize(width: v, height: v) + } else { + let v0 = resize[0] + let v1 = resize[0] + size = CGSize(width: v0, height: v1) + } + input.processing.resize = size + } let result = try await modelContainer.perform { [generate] context in let input = try context.processor.prepare(input: input)