add VLM support, refactor common LM code into MLXLMCommon. breaking API changes#151
add VLM support, refactor common LM code into MLXLMCommon. breaking API changes#151davidkoski merged 14 commits intomainfrom
Conversation
|
|
||
| import Foundation | ||
|
|
||
| public enum StringOrNumber: Codable, Equatable, Sendable { |
There was a problem hiding this comment.
move to LMCommon
| import Tokenizers | ||
| import MLXLMCommon | ||
|
|
||
| /// Container for models that guarantees single threaded access. |
There was a problem hiding this comment.
Move to ModelContainer
Libraries/LLM/LLMModel.swift
Outdated
| } | ||
| } | ||
| } | ||
| // TODO move? these cause some ambiguity -- how to resolve? |
There was a problem hiding this comment.
I was playing around with these to avoid breaking API -- moving types into LMCommon means callers will need to import LMCommon if they refer to them. This (the aliases) caused more trouble than I think it is worth
| import MLXLMCommon | ||
| import MLXNN | ||
| import MLXRandom | ||
| import Tokenizers |
There was a problem hiding this comment.
Ultimately I would like this to move into LMCommon -- I think it can support both LLM and VLM models, but I didn't get a chance to move this yet.
| import MLXRandom | ||
| import Tokenizers | ||
|
|
||
| /// Layers to apply LoRA adapters to. |
There was a problem hiding this comment.
Move to LMCommon
| } | ||
|
|
||
| /// Equivalent to `lora.py/iterate_batches()`. Used internally by ``LoRATrain``. | ||
| struct LoRABatchIterator: Sequence, IteratorProtocol { |
There was a problem hiding this comment.
Ideally the rest of this moves to LMCommon as well -- I think it can.
| mutating func prompt(_ prompt: MLXArray) | ||
| func process(logits: MLXArray) -> MLXArray | ||
| mutating func didSample(token: MLXArray) | ||
| } |
There was a problem hiding this comment.
The generate / step code has been refactored a bit and can now take custom logit samplers and processors
| public init( | ||
| prompt: MLXArray, model: any LanguageModel, cache: [KVCache]? = nil, | ||
| parameters: GenerateParameters | ||
| ) throws { |
There was a problem hiding this comment.
This now takes either a prompt (MLXArray) or an LMInput (text + image + ...) via multiple initializers.
| } | ||
| } | ||
|
|
||
| public struct LMInput { |
There was a problem hiding this comment.
A new union type that holds the different inputs to generate() and LanguageModel.prepare()
| } | ||
| } | ||
|
|
||
| public struct LMOutput { |
There was a problem hiding this comment.
Union type for the output. Some of the VLMs return additional state, which is represented here.
Libraries/LMCommon/Models.swift
Outdated
| extraEOSTokens: ["<|end|>"] | ||
| ) | ||
|
|
||
| // TODO the prompt formatter is replaced by the chat template |
Libraries/LMCommon/Processor.swift
Outdated
|
|
||
| import CoreImage | ||
| import Foundation | ||
| import MLX |
There was a problem hiding this comment.
This file may be deleted -- it was some notes & thoughts along the way
Libraries/LMCommon/Prompt.swift
Outdated
| // Copyright © 2024 Apple Inc. | ||
|
|
||
| import Foundation | ||
| import MLX |
There was a problem hiding this comment.
Also to be deleted -- LMInput replaces this
| private let context = CIContext() | ||
|
|
||
| // TODO documentation | ||
| public enum MediaProcessing { |
There was a problem hiding this comment.
Needs documentation, but see PaliGemmaImageProvider which implements
SiglipImageProcessor {
"do_convert_rgb": null,
"do_normalize": true,
"do_rescale": true,
"do_resize": true,
"image_mean": [
0.5,
0.5,
0.5
],
"image_processor_type": "SiglipImageProcessor",
"image_seq_length": 1024,
"image_std": [
0.5,
0.5,
0.5
],
"processor_class": "PaliGemmaProcessor",
"resample": 3,
"rescale_factor": 0.00392156862745098,
"size": {
"height": 448,
"width": 448
}
}
from the python transformers code.
| import MLXNN | ||
| import Tokenizers | ||
|
|
||
| // MARK: - Language |
There was a problem hiding this comment.
First cut at a port of https://github.com/Blaizzy/mlx-vlm/tree/main/mlx_vlm/models/paligemma
There was a problem hiding this comment.
Note: this builds, loads weights and "runs" but doesn't produce any output -- still needs to be debugged.
There was a problem hiding this comment.
it should be usable as an example of the structure I think we need
Libraries/VLM/Models/Paligemma.swift
Outdated
| } | ||
| } | ||
|
|
||
| // TODO does not suport multiple images -- how do we represent? |
There was a problem hiding this comment.
We need a protocol for the image and text processing pieces.
Libraries/VLM/Models/Paligemma.swift
Outdated
| image = MediaProcessing.inSRGBToneCurveSpace(image) | ||
|
|
||
| image = MediaProcessing.resampleBicubic(image, to: .init(width: size, height: size)) | ||
| image = MediaProcessing.normalize(image, mean: (0.5, 0.5, 0.5), std: (0.5, 0.5, 0.5)) |
There was a problem hiding this comment.
SiglipImageProcessor {
"do_convert_rgb": null,
"do_normalize": true,
"do_rescale": true,
"do_resize": true,
"image_mean": [
0.5,
0.5,
0.5
],
"image_processor_type": "SiglipImageProcessor",
"image_seq_length": 1024,
"image_std": [
0.5,
0.5,
0.5
],
"processor_class": "PaliGemmaProcessor",
"resample": 3,
"rescale_factor": 0.00392156862745098,
"size": {
"height": 448,
"width": 448
}
}
Libraries/VLM/Models/Paligemma.swift
Outdated
| } | ||
| } | ||
|
|
||
| private func loadConfiguration(url: URL) throws -> PaliGemma { |
There was a problem hiding this comment.
These next couple of functions are just stubs to let me try it out -- this will work much like the LLM models
| private let _ropeTheta: Float? | ||
| public var ropeTheta: Float { _ropeTheta ?? 10_000 } | ||
| public let _ropeTraditional: Bool? | ||
| public var ropeTraditional: Bool { _ropeTraditional ?? false } |
There was a problem hiding this comment.
Rather than doing the full implementation of Codable I went a simpler route for default values. Less code, cleaner (I think)
Tools/llm-tool/LLMTool.swift
Outdated
| @Option var path: URL | ||
|
|
||
| @MainActor | ||
| mutating func run() async throws { |
There was a problem hiding this comment.
Just stub code to exercise the model. This still needs the input processing layers, in particular the prompt processing. The image processing is in place but will need to be wrapped up API-wise.
There was a problem hiding this comment.
This is now the real code
e19f736 to
5ffe9b3
Compare
| import LLM | ||
| import MLX | ||
| import MLXLLM | ||
| import MLXLMCommon |
There was a problem hiding this comment.
See PR description -- split LLM -> LLM and LMCommon. Switched local names to match what people get via swiftpm (MLXLLM, etc.).
| /// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on | ||
| /// more devices. | ||
| let modelConfiguration = ModelConfiguration.phi3_5_4bit | ||
| let modelConfiguration = ModelRegistry.phi3_5_4bit |
There was a problem hiding this comment.
From PR description:
- constants for models have moved from
ModelConfigurationtoModelRegistry - this is
MLXLM.ModelRegistryand there is alsoMLXVLM.ModelRegistry
- let modelConfiguration = ModelConfiguration.phi3_5_4bit
+ let modelConfiguration = ModelRegistry.phi3_5_4bit|
This code is ready for review! |
awni
left a comment
There was a problem hiding this comment.
This is incredibly cool. I barely touched the surface but leaving a small review and going to try running it shortly.
| structure something like this: | ||
|
|
||
| ```swift | ||
| public class YourModel: Module, LLMModel, KVCacheDimensionProvider, LoRAModel { |
There was a problem hiding this comment.
Btw I changed the KV cache implementation in mlx-lm to just init the keys and values the first time you call it. There is no need to initialize the KV cache with a head dim etc. so we could probably remove this interface as well. (Just a comment not something that we need to update in this PR)
There was a problem hiding this comment.
OK, I will take a look at it -- if it simplifies things it may be worth including here as we are already making some breaking changes.
There was a problem hiding this comment.
- revisit KVCache / mlx-lm
Libraries/MLXLLM/README.md
Outdated
| public let kvHeads: [Int] | ||
| public let headDim: IntOrPair |
There was a problem hiding this comment.
And e.g. got rid of this which is not necessary
Tools/llm-tool/LLMTool.swift
Outdated
| let (modelContainer, modelConfiguration) = try await memory.start(args.load) | ||
| let modelContainer = try await memory.start { [args] in | ||
| try await args.load( | ||
| defaultModel: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx", |
There was a problem hiding this comment.
We should update this default model, it's pretty dated. Maybe to mlx-community/Mistral-7B-Instruct-v0.3-4bit is a good option?
There was a problem hiding this comment.
Sure, I will give it a run and make sure it works!
- test this
There was a problem hiding this comment.
It is one of the preset models, so good to go
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Tools/llm-tool/LLMTool.swift
Outdated
| @MainActor | ||
| mutating func run() async throws { | ||
| let (modelContainer, modelConfiguration) = try await memory.start(args.load) | ||
| let modelContainer = try await memory.start { [args] in |
There was a problem hiding this comment.
Can we rename this to LMCommand and subcommand lm, to match the VLMCommand.
Alternatively (given the complexity) it might be worth using the same subcommand and just dispatching to the vlm subroutine if an image input is provided or not..
There was a problem hiding this comment.
Interesting idea! The default model is different, as is the model factory. We could certainly switch on the presence of an image (or video) to chose but I wonder if that complicates things over just having the two subcommands?
Let me try the refactor to fold these down into one and see if that looks reasonable.
- try refactor of vlm -> eval (lm) command
There was a problem hiding this comment.
Yes it was a slightly off the cuff suggestion. It simplifies the command line but it might not be worth doing at the expense of code complexity.
There was a problem hiding this comment.
I think that worked well -- it came down to this (mostly):
// switch between LLM and VLM
let vlm = image.count > 0
if vlm {
modelFactory = VLMModelFactory.shared
defaultModel = MLXVLM.ModelRegistry.paligemma3bMix448_8bit
} else {
modelFactory = LLMModelFactory.shared
defaultModel = MLXLLM.ModelRegistry.mistral7B4bit
}| /// ```swift | ||
| /// let messages = [["role": "user", "content": prompt]] | ||
| /// let promptTokens = try await modelContainer.perform { context in | ||
| /// try context.tokenizer.applyChatTemplate(messages: messages) | ||
| /// } | ||
| /// ``` | ||
| /// | ||
| /// or: | ||
| /// | ||
| /// ```swift | ||
| /// let result = await modelContainer.perform { context in | ||
| /// LLM.generate( | ||
| /// promptTokens: promptTokens, parameters: generateParameters, model: context.model, | ||
| /// tokenizer: context.tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens | ||
| /// ) { tokens in | ||
| /// ... | ||
| /// } |
There was a problem hiding this comment.
yes, thanks for spotting that!
| let inputEmbedding = languageModel.model.embedTokens(inputIds) | ||
| let (hiddenState, _, _) = self.visionModel( | ||
| pixelValues.transposed(0, 2, 3, 1).asType(inputEmbedding.dtype), | ||
| outputHiddenStates: true | ||
| ) |
There was a problem hiding this comment.
We have to be pretty careful with data types in these models cause it's really easy to upcast to fp32 by accident and that can slow things down a lot or use a lot more memory (or both).
One thing I recommend doing is if you have a test suite that runs the models, making sure the output type is the same as the input type.
Here you cast the pixelValues to the embedding type which is good. But below you cast the output back to the pixelValues type which I'm not sure about.. I would just keep those in the same model type.
There was a problem hiding this comment.
Good spot on that!
inputEmbedding float16, hiddenState float32, pixelValues float32
| let embedDimension = imageFeatures.dim(2) | ||
| let (batchSize, sequenceLength) = inputIds.shape2 | ||
| var scaledImageFeatures = imageFeatures / pow(Float(config.hiddenSize), 0.5) | ||
| var finalEmbedding = zeros([batchSize, sequenceLength, embedDimension]) |
There was a problem hiding this comment.
The default data type of zeros is fp32. That will cause anything that works with this finalEmbedding to be upcasat to fp32.
| let (inputEmbedding, finalAttentionMask4d) = inputEmbeddings( | ||
| inputIds: inputIds, pixelValues: image.pixels, mask: mask) |
There was a problem hiding this comment.
We might want to cast the inputEmbedding to the LM dtype as well (get it from the embedding layer weight or something).. just in case they have different types.
There was a problem hiding this comment.
handled inside the inpuEmbeddings function:
private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, mask: MLXArray) -> (
MLXArray, MLXArray
) {
guard let pixelValues else {
return (inputIds, mask)
}
let inputEmbedding = languageModel.model.embedTokens(inputIds)
let (hiddenState, _, _) = self.visionModel(
pixelValues.transposed(0, 2, 3, 1).asType(inputEmbedding.dtype),| imageMaskExpanded = repeated(imageMaskExpanded, count: embedDimension, axis: -1) | ||
| finalEmbedding = which(imageMaskExpanded, scaledImageFeatures, finalEmbedding) | ||
|
|
||
| finalEmbedding = which(padMaskExpanded, zeros(like: finalEmbedding), finalEmbedding) |
There was a problem hiding this comment.
In python it's better to do:
mx.where(mask, array, 0.0) since the 0 will be broadcast and inherit the type of array. I think the same is true in Swift?
There was a problem hiding this comment.
yes, to avoid the zeros float32 (and maybe faster to boot because of the broadcasting instead of a realized array). done
|
|
||
| // insert image embeddings - the image mask is always less or equal to the sentence in length | ||
| var imageMaskExpanded = expandedDimensions(imageMask, axis: -1) | ||
| imageMaskExpanded = repeated(imageMaskExpanded, count: embedDimension, axis: -1) |
There was a problem hiding this comment.
There is no need to explicitly repeat these.. just rely on the fact that which broadcasts it's inputs against one another. Same is true for most of the calls to repeated above.
There was a problem hiding this comment.
wow, went from ~92 tokens / s -> 112 tokens / s
|
|
||
| // insert padding and text token embeddings | ||
| finalEmbedding = which(textMaskExpanded, inputEmbedding, finalEmbedding) | ||
| finalEmbedding = which(padMaskExpanded, zeros(like: finalEmbedding), finalEmbedding) |
There was a problem hiding this comment.
This zeros also should be a plain scalar and inherit the type of the finalEmbedding.
awni
left a comment
There was a problem hiding this comment.
Massive! Thanks for adding this!
Status: almost ready, just testing and cleaning up. Models are working. I am using a local override of mlx-swift main.
Xcode 16
Xcode 16 is required to build the example applications and tools. Older Xcode can still build the libraries via swiftpm (so no changes in requirements to any applications or libraries that refer to this).
This change is required because the xcodeproj now refers to the local
Package.swiftfile to get builds consistent with external users. If needed we can switch back to using xcodeproj for library builds (internal) and swiftpm for library builds (external) -- if there is a problem please file an issue and it can be considered.Additions
There are two new libraries:
MLXVLMcontains vision language models that combine images and text prompts to produce text results, e.g.describe this imageMLXLMCommoncontains theLanguageModelcode that is shared betweenMLXLLMandMLXVLMThe API between
LLMandVLMis identical aside from the preparation of theUserInput.VLM example code is available in the
llm-toolexample:Breaking Changes
Probably no effect to code external to this repo:
Package.swiftto build the librariesimport LLM->import MLXLLMLLM->MLXLLMBreaking:
MLXLLMandMLXLMCommon(particularly code that loads models)MLXLMCommoncontains the common API between LLM and VLMModelConfigurationtoModelRegistryMLXLM.ModelRegistryand there is alsoMLXVLM.ModelRegistryloadModelContainer()function is nowLLMModelFactory.shared.loadContainer()VLMModelFactorywith identical methods for loading VLMsModelContainer.performis now throwing (and in MLXLMCommon):ModelConfigurationpreviously had a way to register new configurations. This is now onLLMModelFactory(andVLMModelFactoryhas the same):Deprecations
An example at the end shows all of these deprecations in context.
Prefer to use the
ModelContext.processorto prepare prompts. Previously users would pass in a bare[Int]of tokens, but in order to support more complex inputs (VLMs) the use of bare[Int]is deprecated and callers should useUserInputandLMInput.For example, previously callers might have done something like this:
Now that should be:
Which will initialize a
UserInputfrom the prompt text and produce anLMInputthat can be used to generate tokens.This call to
generate()is now deprecated:This consumed the
[Int]variety of tokens. Now this is preferred:This method on
ModelContaineris now deprecated:/// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as /// `MLXArray` is not `Sendable`. @available(*, deprecated, message: "prefer perform(_:) that uses a ModelContext") public func perform<R>(_ action: @Sendable (any LanguageModel, Tokenizer) throws -> R) rethrows -> Ruse this one instead (though the former still works):
/// Perform an action on the ``ModelContext``. Callers _must_ eval any `MLXArray` before returning as /// `MLXArray` is not `Sendable`. public func perform<R>(_ action: @Sendable (ModelContext) async throws -> R) async rethrows -> RExample
Putting all of these deprecations together, previously you might have generated text like this:
now do this: