diff --git a/.circleci/config.yml b/.circleci/config.yml index 328335c..8ecd4bc 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -15,7 +15,7 @@ jobs: mac_build_and_test: macos: - xcode: 15.2.0 + xcode: 16.0.0 resource_class: macos.m1.medium.gen1 steps: - checkout @@ -35,8 +35,9 @@ jobs: xcrun --show-sdk-build-version swift --version find . -name Package.resolved -exec rm {} \; - xcodebuild -skipPackagePluginValidation -scheme llm-tool - xcodebuild -skipPackagePluginValidation -scheme mnist-tool + xcodebuild -scheme llm-tool + xcodebuild -scheme image-tool + xcodebuild -scheme mnist-tool workflows: build_and_test: diff --git a/Applications/LLMEval/ContentView.swift b/Applications/LLMEval/ContentView.swift index e72a522..4690bf8 100644 --- a/Applications/LLMEval/ContentView.swift +++ b/Applications/LLMEval/ContentView.swift @@ -1,7 +1,8 @@ // Copyright © 2024 Apple Inc. -import LLM import MLX +import MLXLLM +import MLXLMCommon import MLXRandom import MarkdownUI import Metal @@ -159,7 +160,7 @@ class LLMEvaluator { /// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on /// more devices. - let modelConfiguration = ModelConfiguration.phi3_5_4bit + let modelConfiguration = ModelRegistry.phi3_5_4bit /// parameters controlling the output let generateParameters = GenerateParameters(temperature: 0.6) @@ -185,17 +186,17 @@ class LLMEvaluator { // limit the buffer cache MLX.GPU.set(cacheLimit: 20 * 1024 * 1024) - let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration) - { + let modelContainer = try await LLMModelFactory.shared.loadContainer( + configuration: modelConfiguration + ) { [modelConfiguration] progress in Task { @MainActor in self.modelInfo = "Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%" } } - let numParams = await modelContainer.perform { - [] model, _ in - return model.numParameters() + let numParams = await modelContainer.perform { context in + context.model.numParameters() } self.modelInfo = @@ -217,22 +218,17 @@ class LLMEvaluator { do { let modelContainer = try await load() - let messages = [["role": "user", "content": prompt]] - let promptTokens = try await modelContainer.perform { _, tokenizer in - try tokenizer.applyChatTemplate(messages: messages) - } - // each time you generate you will get something new MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000)) - let result = await modelContainer.perform { model, tokenizer in - LLM.generate( - promptTokens: promptTokens, parameters: generateParameters, model: model, - tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens + let result = try await modelContainer.perform { context in + let input = try await context.processor.prepare(input: .init(prompt: prompt)) + return try MLXLMCommon.generate( + input: input, parameters: generateParameters, context: context ) { tokens in // update the output -- this will make the view show the text as it generates if tokens.count % displayEveryNTokens == 0 { - let text = tokenizer.decode(tokens: tokens) + let text = context.tokenizer.decode(tokens: tokens) Task { @MainActor in self.output = text } diff --git a/Applications/LLMEval/README.md b/Applications/LLMEval/README.md index 45a2208..836c43a 100644 --- a/Applications/LLMEval/README.md +++ b/Applications/LLMEval/README.md @@ -30,7 +30,7 @@ The example application uses Phi2 model by default, see [ContentView.swift](Cont let modelConfiguration = ModelConfiguration.phi4bit ``` -There are some pre-configured models in [LLM/Models.swift](../../Libraries/LLM/Models.swift#L62) +There are some pre-configured models in [MLXLLM/LLMModelFactory.swift](../../Libraries/MLXLLM/LLMModelFactory.swift#L78) and you can load any weights from Hugging Face where there is a model architecture defined and you have enough memory. diff --git a/Applications/LLMEval/ViewModels/DeviceStat.swift b/Applications/LLMEval/ViewModels/DeviceStat.swift index f1dca3c..49a59b0 100644 --- a/Applications/LLMEval/ViewModels/DeviceStat.swift +++ b/Applications/LLMEval/ViewModels/DeviceStat.swift @@ -1,6 +1,6 @@ import Foundation -import LLM import MLX +import MLXLLM @Observable final class DeviceStat: @unchecked Sendable { diff --git a/Applications/LoRATrainingExample/ContentView.swift b/Applications/LoRATrainingExample/ContentView.swift index 7813372..4b59904 100644 --- a/Applications/LoRATrainingExample/ContentView.swift +++ b/Applications/LoRATrainingExample/ContentView.swift @@ -1,7 +1,8 @@ // Copyright © 2024 Apple Inc. -import LLM import MLX +import MLXLLM +import MLXLMCommon import MLXNN import MLXOptimizers import MLXRandom @@ -122,7 +123,7 @@ class LoRAEvaluator { var output = "" - private let modelConfiguration = ModelConfiguration.mistral7B4bit + private let modelConfiguration = ModelRegistry.mistral7B4bit private var model: ModelState = .idle private let loraLayers = 4 @@ -141,8 +142,9 @@ class LoRAEvaluator { progress = .init(title: "Loading \(name)", current: 0, limit: 1) } - let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration) - { + let modelContainer = try await LLMModelFactory.shared.loadContainer( + configuration: modelConfiguration + ) { progress in Task { @MainActor in self.progress = .init( @@ -160,7 +162,7 @@ class LoRAEvaluator { private func loadLoRAData(name: String) throws -> [String]? { if let url = Bundle.main.url(forResource: name, withExtension: "jsonl") { - return try LLM.loadLoRAData(url: url) + return try MLXLLM.loadLoRAData(url: url) } return nil } @@ -196,9 +198,9 @@ class LoRAEvaluator { let modelContainer = try await loadModel() // apply LoRA adapters and train - await modelContainer.perform { model, _ in + await modelContainer.perform { context in LoRATrain.convert( - model: model, layers: loraLayers(model: model)) + model: context.model, layers: loraLayers(model: context.model)) } let train = try loadLoRAData(name: "train") @@ -208,11 +210,11 @@ class LoRAEvaluator { return } - try await modelContainer.perform { model, tokenizer in + try await modelContainer.perform { context in let optimizer = Adam(learningRate: learningRate) try LoRATrain.train( - model: model, train: train, validate: valid, optimizer: optimizer, - tokenizer: tokenizer, + model: context.model, train: train, validate: valid, optimizer: optimizer, + tokenizer: context.tokenizer, parameters: parameters ) { progress in Task { @MainActor in @@ -240,9 +242,10 @@ class LoRAEvaluator { return } - let loss = await modelContainer.perform { model, tokenizer in + let loss = await modelContainer.perform { context in LoRATrain.evaluate( - model: model, dataset: test, tokenizer: tokenizer, batchSize: 1, batchCount: 0) + model: context.model, dataset: test, + tokenizer: context.tokenizer, batchSize: 1, batchCount: 0) } self.progress = nil @@ -269,26 +272,20 @@ class LoRAEvaluator { let modelContainer = try await loadModel() - let messages = [["role": "user", "content": prompt]] - let promptTokens = try await modelContainer.perform { _, tokenizer in - try tokenizer.applyChatTemplate(messages: messages) - } - // evaluate - let result = await modelContainer.perform { model, tokenizer in - LLM.generate( - promptTokens: promptTokens, parameters: generateParameters, model: model, - tokenizer: tokenizer, - extraEOSTokens: modelConfiguration.extraEOSTokens, - didGenerate: { tokens in - if tokens.count % evaluateShowEvery == 0 { - let fullOutput = tokenizer.decode(tokens: tokens) - Task { @MainActor in - self.output = fullOutput - } + let result = try await modelContainer.perform { context in + let input = try await context.processor.prepare(input: .init(prompt: prompt)) + return try MLXLMCommon.generate( + input: input, parameters: generateParameters, context: context + ) { tokens in + if tokens.count % evaluateShowEvery == 0 { + let fullOutput = context.tokenizer.decode(tokens: tokens) + Task { @MainActor in + self.output = fullOutput } - return tokens.count >= maxTokens ? .stop : .more - }) + } + return tokens.count >= maxTokens ? .stop : .more + } } self.output = result.output diff --git a/Applications/MNISTTrainer/ContentView.swift b/Applications/MNISTTrainer/ContentView.swift index b8ccdf8..4142f27 100644 --- a/Applications/MNISTTrainer/ContentView.swift +++ b/Applications/MNISTTrainer/ContentView.swift @@ -1,10 +1,10 @@ // Copyright © 2024 Apple Inc. import MLX +import MLXMNIST import MLXNN import MLXOptimizers import MLXRandom -import MNIST import SwiftUI struct TrainingView: View { diff --git a/Applications/MNISTTrainer/PredictionView.swift b/Applications/MNISTTrainer/PredictionView.swift index 5aa8353..40b207f 100644 --- a/Applications/MNISTTrainer/PredictionView.swift +++ b/Applications/MNISTTrainer/PredictionView.swift @@ -6,8 +6,8 @@ // import MLX +import MLXMNIST import MLXNN -import MNIST import SwiftUI struct Canvas: View { diff --git a/Libraries/LLM/Configuration.swift b/Libraries/LLM/Configuration.swift deleted file mode 100644 index 0fbdb91..0000000 --- a/Libraries/LLM/Configuration.swift +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright © 2024 Apple Inc. - -import Foundation - -public enum StringOrNumber: Codable, Equatable, Sendable { - case string(String) - case float(Float) - - public init(from decoder: Decoder) throws { - let values = try decoder.singleValueContainer() - - if let v = try? values.decode(Float.self) { - self = .float(v) - } else { - let v = try values.decode(String.self) - self = .string(v) - } - } - - public func encode(to encoder: Encoder) throws { - var container = encoder.singleValueContainer() - switch self { - case .string(let v): try container.encode(v) - case .float(let v): try container.encode(v) - } - } -} - -private class ModelTypeRegistry: @unchecked Sendable { - - // Note: using NSLock as we have very small (just dictionary get/set) - // critical sections and expect no contention. this allows the methods - // to remain synchronous. - private let lock = NSLock() - - @Sendable - private static func createLlamaModel(url: URL) throws -> LLMModel { - let configuration = try JSONDecoder().decode( - LlamaConfiguration.self, from: Data(contentsOf: url)) - return LlamaModel(configuration) - } - - private var creators: [String: @Sendable (URL) throws -> LLMModel] = [ - "mistral": createLlamaModel, - "llama": createLlamaModel, - "phi": { url in - let configuration = try JSONDecoder().decode( - PhiConfiguration.self, from: Data(contentsOf: url)) - return PhiModel(configuration) - }, - "phi3": { url in - let configuration = try JSONDecoder().decode( - Phi3Configuration.self, from: Data(contentsOf: url)) - return Phi3Model(configuration) - }, - "phimoe": { url in - let configuration = try JSONDecoder().decode( - PhiMoEConfiguration.self, from: Data(contentsOf: url)) - return PhiMoEModel(configuration) - }, - "gemma": { url in - let configuration = try JSONDecoder().decode( - GemmaConfiguration.self, from: Data(contentsOf: url)) - return GemmaModel(configuration) - }, - "gemma2": { url in - let configuration = try JSONDecoder().decode( - Gemma2Configuration.self, from: Data(contentsOf: url)) - return Gemma2Model(configuration) - }, - "qwen2": { url in - let configuration = try JSONDecoder().decode( - Qwen2Configuration.self, from: Data(contentsOf: url)) - return Qwen2Model(configuration) - }, - "starcoder2": { url in - let configuration = try JSONDecoder().decode( - Starcoder2Configuration.self, from: Data(contentsOf: url)) - return Starcoder2Model(configuration) - }, - "cohere": { url in - let configuration = try JSONDecoder().decode( - CohereConfiguration.self, from: Data(contentsOf: url)) - return CohereModel(configuration) - }, - "openelm": { url in - let configuration = try JSONDecoder().decode( - OpenElmConfiguration.self, from: Data(contentsOf: url)) - return OpenELMModel(configuration) - }, - "internlm2": { url in - let configuration = try JSONDecoder().decode( - InternLM2Configuration.self, from: Data(contentsOf: url)) - return InternLM2Model(configuration) - }, - ] - - public func registerModelType( - _ type: String, creator: @Sendable @escaping (URL) throws -> LLMModel - ) { - lock.withLock { - creators[type] = creator - } - } - - public func createModel(configuration: URL, rawValue: String) throws -> LLMModel { - let creator = lock.withLock { - creators[rawValue] - } - guard let creator else { - throw LLMError(message: "Unsupported model type.") - } - return try creator(configuration) - } - -} - -private let modelTypeRegistry = ModelTypeRegistry() - -public struct ModelType: RawRepresentable, Codable, Sendable { - public let rawValue: String - - public init(rawValue: String) { - self.rawValue = rawValue - } - - public static func registerModelType( - _ type: String, creator: @Sendable @escaping (URL) throws -> LLMModel - ) { - modelTypeRegistry.registerModelType(type, creator: creator) - } - - public func createModel(configuration: URL) throws -> LLMModel { - try modelTypeRegistry.createModel(configuration: configuration, rawValue: rawValue) - } -} - -public struct BaseConfiguration: Codable, Sendable { - public let modelType: ModelType - - public struct Quantization: Codable, Sendable { - public init(groupSize: Int, bits: Int) { - self.groupSize = groupSize - self.bits = bits - } - - let groupSize: Int - let bits: Int - - enum CodingKeys: String, CodingKey { - case groupSize = "group_size" - case bits = "bits" - } - } - - public var quantization: Quantization? - - enum CodingKeys: String, CodingKey { - case modelType = "model_type" - case quantization - } -} diff --git a/Libraries/LLM/Evaluate.swift b/Libraries/LLM/Evaluate.swift deleted file mode 100644 index f681054..0000000 --- a/Libraries/LLM/Evaluate.swift +++ /dev/null @@ -1,318 +0,0 @@ -// Copyright © 2024 Apple Inc. - -import Foundation -import MLX -import MLXRandom -import Tokenizers - -/// Parameters for text generation, see ``TokenIterator`` -public struct GenerateParameters: Sendable { - - /// Step size for processing the prompt - public var prefillStepSize = 512 - - /// sampling temperature - public var temperature: Float = 0.6 - - /// top p sampling - public var topP: Float = 1.0 - - /// penalty factor for repeating tokens - public var repetitionPenalty: Float? - - /// number of tokens to consider for repetition penalty - public var repetitionContextSize: Int = 20 - - public init( - temperature: Float = 0.6, topP: Float = 1.0, repetitionPenalty: Float? = nil, - repetitionContextSize: Int = 20 - ) { - self.temperature = temperature - self.topP = topP - self.repetitionPenalty = repetitionPenalty - self.repetitionContextSize = repetitionContextSize - } -} - -struct SampleContext { - - let temp: MLXArray - let topP: MLXArray - let useTopP: Bool - let useArgMax: Bool - - init(parameters: GenerateParameters) { - self.temp = MLXArray(parameters.temperature) - self.topP = MLXArray(parameters.topP) - self.useTopP = parameters.topP > 0 && parameters.topP < 1 - self.useArgMax = parameters.temperature == 0 - } - - private let compiledTopPSampling: (MLXArray, MLXArray, MLXArray) -> MLXArray = { - compile(inputs: [MLXRandom.globalState], outputs: [MLXRandom.globalState]) { - logits, topP, temp in - let probs = softmax(logits / temp, axis: -1) - let sortedIndices = argSort(probs, axis: -1) - - // probs shape is [B,V] and after take it will be [1, B, V], so we squeeze it back to [B, V] - let sortedProbs = take(probs, sortedIndices, axis: -1).squeezed(axis: 0) - - let cumulativeProbs = cumsum(sortedProbs, axis: -1) - - let topProbs = MLX.where( - cumulativeProbs .> (1 - topP), sortedProbs, zeros(like: sortedProbs)) - - let sortedToken = categorical(log(topProbs)) - return sortedIndices.squeezed(axis: 0)[sortedToken] - } - }() - - private let compiledCategorical: (MLXArray, MLXArray) -> MLXArray = { - compile(inputs: [MLXRandom.globalState], outputs: [MLXRandom.globalState]) { logits, temp in - categorical(logits * (1 / temp)) - } - }() - - private func topPSampling(logits: MLXArray) -> MLXArray { - var logits = logits - if logits.dtype == .bfloat16 { - logits = logits.asType(.float32) - } - - return compiledTopPSampling(logits, topP, temp) - } - - func sample(logits: MLXArray) -> MLXArray { - if useArgMax { - return argMax(logits, axis: -1) - } else { - if useTopP { - return topPSampling(logits: logits) - } else { - return compiledCategorical(logits, temp) - } - } - } -} - -/// Encapsulaton of the repetitionPenalty -struct RepetitionContext: Sendable { - /// tokens in the repetition context sliding window - var tokens: [Int] - - /// current write into into the tokens circular array - var index = 0 - - /// penalty factor for repeating tokens - let repetitionPenalty: Float? - - /// number of tokens to consider for repetition penalty - let repetitionContextSize: Int - - init(prompt: MLXArray, parameters: GenerateParameters) { - self.repetitionPenalty = parameters.repetitionPenalty - self.repetitionContextSize = parameters.repetitionContextSize - - if repetitionPenalty != nil && repetitionContextSize > 1 { - if prompt.shape[0] <= repetitionContextSize { - self.tokens = prompt.asArray(Int.self) - } else { - self.tokens = prompt[(-repetitionContextSize)...].asArray(Int.self) - } - } else { - self.tokens = [] - } - } - - func applyRepetitionPenalty(logits: MLXArray) -> MLXArray { - if let penalty = repetitionPenalty, tokens.count > 0 { - let indices = MLXArray(tokens.map { UInt32($0) }) - var selectedLogits = logits[0..., indices] - - selectedLogits = MLX.where( - selectedLogits .< 0, selectedLogits * penalty, selectedLogits / penalty) - - logits[0..., indices] = selectedLogits - return logits - } - - return logits - } - - mutating func append(token: MLXArray) { - if repetitionPenalty != nil { - if tokens.count >= repetitionContextSize { - tokens[index] = token.item(Int.self) - index = (index + 1) % repetitionContextSize - } else { - tokens.append(token.item(Int.self)) - } - } - } -} - -/// Synchronous generator of tokens. -/// -/// Tokens are integers that can be passed through a `Tokenizer` or ``StreamingDetokenizer`` to produce Strings. -/// -/// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py -/// -/// Note: this uses `asyncEval()` and there may be an async evaluation running after a call to `next()`. -public struct TokenIterator: Sequence, IteratorProtocol { - let model: LLMModel - let parameters: GenerateParameters - - var y: MLXArray - var cache: [KVCache] - var repetitionContext: RepetitionContext - let sampleContext: SampleContext - - public init(prompt: MLXArray, model: LLMModel, parameters: GenerateParameters) { - self.model = model - self.parameters = parameters - self.y = prompt - self.cache = model.newCache(parameters: parameters) - - self.repetitionContext = RepetitionContext(prompt: prompt, parameters: parameters) - self.sampleContext = SampleContext(parameters: parameters) - - // prepare the prompt in chunks if larger than the prefill size - while y.size > parameters.prefillStepSize { - _ = model( - y[.newAxis, .. MLXArray { - var logits: MLXArray - logits = model(previous[.newAxis], cache: cache.isEmpty ? nil : cache) - - logits = logits[0..., -1, 0...] - logits = repetitionContext.applyRepetitionPenalty(logits: logits) - - let y = sampleContext.sample(logits: logits) - - repetitionContext.append(token: y) - - return y - } - - mutating public func next() -> Int? { - // save current value -- this will be returned - let previousY = y - - // compute the next state and async eval the next token - y = step(previous: previousY) - asyncEval(y) - - return previousY.item(Int.self) - } -} - -public struct GenerateResult: Sendable { - /// input tokens - public let promptTokens: [Int] - - /// output tokens - public let tokens: [Int] - - /// output text - public let output: String - - /// time to process the prompt / generate the first token - public let promptTime: TimeInterval - - /// time to generate the remaining tokens - public let generateTime: TimeInterval - - public var promptTokensPerSecond: Double { - Double(promptTokens.count) / promptTime - } - - public var tokensPerSecond: Double { - Double(tokens.count) / generateTime - } - - public func summary() -> String { - """ - Prompt: \(promptTokens.count) tokens, \(promptTokensPerSecond.formatted()) tokens/s - Generation: \(tokens.count) tokens, \(tokensPerSecond.formatted()) tokens/s, \(generateTime.formatted())s - """ - } -} - -public enum GenerateDisposition: Sendable { - case more - case stop -} - -/// Given prompt tokens generate text using the given model and parameters. -/// -/// - Parameters: -/// - promptTokens: tokenized prompt -/// - parameters: generation parameters -/// - model: model to evaluate -/// - tokenizer: tokenizer to convert tokens back into strings and recognizer special tokens -/// - configuration: the model configuration -/// - didGenerate: visitor for the tokens as they are generated -public func generate( - promptTokens: [Int], parameters: GenerateParameters, model: LLMModel, tokenizer: Tokenizer, - extraEOSTokens: Set? = nil, - didGenerate: ([Int]) -> GenerateDisposition -) -> GenerateResult { - var start = Date.timeIntervalSinceReferenceDate - var promptTime: TimeInterval = 0 - - let additionalEOSTokenIds = Set( - (extraEOSTokens ?? []) - .compactMap { - tokenizer.convertTokenToId($0) - }) - - var tokens = [Int]() - - for token in TokenIterator( - prompt: MLXArray(promptTokens), model: model, parameters: parameters) - { - // compute the timing for the prompt - if tokens.isEmpty { - let now = Date.timeIntervalSinceReferenceDate - promptTime = now - start - start = now - } - - if token == tokenizer.unknownTokenId || token == tokenizer.eosTokenId - || additionalEOSTokenIds.contains(token) - { - break - } - tokens.append(token) - - if didGenerate(tokens) == .stop { - break - } - } - - let now = Date.timeIntervalSinceReferenceDate - let generateTime = now - start - - // TokenIterator uses `asyncEval()` to keep the pipeline full. If the caller - // exits the program right away, those tasks will still be executing and will - // hit assertions as the mlx scheduler is torn down. Synchronize with the stream - // to make sure it is complete. - Stream().synchronize() - - return GenerateResult( - promptTokens: promptTokens, tokens: tokens, - output: tokenizer.decode(tokens: tokens), - promptTime: promptTime, generateTime: generateTime) -} diff --git a/Libraries/LLM/LLM.h b/Libraries/LLM/LLM.h deleted file mode 100644 index 8b13789..0000000 --- a/Libraries/LLM/LLM.h +++ /dev/null @@ -1 +0,0 @@ - diff --git a/Libraries/LLM/LLMModel.swift b/Libraries/LLM/LLMModel.swift deleted file mode 100644 index 5999fb5..0000000 --- a/Libraries/LLM/LLMModel.swift +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright © 2024 Apple Inc. - -import Foundation -@preconcurrency import Hub -import MLX -import MLXNN -import Tokenizers - -/// Container for models that guarantees single threaded access. -/// -/// Wrap models used by e.g. the UI in a ModelContainer. Callers can access -/// the model and/or tokenizer: -/// -/// ```swift -/// let messages = [["role": "user", "content": prompt]] -/// let promptTokens = try await modelContainer.perform { _, tokenizer in -/// try tokenizer.applyChatTemplate(messages: messages) -/// } -/// ``` -/// -/// or: -/// -/// ```swift -/// let result = await modelContainer.perform { model, tokenizer in -/// LLM.generate( -/// promptTokens: promptTokens, parameters: generateParameters, model: model, -/// tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens -/// ) { tokens in -/// ... -/// } -/// } -/// ``` -public actor ModelContainer { - let model: LLMModel - let tokenizer: Tokenizer - - public init(model: LLMModel, tokenizer: Tokenizer) { - self.model = model - self.tokenizer = tokenizer - } - - /// build the model and tokenizer without passing non-sendable data over isolation barriers - public init( - hub: HubApi, modelDirectory: URL, configuration: ModelConfiguration - ) async throws { - self.model = try loadSynchronous(modelDirectory: modelDirectory) - - let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig( - configuration: configuration, hub: hub) - self.tokenizer = try PreTrainedTokenizer( - tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) - } - - /// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as - /// `MLXArray` is not `Sendable`. - public func perform(_ action: @Sendable (LLMModel, Tokenizer) throws -> R) rethrows -> R { - try action(model, tokenizer) - } -} - -extension Module { - - /// Compute the number of parameters in a possibly quantized model - public func numParameters() -> Int { - return leafModules().flattenedValues().map { - mod -> Int in - if let qlin = mod as? QuantizedLinear { - return qlin.scales.size * qlin.groupSize - } else if let qemb = mod as? QuantizedEmbedding { - return qemb.scales.size * qemb.groupSize - } else { - return mod.parameters().flattenedValues().reduce( - 0, - { - $0 + $1.size - }) - } - }.reduce(0, +) - } -} - -/// Interface for all LLM Models -public protocol LLMModel: Module { - - var vocabularySize: Int { get } - - func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray - - /// create a new array of ``KVCache`` -- automatic implementation if self - /// implements ``KVCacheDimensionProvider`` - func newCache(parameters: GenerateParameters) -> [KVCache] - - /// Optionally preprocess the weights and modify / remove values as needed. - func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] -} - -/// Optional protocol that can be implemented by ``LLMModel`` and will -/// provide an automatic implementation of ``LLMModel/newCache(parameters:)`` -public protocol KVCacheDimensionProvider { - var kvHeads: [Int] { get } - var headDim: IntOrPair { get } -} - -extension LLMModel { - public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { - weights - } -} - -extension LLMModel where Self: KVCacheDimensionProvider { - public func newCache(parameters: GenerateParameters) -> [KVCache] { - kvHeads.map { n in - KVCacheSimple(headDim: headDim, kvHeads: n) - } - } -} diff --git a/Libraries/LLM/Models.swift b/Libraries/LLM/Models.swift deleted file mode 100644 index 614fcb6..0000000 --- a/Libraries/LLM/Models.swift +++ /dev/null @@ -1,240 +0,0 @@ -// Copyright © 2024 Apple Inc. - -import Foundation -import Hub - -/// Registry of models and any overrides that go with them, e.g. prompt augmentation. -/// If asked for an unknown configuration this will use the model/tokenizer as-is. -/// -/// The python tokenizers have a very rich set of implementations and configuration. The -/// swift-tokenizers code handles a good chunk of that and this is a place to augment that -/// implementation, if needed. -public struct ModelConfiguration: Sendable { - - public enum Identifier: Sendable { - case id(String) - case directory(URL) - } - - public var id: Identifier - - public var name: String { - switch id { - case .id(let string): - string - case .directory(let url): - url.deletingLastPathComponent().lastPathComponent + "/" + url.lastPathComponent - } - } - - /// pull the tokenizer from an alternate id - public let tokenizerId: String? - - /// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated - public let overrideTokenizer: String? - - /// A reasonable default prompt for the model - public let defaultPrompt: String - - /// Additional tokens to use for end of string - public let extraEOSTokens: Set - - public init( - id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil, - defaultPrompt: String = "hello", - extraEOSTokens: Set = [], - preparePrompt: (@Sendable (String) -> String)? = nil - ) { - self.id = .id(id) - self.tokenizerId = tokenizerId - self.overrideTokenizer = overrideTokenizer - self.defaultPrompt = defaultPrompt - self.extraEOSTokens = extraEOSTokens - } - - public init( - directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil, - defaultPrompt: String = "hello", - extraEOSTokens: Set = [] - ) { - self.id = .directory(directory) - self.tokenizerId = tokenizerId - self.overrideTokenizer = overrideTokenizer - self.defaultPrompt = defaultPrompt - self.extraEOSTokens = extraEOSTokens - } - - public func modelDirectory(hub: HubApi = HubApi()) -> URL { - switch id { - case .id(let id): - // download the model weights and config - let repo = Hub.Repo(id: id) - return hub.localRepoLocation(repo) - - case .directory(let directory): - return directory - } - } - - @MainActor - public static var registry = [String: ModelConfiguration]() - - @MainActor - public static func register(configurations: [ModelConfiguration]) { - bootstrap() - - for c in configurations { - registry[c.name] = c - } - } - - @MainActor - public static func configuration(id: String) -> ModelConfiguration { - bootstrap() - - if let c = registry[id] { - return c - } else { - return ModelConfiguration(id: id) - } - } -} - -extension ModelConfiguration { - public static let smolLM_135M_4bit = ModelConfiguration( - id: "mlx-community/SmolLM-135M-Instruct-4bit", - defaultPrompt: "Tell me about the history of Spain." - ) - - public static let mistralNeMo4bit = ModelConfiguration( - id: "mlx-community/Mistral-Nemo-Instruct-2407-4bit", - defaultPrompt: "Explain quaternions." - ) - - public static let mistral7B4bit = ModelConfiguration( - id: "mlx-community/Mistral-7B-Instruct-v0.3-4bit", - defaultPrompt: "Describe the Swift language." - ) - - public static let codeLlama13b4bit = ModelConfiguration( - id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX", - overrideTokenizer: "PreTrainedTokenizer", - defaultPrompt: "func sortArray(_ array: [Int]) -> String { }" - ) - - public static let phi4bit = ModelConfiguration( - id: "mlx-community/phi-2-hf-4bit-mlx", - // https://www.promptingguide.ai/models/phi-2 - defaultPrompt: "Why is the sky blue?" - ) - - public static let phi3_5_4bit = ModelConfiguration( - id: "mlx-community/Phi-3.5-mini-instruct-4bit", - defaultPrompt: "What is the gravity on Mars and the moon?", - extraEOSTokens: ["<|end|>"] - ) - - public static let phi3_5MoE = ModelConfiguration( - id: "mlx-community/Phi-3.5-MoE-instruct-4bit", - defaultPrompt: "What is the gravity on Mars and the moon?", - extraEOSTokens: ["<|end|>"] - ) { - prompt in - "<|user|>\n\(prompt)<|end|>\n<|assistant|>\n" - } - - public static let gemma2bQuantized = ModelConfiguration( - id: "mlx-community/quantized-gemma-2b-it", - overrideTokenizer: "PreTrainedTokenizer", - // https://www.promptingguide.ai/models/gemma - defaultPrompt: "what is the difference between lettuce and cabbage?" - ) - - public static let gemma_2_9b_it_4bit = ModelConfiguration( - id: "mlx-community/gemma-2-9b-it-4bit", - overrideTokenizer: "PreTrainedTokenizer", - // https://www.promptingguide.ai/models/gemma - defaultPrompt: "What is the difference between lettuce and cabbage?" - ) - - public static let gemma_2_2b_it_4bit = ModelConfiguration( - id: "mlx-community/gemma-2-2b-it-4bit", - overrideTokenizer: "PreTrainedTokenizer", - // https://www.promptingguide.ai/models/gemma - defaultPrompt: "What is the difference between lettuce and cabbage?" - ) - - public static let qwen205b4bit = ModelConfiguration( - id: "mlx-community/Qwen1.5-0.5B-Chat-4bit", - overrideTokenizer: "PreTrainedTokenizer", - defaultPrompt: "why is the sky blue?" - ) - - public static let openelm270m4bit = ModelConfiguration( - id: "mlx-community/OpenELM-270M-Instruct", - // https://huggingface.co/apple/OpenELM - defaultPrompt: "Once upon a time there was" - ) - - public static let llama3_1_8B_4bit = ModelConfiguration( - id: "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", - defaultPrompt: "What is the difference between a fruit and a vegetable?" - ) - - public static let llama3_8B_4bit = ModelConfiguration( - id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit", - defaultPrompt: "What is the difference between a fruit and a vegetable?" - ) - - public static let llama3_2_1B_4bit = ModelConfiguration( - id: "mlx-community/Llama-3.2-1B-Instruct-4bit", - defaultPrompt: "What is the difference between a fruit and a vegetable?" - ) - - public static let llama3_2_3B_4bit = ModelConfiguration( - id: "mlx-community/Llama-3.2-3B-Instruct-4bit", - defaultPrompt: "What is the difference between a fruit and a vegetable?" - ) - - private enum BootstrapState: Sendable { - case idle - case bootstrapping - case bootstrapped - } - - @MainActor - static private var bootstrapState = BootstrapState.idle - - @MainActor - static func bootstrap() { - switch bootstrapState { - case .idle: - bootstrapState = .bootstrapping - register(configurations: [ - codeLlama13b4bit, - gemma2bQuantized, - gemma_2_2b_it_4bit, - gemma_2_9b_it_4bit, - llama3_1_8B_4bit, - llama3_2_1B_4bit, - llama3_2_3B_4bit, - llama3_8B_4bit, - mistral7B4bit, - mistralNeMo4bit, - openelm270m4bit, - phi3_5MoE, - phi3_5_4bit, - phi4bit, - qwen205b4bit, - smolLM_135M_4bit, - ]) - bootstrapState = .bootstrapped - - case .bootstrapping: - break - - case .bootstrapped: - break - } - } -} diff --git a/Libraries/LLM/README.md b/Libraries/LLM/README.md deleted file mode 100644 index 10fc034..0000000 --- a/Libraries/LLM/README.md +++ /dev/null @@ -1,41 +0,0 @@ -# LLM - -This is a port of several models from: - -- https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/ - -using the Hugging Face swift transformers package to provide tokenization: - -- https://github.com/huggingface/swift-transformers - -The [Models.swift](Models.swift) provides minor overrides and customization -- -if you require overrides for the tokenizer or prompt customizations they can be -added there. - -This is set up to load models from Hugging Face, e.g. https://huggingface.co/mlx-community - -The following models have been tried: - -- mlx-community/Mistral-7B-v0.1-hf-4bit-mlx -- mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX -- mlx-community/phi-2-hf-4bit-mlx -- mlx-community/quantized-gemma-2b-it - -Currently supported model types are: - -- Llama / Mistral -- Gemma -- Phi - -See [Configuration.swift](Configuration.swift) for more info. - -See [llm-tool](../../Tools/llm-tool) - -# LoRA - -[Lora.swift](Lora.swift) contains an implementation of LoRA based on this example: - -- https://github.com/ml-explore/mlx-examples/tree/main/lora - -See [llm-tool/LoraCommands.swift](../../Tools/llm-tool/LoraCommands.swift) for an example of a driver and -[llm-tool](../../Tools/llm-tool) for examples of how to run it. diff --git a/Libraries/MLXLLM/LLMModel.swift b/Libraries/MLXLLM/LLMModel.swift new file mode 100644 index 0000000..97bcf57 --- /dev/null +++ b/Libraries/MLXLLM/LLMModel.swift @@ -0,0 +1,33 @@ +// Copyright © 2024 Apple Inc. + +import MLX +import MLXLMCommon + +/// Marker protocol for LLMModels +public protocol LLMModel: LanguageModel, LoRAModel { +} + +extension LLMModel { + + /// Default prepare step for ``LLMModel``. + /// + /// This will evaluate the prompt in chunks until there is a small amount of + /// tokens left to feed into the `TokenIterator`. + public func prepare(_ input: LMInput, cache: [KVCache], windowSize: Int?) throws + -> PrepareResult + { + let prefillStepSize = windowSize ?? 512 + var y = input.text + var state: LMOutput.State? = nil + + // prepare the prompt in chunks if larger than the prefill size + while y.tokens.size > prefillStepSize { + let input = y[.newAxis, ..( + _ configurationType: C.Type, _ modelInit: @escaping (C) -> M +) -> (URL) throws -> M { + { url in + let configuration = try JSONDecoder().decode( + C.self, from: Data(contentsOf: url)) + return modelInit(configuration) + } +} + +/// Registry of model type, e.g 'llama', to functions that can instantiate the model from configuration. +/// +/// Typically called via ``LLMModelFactory/load(hub:configuration:progressHandler:)``. +public class ModelTypeRegistry: @unchecked Sendable { + + // Note: using NSLock as we have very small (just dictionary get/set) + // critical sections and expect no contention. this allows the methods + // to remain synchronous. + private let lock = NSLock() + + private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [ + "mistral": create(LlamaConfiguration.self, LlamaModel.init), + "llama": create(LlamaConfiguration.self, LlamaModel.init), + "phi": create(PhiConfiguration.self, PhiModel.init), + "phi3": create(Phi3Configuration.self, Phi3Model.init), + "phimoe": create(PhiMoEConfiguration.self, PhiMoEModel.init), + "gemma": create(GemmaConfiguration.self, GemmaModel.init), + "gemma2": create(Gemma2Configuration.self, Gemma2Model.init), + "qwen2": create(Qwen2Configuration.self, Qwen2Model.init), + "starcoder2": create(Starcoder2Configuration.self, Starcoder2Model.init), + "cohere": create(CohereConfiguration.self, CohereModel.init), + "openelm": create(OpenElmConfiguration.self, OpenELMModel.init), + "internlm2": create(InternLM2Configuration.self, InternLM2Model.init), + ] + + /// Add a new model to the type registry. + public func registerModelType( + _ type: String, creator: @Sendable @escaping (URL) throws -> any LanguageModel + ) { + lock.withLock { + creators[type] = creator + } + } + + /// Given a `modelType` and configuration file instantiate a new `LanguageModel`. + public func createModel(configuration: URL, modelType: String) throws -> LanguageModel { + let creator = lock.withLock { + creators[modelType] + } + guard let creator else { + throw ModelFactoryError.unsupportedModelType(modelType) + } + return try creator(configuration) + } + +} + +/// Registry of models and any overrides that go with them, e.g. prompt augmentation. +/// If asked for an unknown configuration this will use the model/tokenizer as-is. +/// +/// The python tokenizers have a very rich set of implementations and configuration. The +/// swift-tokenizers code handles a good chunk of that and this is a place to augment that +/// implementation, if needed. +public class ModelRegistry: @unchecked Sendable { + + private let lock = NSLock() + private var registry = Dictionary(uniqueKeysWithValues: all().map { ($0.name, $0) }) + + static public let smolLM_135M_4bit = ModelConfiguration( + id: "mlx-community/SmolLM-135M-Instruct-4bit", + defaultPrompt: "Tell me about the history of Spain." + ) + + static public let mistralNeMo4bit = ModelConfiguration( + id: "mlx-community/Mistral-Nemo-Instruct-2407-4bit", + defaultPrompt: "Explain quaternions." + ) + + static public let mistral7B4bit = ModelConfiguration( + id: "mlx-community/Mistral-7B-Instruct-v0.3-4bit", + defaultPrompt: "Describe the Swift language." + ) + + static public let codeLlama13b4bit = ModelConfiguration( + id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX", + overrideTokenizer: "PreTrainedTokenizer", + defaultPrompt: "func sortArray(_ array: [Int]) -> String { }" + ) + + static public let phi4bit = ModelConfiguration( + id: "mlx-community/phi-2-hf-4bit-mlx", + // https://www.promptingguide.ai/models/phi-2 + defaultPrompt: "Why is the sky blue?" + ) + + static public let phi3_5_4bit = ModelConfiguration( + id: "mlx-community/Phi-3.5-mini-instruct-4bit", + defaultPrompt: "What is the gravity on Mars and the moon?", + extraEOSTokens: ["<|end|>"] + ) + + static public let phi3_5MoE = ModelConfiguration( + id: "mlx-community/Phi-3.5-MoE-instruct-4bit", + defaultPrompt: "What is the gravity on Mars and the moon?", + extraEOSTokens: ["<|end|>"] + ) { + prompt in + "<|user|>\n\(prompt)<|end|>\n<|assistant|>\n" + } + + static public let gemma2bQuantized = ModelConfiguration( + id: "mlx-community/quantized-gemma-2b-it", + overrideTokenizer: "PreTrainedTokenizer", + // https://www.promptingguide.ai/models/gemma + defaultPrompt: "what is the difference between lettuce and cabbage?" + ) + + static public let gemma_2_9b_it_4bit = ModelConfiguration( + id: "mlx-community/gemma-2-9b-it-4bit", + overrideTokenizer: "PreTrainedTokenizer", + // https://www.promptingguide.ai/models/gemma + defaultPrompt: "What is the difference between lettuce and cabbage?" + ) + + static public let gemma_2_2b_it_4bit = ModelConfiguration( + id: "mlx-community/gemma-2-2b-it-4bit", + overrideTokenizer: "PreTrainedTokenizer", + // https://www.promptingguide.ai/models/gemma + defaultPrompt: "What is the difference between lettuce and cabbage?" + ) + + static public let qwen205b4bit = ModelConfiguration( + id: "mlx-community/Qwen1.5-0.5B-Chat-4bit", + overrideTokenizer: "PreTrainedTokenizer", + defaultPrompt: "why is the sky blue?" + ) + + static public let openelm270m4bit = ModelConfiguration( + id: "mlx-community/OpenELM-270M-Instruct", + // https://huggingface.co/apple/OpenELM + defaultPrompt: "Once upon a time there was" + ) + + static public let llama3_1_8B_4bit = ModelConfiguration( + id: "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", + defaultPrompt: "What is the difference between a fruit and a vegetable?" + ) + + static public let llama3_8B_4bit = ModelConfiguration( + id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit", + defaultPrompt: "What is the difference between a fruit and a vegetable?" + ) + + static public let llama3_2_1B_4bit = ModelConfiguration( + id: "mlx-community/Llama-3.2-1B-Instruct-4bit", + defaultPrompt: "What is the difference between a fruit and a vegetable?" + ) + + static public let llama3_2_3B_4bit = ModelConfiguration( + id: "mlx-community/Llama-3.2-3B-Instruct-4bit", + defaultPrompt: "What is the difference between a fruit and a vegetable?" + ) + + private static func all() -> [ModelConfiguration] { + [ + codeLlama13b4bit, + gemma2bQuantized, + gemma_2_2b_it_4bit, + gemma_2_9b_it_4bit, + llama3_1_8B_4bit, + llama3_2_1B_4bit, + llama3_2_3B_4bit, + llama3_8B_4bit, + mistral7B4bit, + mistralNeMo4bit, + openelm270m4bit, + phi3_5MoE, + phi3_5_4bit, + phi4bit, + qwen205b4bit, + smolLM_135M_4bit, + ] + } + + public func register(configurations: [ModelConfiguration]) { + lock.withLock { + for c in configurations { + registry[c.name] = c + } + } + } + + public func configuration(id: String) -> ModelConfiguration { + lock.withLock { + if let c = registry[id] { + return c + } else { + return ModelConfiguration(id: id) + } + } + } +} + +private struct LLMUserInputProcessor: UserInputProcessor { + + let tokenizer: Tokenizer + let configuration: ModelConfiguration + + internal init(tokenizer: any Tokenizer, configuration: ModelConfiguration) { + self.tokenizer = tokenizer + self.configuration = configuration + } + + func prepare(input: UserInput) throws -> LMInput { + do { + let messages = input.prompt.asMessages() + let promptTokens = try tokenizer.applyChatTemplate(messages: messages) + return LMInput(tokens: MLXArray(promptTokens)) + } catch { + // #150 -- it might be a TokenizerError.chatTemplate("No chat template was specified") + // but that is not public so just fall back to text + let prompt = input.prompt + .asMessages() + .compactMap { $0["content"] } + .joined(separator: ". ") + let promptTokens = tokenizer.encode(text: prompt) + return LMInput(tokens: MLXArray(promptTokens)) + } + } +} + +/// Factory for creating new LLMs. +/// +/// Callers can use the `shared` instance or create a new instance if custom configuration +/// is required. +/// +/// ```swift +/// let modelContainer = try await LLMModelFactory.shared.loadContainer( +/// configuration: ModelRegistry.llama3_8B_4bit) +/// ``` +public class LLMModelFactory: ModelFactory { + + public static let shared = LLMModelFactory() + + /// registry of model type, e.g. configuration value `llama` -> configuration and init methods + public let typeRegistry = ModelTypeRegistry() + + /// registry of model id to configuration, e.g. `mlx-community/Llama-3.2-3B-Instruct-4bit` + public let modelRegistry = ModelRegistry() + + public func configuration(id: String) -> ModelConfiguration { + modelRegistry.configuration(id: id) + } + + public func _load( + hub: HubApi, configuration: ModelConfiguration, + progressHandler: @Sendable @escaping (Progress) -> Void + ) async throws -> ModelContext { + // download weights and config + let modelDirectory = try await downloadModel( + hub: hub, configuration: configuration, progressHandler: progressHandler) + + // load the generic config to unerstand which model and how to load the weights + let configurationURL = modelDirectory.appending(component: "config.json") + let baseConfig = try JSONDecoder().decode( + BaseConfiguration.self, from: Data(contentsOf: configurationURL)) + let model = try typeRegistry.createModel( + configuration: configurationURL, modelType: baseConfig.modelType) + + // apply the weights to the bare model + try loadWeights( + modelDirectory: modelDirectory, model: model, quantization: baseConfig.quantization) + + let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub) + + return .init( + configuration: configuration, model: model, + processor: LLMUserInputProcessor(tokenizer: tokenizer, configuration: configuration), + tokenizer: tokenizer) + } + +} diff --git a/Libraries/LLM/Lora+Data.swift b/Libraries/MLXLLM/Lora+Data.swift similarity index 100% rename from Libraries/LLM/Lora+Data.swift rename to Libraries/MLXLLM/Lora+Data.swift diff --git a/Libraries/LLM/Lora.swift b/Libraries/MLXLLM/LoraTrain.swift similarity index 66% rename from Libraries/LLM/Lora.swift rename to Libraries/MLXLLM/LoraTrain.swift index 426237a..2dc0d02 100644 --- a/Libraries/LLM/Lora.swift +++ b/Libraries/MLXLLM/LoraTrain.swift @@ -2,222 +2,12 @@ import Foundation import MLX +import MLXLMCommon import MLXNN import MLXOptimizers import MLXRandom import Tokenizers -/// Layers to apply LoRA adapters to. -/// -/// This is the value returned by ``LoRAModel/loraLinearLayers()``. -public typealias LoRALinearLayers = [(Module, [String])] - -public protocol LoRAModel { - /// Return the layers and keys to apply LoRA adapters to. - /// - /// For example this might apply the adapters to the `q` an `v` projections in the - /// Attention layers: - /// - /// ```swift - /// model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } - /// ``` - /// - /// It is not required that a model implement this protocol to have LoRA adapters applied, but - /// the command line driver example uses this to produce the ``LoRALinearLayers``. - /// - /// ### See Also - /// - ``LoRATrain/convert(model:layers:)`` - func loraLinearLayers() -> LoRALinearLayers -} - -/// Protocol for LoRA implementations that provides a method for converting back to a `Linear` -/// (or subtype). -/// -/// This is normally called via ``LoRATrain/fuse(model:layers:deQuantize:)`` -public protocol LoRAConvertToLinear { - func toLinear(deQuantize: Bool) -> Linear -} - -/// Implementation of LoRA `Linear` replacement layer. -/// -/// This layer implements the LoRA capabilities for `Linear` layers, specifically: -/// -/// - converting `Linear` or `QuantizedLinear` layers to ``LoRALinear`` / ``QLoRALinear`` -/// - converting ``LoRALinear`` back to `Linear` or `QuantizedLinear` (``LoRAConvertToLinear``) -/// - implementing the LoRA evaluation -/// -/// ``QLoRALinear`` is the equivalent class for `QuantizedLinear`. -/// -/// This is not typically used directly -- ``LoRATrain/convert(model:layers:)`` is used to -/// add the adapter layers to a given model. -/// -/// ### See Also -/// - [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) -/// - [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) -/// - ``QLoRALinear`` -/// - ``LoRATrain/convert(model:layers:)`` -/// - ``LoRATrain/fuse(model:layers:deQuantize:)`` -public class LoRALinear: Linear, LoRAConvertToLinear { - - let scale: Float - - @ParameterInfo(key: "lora_a") var loraA: MLXArray - @ParameterInfo(key: "lora_b") var loraB: MLXArray - - required public init( - _ inputDimensions: Int, _ outputDimensions: Int, rank: Int = 8, bias: Bool = false, - scale: Float = 20.0, linear: Linear - ) { - // Scale for low-rank update - self.scale = scale - - // Low rank lora weights - let loraScale = 1 / sqrt(Float(inputDimensions)) - self._loraA.wrappedValue = MLXRandom.uniform( - low: -loraScale, high: loraScale, [inputDimensions, rank]) - self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions]) - - super.init(weight: linear.weight, bias: linear.bias) - - freeze() - } - - /// Freeze all parameters except the lora parameters - public override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) - throws - { - // realize the keys and omit the lora parameters - let keys = - (keys ?? self.filterMap(filter: Self.filterLocalParameters).flattened().map { $0.0 }) - .filter { - $0 != "lora_a" && $0 != "lora_b" - } - try super.freeze(recursive: recursive, keys: keys, strict: strict) - } - - /// Convert a `Linear` or `QuantizedLinear` layer into a new `Linear` layer - /// that implements the `LoRA` adapter. - /// - /// This is typically called via ``LoRATrain/convert(model:layers:)``. - /// - /// ### See Also - /// - ``LoRATrain/convert(model:layers:)`` - /// - ``QLoRALinear/from(linear:rank:)`` - public static func from(linear: Linear, rank: Int = 8) -> Linear { - if let linear = linear as? QuantizedLinear { - return QLoRALinear.from(linear: linear, rank: rank) - } - let (outputDimensions, inputDimensions) = linear.shape - return LoRALinear(inputDimensions, outputDimensions, rank: rank, linear: linear) - } - - /// Convert back into a fused `Linear` layer. - /// - /// This is typically called via ``LoRATrain/fuse(model:layers:deQuantize:)``. - /// - /// ### See Also - /// - ``LoRATrain/fuse(model:layers:deQuantize:)`` - /// - ``LoRAConvertToLinear`` - /// - ``QLoRALinear/toLinear(deQuantize:)`` - public func toLinear(deQuantize: Bool = false) -> Linear { - let dtype = weight.dtype - let loraB = (scale * loraB.T).asType(dtype) - let loraA = loraA.T.asType(dtype) - return Linear(weight: weight + matmul(loraB, loraA), bias: bias) - } - - public override func callAsFunction(_ x: MLXArray) -> MLXArray { - let y = super.callAsFunction(x.asType(weight.dtype)) - let z = matmul(matmul(x, self.loraA), self.loraB) - return y + scale * z - } -} - -/// Implementation of LoRA `QuantizedLinear` replacement layer. -/// -/// See ``LoRALinear`` (equivalent class for `Linear` layers) for more information. -public class QLoRALinear: QuantizedLinear, LoRAConvertToLinear { - - let scale: Float - - @ParameterInfo(key: "lora_a") var loraA: MLXArray - @ParameterInfo(key: "lora_b") var loraB: MLXArray - - required public init( - _ inputDimensions: Int, _ outputDimensions: Int, rank: Int = 8, bias: Bool = false, - scale: Float = 20.0, linear: QuantizedLinear - ) { - - // Scale for low-rank update - self.scale = scale - - // Low rank lora weights - let loraScale = 1 / sqrt(Float(inputDimensions)) - self._loraA.wrappedValue = MLXRandom.uniform( - low: -loraScale, high: loraScale, [inputDimensions, rank]) - self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions]) - - super.init( - weight: linear.weight, bias: linear.bias, scales: linear.scales, biases: linear.biases, - groupSize: linear.groupSize, bits: linear.bits) - - // start frozen except for the lora keys - freeze() - } - - /// Freeze all parameters except the lora parameters - public override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) - throws - { - // realize the keys and omit the lora parameters - let keys = - (keys ?? self.filterMap(filter: Self.filterLocalParameters).flattened().map { $0.0 }) - .filter { - $0 != "lora_a" && $0 != "lora_b" - } - try super.freeze(recursive: recursive, keys: keys, strict: strict) - } - - /// Convert a `QuantizedLinear` layer into a new `Linear` layer - /// that implements the `LoRA` adapter. - /// - /// This is typically called via ``LoRATrain/convert(model:layers:)``. - /// - /// ### See Also - /// - ``LoRATrain/convert(model:layers:)`` - /// - ``LoRALinear/from(linear:rank:)`` - public static func from(linear: QuantizedLinear, rank: Int = 8) -> Linear { - var (outputDimensions, inputDimensions) = linear.shape - inputDimensions = inputDimensions * 32 / linear.bits - return QLoRALinear(inputDimensions, outputDimensions, rank: rank, linear: linear) - } - - /// Convert back into a fused `QuantizedLinear` layer. - /// - /// This is typically called via ``LoRATrain/fuse(model:layers:deQuantize:)``. - /// - /// ### See Also - /// - ``LoRATrain/fuse(model:layers:deQuantize:)`` - public func toLinear(deQuantize: Bool = false) -> Linear { - // convert back into full weights - let weight = dequantized( - weight, scales: scales, biases: biases, groupSize: groupSize, bits: bits) - - let loraB = (scale * loraB.T).asType(.float16) - let loraA = loraA.T.asType(.float16) - - // convert back into quantized - return QuantizedLinear( - weight: weight + matmul(loraB, loraA), bias: bias, groupSize: groupSize, bits: bits) - } - - public override func callAsFunction(_ x: MLXArray) -> MLXArray { - let y = super.callAsFunction(x.asType(scales.dtype)) - let z = matmul(matmul(x, self.loraA), self.loraB) - return y + scale * z - } -} - /// Equivalent to `lora.py/iterate_batches()`. Used internally by ``LoRATrain``. struct LoRABatchIterator: Sequence, IteratorProtocol { @@ -277,7 +67,6 @@ struct LoRABatchIterator: Sequence, IteratorProtocol { return (batchArray[0..., .stride(to: -1)], batchArray[0..., 1...], MLXArray(lengths)) } - } /// Collection of functions for adding LoRA adapters to an LLM model, training, fusing and saving/loading weights. @@ -437,7 +226,7 @@ public enum LoRATrain { // def loss(model, inputs, targets, lengths): // run model on inputs - let model = model as! LLMModel + let model = model as! any LLMModel let logits = model(inputs, cache: nil).asType(.float32) // mask padding tokens diff --git a/Libraries/LLM/Models/Cohere.swift b/Libraries/MLXLLM/Models/Cohere.swift similarity index 98% rename from Libraries/LLM/Models/Cohere.swift rename to Libraries/MLXLLM/Models/Cohere.swift index eff0de5..f57364b 100644 --- a/Libraries/LLM/Models/Cohere.swift +++ b/Libraries/MLXLLM/Models/Cohere.swift @@ -1,6 +1,7 @@ import Foundation import MLX import MLXFast +import MLXLMCommon import MLXNN // port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/cohere.py @@ -149,7 +150,6 @@ public class CohereModel: Module, LLMModel, KVCacheDimensionProvider { public let vocabularySize: Int public let kvHeads: [Int] - public let headDim: IntOrPair let model: CohereModelInner let logitScale: Float @@ -157,7 +157,6 @@ public class CohereModel: Module, LLMModel, KVCacheDimensionProvider { public init(_ args: CohereConfiguration) { self.vocabularySize = args.vocabularySize self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } - self.headDim = .init(args.hiddenSize / args.attentionHeads) self.model = CohereModelInner(args) self.logitScale = args.logitScale } diff --git a/Libraries/LLM/Models/Gemma.swift b/Libraries/MLXLLM/Models/Gemma.swift similarity index 98% rename from Libraries/LLM/Models/Gemma.swift rename to Libraries/MLXLLM/Models/Gemma.swift index 8475e70..37ca0a4 100644 --- a/Libraries/LLM/Models/Gemma.swift +++ b/Libraries/MLXLLM/Models/Gemma.swift @@ -3,6 +3,7 @@ import Foundation import MLX import MLXFast +import MLXLMCommon import MLXNN // Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma.py @@ -136,7 +137,7 @@ private class TransformerBlock: Module { } } -public class GemmaModelInner: Module { +private class GemmaModelInner: Module { let args: GemmaConfiguration let vocabularySize: Int let numHiddenLayers: Int @@ -179,16 +180,14 @@ public class GemmaModelInner: Module { public class GemmaModel: Module, LLMModel, KVCacheDimensionProvider { public let vocabularySize: Int public let kvHeads: [Int] - public let headDim: IntOrPair let modelType: String - let model: GemmaModelInner + private let model: GemmaModelInner public init(_ args: GemmaConfiguration) { self.modelType = args.modelType self.vocabularySize = args.vocabularySize self.kvHeads = Array(repeating: args.kvHeads, count: args.hiddenLayers) - self.headDim = .init(args.headDimensions) self.model = GemmaModelInner(args) } diff --git a/Libraries/LLM/Models/Gemma2.swift b/Libraries/MLXLLM/Models/Gemma2.swift similarity index 99% rename from Libraries/LLM/Models/Gemma2.swift rename to Libraries/MLXLLM/Models/Gemma2.swift index 30bcd3f..48e0482 100644 --- a/Libraries/LLM/Models/Gemma2.swift +++ b/Libraries/MLXLLM/Models/Gemma2.swift @@ -3,6 +3,7 @@ import Foundation import MLX import MLXFast +import MLXLMCommon import MLXNN // Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py @@ -195,7 +196,6 @@ public class ModelInner: Module { public class Gemma2Model: Module, LLMModel, KVCacheDimensionProvider { public let vocabularySize: Int public let kvHeads: [Int] - public let headDim: IntOrPair let model: ModelInner let logitSoftCap: Float @@ -203,7 +203,6 @@ public class Gemma2Model: Module, LLMModel, KVCacheDimensionProvider { public init(_ args: Gemma2Configuration) { self.vocabularySize = args.vocabularySize self.kvHeads = Array(repeating: args.kvHeads, count: args.hiddenLayers) - self.headDim = .init(args.headDimensions) self.model = ModelInner(args) self.logitSoftCap = args.finalLogitSoftcapping } diff --git a/Libraries/LLM/Models/Internlm2.swift b/Libraries/MLXLLM/Models/Internlm2.swift similarity index 97% rename from Libraries/LLM/Models/Internlm2.swift rename to Libraries/MLXLLM/Models/Internlm2.swift index 2bdf5fd..a2d8ebb 100644 --- a/Libraries/LLM/Models/Internlm2.swift +++ b/Libraries/MLXLLM/Models/Internlm2.swift @@ -3,6 +3,7 @@ import Foundation import MLX import MLXFast +import MLXLMCommon import MLXNN // Port of https://github.com/maiqingqiang/mlx-examples/blob/main/llms/mlx_lm/models/internlm2.py @@ -71,11 +72,10 @@ private class Attention: Module { if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"), let factor = ropeScaling["factor"] { - switch factor { - case .string: - fatalError("ropeScaling.factor must be a float") - case .float(let v): + if let v = factor.asFloat() { ropeScale = 1 / v + } else { + fatalError("ropeScaling.factor must be a float") } } else { ropeScale = 1 @@ -200,7 +200,6 @@ private class InternLM2ModelInner: Module { public class InternLM2Model: Module, LLMModel, KVCacheDimensionProvider { public let vocabularySize: Int public let kvHeads: [Int] - public let headDim: IntOrPair fileprivate let model: InternLM2ModelInner @@ -209,7 +208,6 @@ public class InternLM2Model: Module, LLMModel, KVCacheDimensionProvider { public init(_ args: InternLM2Configuration) { self.vocabularySize = args.vocabularySize self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } - self.headDim = .init(args.hiddenSize / args.attentionHeads) self.model = InternLM2ModelInner(args) if !args.tieWordEmbeddings { self._output.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) @@ -233,6 +231,12 @@ public class InternLM2Model: Module, LLMModel, KVCacheDimensionProvider { } } +extension InternLM2Model: LoRAModel { + public func loraLinearLayers() -> LoRALinearLayers { + model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } + } +} + public struct InternLM2Configuration: Codable, Sendable { var hiddenSize: Int var hiddenLayers: Int diff --git a/Libraries/LLM/Models/Llama.swift b/Libraries/MLXLLM/Models/Llama.swift similarity index 99% rename from Libraries/LLM/Models/Llama.swift rename to Libraries/MLXLLM/Models/Llama.swift index c2d5e40..08bdc0f 100644 --- a/Libraries/LLM/Models/Llama.swift +++ b/Libraries/MLXLLM/Models/Llama.swift @@ -3,6 +3,7 @@ import Foundation import MLX import MLXFast +import MLXLMCommon import MLXNN // port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py @@ -286,7 +287,6 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider { public let vocabularySize: Int public let kvHeads: [Int] - public let headDim: IntOrPair fileprivate let model: LlamaModelInner @@ -295,7 +295,6 @@ public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider { public init(_ args: LlamaConfiguration) { self.vocabularySize = args.vocabularySize self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } - self.headDim = .init(args.resolvedHeadDimensions) self.model = LlamaModelInner(args) if !args.tieWordEmbeddings { self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) diff --git a/Libraries/LLM/Models/OpenELM.swift b/Libraries/MLXLLM/Models/OpenELM.swift similarity index 96% rename from Libraries/LLM/Models/OpenELM.swift rename to Libraries/MLXLLM/Models/OpenELM.swift index 9c852cb..df70ed0 100644 --- a/Libraries/LLM/Models/OpenELM.swift +++ b/Libraries/MLXLLM/Models/OpenELM.swift @@ -8,6 +8,7 @@ import Foundation import MLX import MLXFast +import MLXLMCommon import MLXNN func computeHeads(modelDim: Int, headDim: Int) -> Int { @@ -144,11 +145,7 @@ private class TransformerDecoderLayer: Module { } } -class OpenELMModelInner: Module, LLMModel, KVCacheDimensionProvider { - let vocabularySize: Int - let kvHeads: [Int] - let headDim: IntOrPair - +class OpenELMModelInner: Module { @ModuleInfo(key: "token_embeddings") var embedTokens: Embedding fileprivate let layers: [TransformerDecoderLayer] @@ -157,11 +154,8 @@ class OpenELMModelInner: Module, LLMModel, KVCacheDimensionProvider { public init(_ args: OpenElmConfiguration) { precondition(args.vocabularySize > 0) - self.vocabularySize = args.vocabularySize - self.kvHeads = args.kvHeads - self.headDim = .init(args.headDimensions) self._embedTokens.wrappedValue = Embedding( - embeddingCount: self.vocabularySize, dimensions: args.modelDim) + embeddingCount: args.vocabularySize, dimensions: args.modelDim) self.layers = (0 ..< args.numTransformerLayers) .map { layerId in @@ -186,7 +180,6 @@ class OpenELMModelInner: Module, LLMModel, KVCacheDimensionProvider { public class OpenELMModel: Module, LLMModel, KVCacheDimensionProvider { public let vocabularySize: Int public let kvHeads: [Int] - public let headDim: IntOrPair let shareInputOutputLayers: Bool let transformer: OpenELMModelInner @@ -196,7 +189,6 @@ public class OpenELMModel: Module, LLMModel, KVCacheDimensionProvider { public init(_ args: OpenElmConfiguration) { self.vocabularySize = args.vocabularySize self.kvHeads = args.kvHeads - self.headDim = .init(args.headDimensions) self.transformer = OpenELMModelInner(args) self.shareInputOutputLayers = args.shareInputOutputLayers diff --git a/Libraries/LLM/Models/Phi.swift b/Libraries/MLXLLM/Models/Phi.swift similarity index 98% rename from Libraries/LLM/Models/Phi.swift rename to Libraries/MLXLLM/Models/Phi.swift index edfe186..963a315 100644 --- a/Libraries/LLM/Models/Phi.swift +++ b/Libraries/MLXLLM/Models/Phi.swift @@ -3,6 +3,7 @@ import Foundation import MLX import MLXFast +import MLXLMCommon import MLXNN // https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phi.py @@ -155,7 +156,6 @@ public class PhiModel: Module, LLMModel, KVCacheDimensionProvider { public let vocabularySize: Int public let kvHeads: [Int] - public let headDim: IntOrPair fileprivate let model: PhiModelInner @@ -164,7 +164,6 @@ public class PhiModel: Module, LLMModel, KVCacheDimensionProvider { public init(_ args: PhiConfiguration) { self.vocabularySize = args.vocabularySize self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } - self.headDim = .init(args.hiddenSize / args.attentionHeads) self.model = PhiModelInner(args) self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: true) } diff --git a/Libraries/LLM/Models/Phi3.swift b/Libraries/MLXLLM/Models/Phi3.swift similarity index 99% rename from Libraries/LLM/Models/Phi3.swift rename to Libraries/MLXLLM/Models/Phi3.swift index c7709ed..2c25a50 100644 --- a/Libraries/LLM/Models/Phi3.swift +++ b/Libraries/MLXLLM/Models/Phi3.swift @@ -3,6 +3,7 @@ import Foundation import MLX import MLXFast +import MLXLMCommon import MLXNN private class Attention: Module { @@ -187,7 +188,6 @@ public class Phi3Model: Module, LLMModel, KVCacheDimensionProvider { public let vocabularySize: Int public let kvHeads: [Int] - public let headDim: IntOrPair let model: Phi3ModelInner @@ -196,7 +196,6 @@ public class Phi3Model: Module, LLMModel, KVCacheDimensionProvider { public init(_ args: Phi3Configuration) { self.vocabularySize = args.vocabularySize self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } - self.headDim = .init(args.hiddenSize / args.attentionHeads) self.model = Phi3ModelInner(args) self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) } diff --git a/Libraries/LLM/Models/PhiMoE.swift b/Libraries/MLXLLM/Models/PhiMoE.swift similarity index 98% rename from Libraries/LLM/Models/PhiMoE.swift rename to Libraries/MLXLLM/Models/PhiMoE.swift index e11542f..fb86e8f 100644 --- a/Libraries/LLM/Models/PhiMoE.swift +++ b/Libraries/MLXLLM/Models/PhiMoE.swift @@ -1,6 +1,7 @@ import Foundation import MLX import MLXFast +import MLXLMCommon import MLXNN import MLXRandom @@ -210,7 +211,6 @@ private class PhiMoEModelInner: Module { public class PhiMoEModel: Module, LLMModel, KVCacheDimensionProvider { public let vocabularySize: Int public let kvHeads: [Int] - public let headDim: IntOrPair fileprivate let model: PhiMoEModelInner @ModuleInfo(key: "lm_head") var lmHead: Linear @@ -218,7 +218,6 @@ public class PhiMoEModel: Module, LLMModel, KVCacheDimensionProvider { public init(_ args: PhiMoEConfiguration) { self.vocabularySize = args.vocabularySize self.kvHeads = Array(repeating: args.kvHeads, count: args.hiddenLayers) - self.headDim = .init(args.hiddenSize / args.attentionHeads) self.model = PhiMoEModelInner(args) self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: true) } diff --git a/Libraries/LLM/Models/Qwen2.swift b/Libraries/MLXLLM/Models/Qwen2.swift similarity index 98% rename from Libraries/LLM/Models/Qwen2.swift rename to Libraries/MLXLLM/Models/Qwen2.swift index 8a754ef..aa07fef 100644 --- a/Libraries/LLM/Models/Qwen2.swift +++ b/Libraries/MLXLLM/Models/Qwen2.swift @@ -8,6 +8,7 @@ import Foundation import MLX import MLXFast +import MLXLMCommon import MLXNN // port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/qwen2.py @@ -42,11 +43,10 @@ private class Attention: Module { if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"), let factor = ropeScaling["factor"] { - switch factor { - case .string: - fatalError("ropeScaling.factor must be a float") - case .float(let v): + if let v = factor.asFloat() { ropeScale = 1 / v + } else { + fatalError("ropeScaling.factor must be a float") } } else { ropeScale = 1 @@ -168,7 +168,6 @@ public class Qwen2ModelInner: Module { public class Qwen2Model: Module, LLMModel, KVCacheDimensionProvider { public let vocabularySize: Int public let kvHeads: [Int] - public let headDim: IntOrPair let model: Qwen2ModelInner let configuration: Qwen2Configuration @@ -179,7 +178,6 @@ public class Qwen2Model: Module, LLMModel, KVCacheDimensionProvider { self.configuration = args self.vocabularySize = args.vocabularySize self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } - self.headDim = .init(args.hiddenSize / args.attentionHeads) self.model = Qwen2ModelInner(args) _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) } diff --git a/Libraries/LLM/Models/Starcoder2.swift b/Libraries/MLXLLM/Models/Starcoder2.swift similarity index 98% rename from Libraries/LLM/Models/Starcoder2.swift rename to Libraries/MLXLLM/Models/Starcoder2.swift index 2169881..2e2ec5e 100644 --- a/Libraries/LLM/Models/Starcoder2.swift +++ b/Libraries/MLXLLM/Models/Starcoder2.swift @@ -8,6 +8,7 @@ import Foundation import MLX import MLXFast +import MLXLMCommon import MLXNN // port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/starcoder2.py @@ -150,7 +151,6 @@ public class Starcoder2ModelInner: Module { public class Starcoder2Model: Module, LLMModel, KVCacheDimensionProvider { public let vocabularySize: Int public let kvHeads: [Int] - public let headDim: IntOrPair public let tieWordEmbeddings: Bool let model: Starcoder2ModelInner @@ -160,7 +160,6 @@ public class Starcoder2Model: Module, LLMModel, KVCacheDimensionProvider { public init(_ args: Starcoder2Configuration) { self.vocabularySize = args.vocabularySize self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } - self.headDim = .init(args.hiddenSize / args.attentionHeads) self.model = Starcoder2ModelInner(args) self.tieWordEmbeddings = args.tieWordEmbeddings if !self.tieWordEmbeddings { diff --git a/Libraries/MLXLLM/README.md b/Libraries/MLXLLM/README.md new file mode 100644 index 0000000..c024bf7 --- /dev/null +++ b/Libraries/MLXLLM/README.md @@ -0,0 +1,157 @@ +# MLXLLM + +This is a port of several models from: + +- https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/ + +using the Hugging Face swift transformers package to provide tokenization: + +- https://github.com/huggingface/swift-transformers + +The [LLMModelFactory.swift](LLMModelFactory.swift) provides minor overrides and customization -- +if you require overrides for the tokenizer or prompt customizations they can be +added there. + +This is set up to load models from Hugging Face, e.g. https://huggingface.co/mlx-community + +The following models have been tried: + +- mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX +- mlx-community/Llama-3.2-1B-Instruct-4bit +- mlx-community/Llama-3.2-3B-Instruct-4bit +- mlx-community/Meta-Llama-3-8B-Instruct-4bit +- mlx-community/Meta-Llama-3.1-8B-Instruct-4bit +- mlx-community/Mistral-7B-Instruct-v0.3-4bit +- mlx-community/Mistral-Nemo-Instruct-2407-4bit +- mlx-community/OpenELM-270M-Instruct +- mlx-community/Phi-3.5-MoE-instruct-4bit +- mlx-community/Phi-3.5-mini-instruct-4bit +- mlx-community/Qwen1.5-0.5B-Chat-4bit +- mlx-community/SmolLM-135M-Instruct-4bit +- mlx-community/gemma-2-2b-it-4bit +- mlx-community/gemma-2-9b-it-4bit +- mlx-community/phi-2-hf-4bit-mlx +- mlx-community/quantized-gemma-2b-it + +Currently supported model types are: + +- Cohere +- Gemma +- Gemma2 +- InternLM2 +- Llama / Mistral +- OpenELM +- Phi +- Phi3 +- PhiMoE +- Qwen2 +- Starcoder2 + +See [llm-tool](../../Tools/llm-tool) + +# Adding a Model + +If the model follows the typical LLM pattern: + +- `config.json`, `tokenizer.json`, and `tokenizer_config.json` +- `*.safetensors` + +You can follow the pattern of the models in the [Models](Models) directory +and create a `.swift` file for your new model: + +## Create a Configuration + +Create a configuration struct to match the `config.json` (any parameters needed). + +```swift +public struct YourModelConfiguration: Codable, Sendable { + public let hiddenSize: Int + + // use this pattern for values that need defaults + public let _layerNormEps: Float? + public var layerNormEps: Float { _layerNormEps ?? 1e-6 } + + enum CodingKeys: String, CodingKey { + case hiddenSize = "hidden_size" + case _layerNormEps = "layer_norm_eps" + } +} +``` + +## Create the Model Class + +Create the model class. The top-level public class should have a +structure something like this: + +```swift +public class YourModel: Module, LLMModel, KVCacheDimensionProvider, LoRAModel { + + public let kvHeads: [Int] + + @ModuleInfo var model: YourModelInner + + public func loraLinearLayers() -> LoRALinearLayers { + // TODO: modify as needed + model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } + } + + public init(_ args: YourModelConfiguration) { + self.kvHeads = Array(repeating: args.kvHeads, count: args.hiddenLayers) + self.model = YourModelInner(args) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray { + // TODO: modify as needed + let out = model(inputs, cache: cache) + return model.embedTokens.asLinear(out) + } +} +``` + +## Register the Model + +In [LLMModelFactory.swift](LLMModelFactory.swift) register the model type itself +(this is independent of the model id): + +```swift +public class ModelTypeRegistry: @unchecked Sendable { +... + private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [ + "yourModel": create(YourModelConfiguration.self, YourModel.init), +``` + +Add a constant for the model in the `ModelRegistry` (not strictly required but useful +for callers to refer to it in code): + +```swift +public class ModelRegistry: @unchecked Sendable { +... + static public let yourModel_4bit = ModelConfiguration( + id: "mlx-community/YourModel-4bit", + defaultPrompt: "What is the gravity on Mars and the moon?" + ) +``` + +and finally add it to the all list -- this will let users find the model +configuration by id: + +```swift + private static func all() -> [ModelConfiguration] { + [ + codeLlama13b4bit, +... + yourModel_4bit, +``` + +# Using a Model + +See [MLXLMCommon/README.md](../MLXLMCommon/README.md#using-a-model). + +# LoRA + +[Lora.swift](Lora.swift) contains an implementation of LoRA based on this example: + +- https://github.com/ml-explore/mlx-examples/tree/main/lora + +See [llm-tool/LoraCommands.swift](../../Tools/llm-tool/LoraCommands.swift) for an example of a driver and +[llm-tool](../../Tools/llm-tool) for examples of how to run it. diff --git a/Libraries/LLM/SuScaledRotaryEmbedding.swift b/Libraries/MLXLLM/SuScaledRotaryEmbedding.swift similarity index 86% rename from Libraries/LLM/SuScaledRotaryEmbedding.swift rename to Libraries/MLXLLM/SuScaledRotaryEmbedding.swift index 3dec287..f1a001e 100644 --- a/Libraries/LLM/SuScaledRotaryEmbedding.swift +++ b/Libraries/MLXLLM/SuScaledRotaryEmbedding.swift @@ -5,7 +5,6 @@ import MLXNN public class SuScaledRotaryEmbedding: Module { let dimensions: Int - let base: Float let maxPositionEmbeddings: Int let originalMaxPositionEmbeddings: Int let scale: Float @@ -23,7 +22,6 @@ public class SuScaledRotaryEmbedding: Module { precondition(dimensions % 2 == 0, "Dimensions must be even") self.dimensions = dimensions - self.base = base self.maxPositionEmbeddings = maxPositionEmbeddings self.originalMaxPositionEmbeddings = originalMaxPositionEmbeddings @@ -45,10 +43,10 @@ public class SuScaledRotaryEmbedding: Module { self.scale * x, dimensions: x.shape.last!, traditional: false, - base: self.base, // TODO: After updating to MLX 0.17.0, use `nil` + base: nil, scale: 1.0, - offset: offset - // TODO: After updating to MLX 0.17.0, pass `self._freqs` to `freqs` + offset: offset, + freqs: self._freqs ) } } diff --git a/Libraries/LLM/SwitchLayers.swift b/Libraries/MLXLLM/SwitchLayers.swift similarity index 100% rename from Libraries/LLM/SwitchLayers.swift rename to Libraries/MLXLLM/SwitchLayers.swift diff --git a/Libraries/MLXLMCommon/Evaluate.swift b/Libraries/MLXLMCommon/Evaluate.swift new file mode 100644 index 0000000..b20643b --- /dev/null +++ b/Libraries/MLXLMCommon/Evaluate.swift @@ -0,0 +1,556 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import MLX +import MLXRandom +import Tokenizers + +/// A `LogitSampler` is responsible for sampling `logits` produced by +/// a ``LanguageModel`` to produce a token. +/// +/// See also: ``LogitProcessor`` +public protocol LogitSampler: Sendable { + + /// Given `logits` produce a new `MLXArray` with the token. + func sample(logits: MLXArray) -> MLXArray +} + +/// A `LogitProcessor` is an optional visitor of `logits`. +/// +/// The ``LogitProcessor`` is called with the input (prompt) before generating tokens: +/// +/// ```swift +/// processor?.prompt(input.text.tokens) +/// ``` +/// +/// Then for each token generated it has a chance to adjust the logits: +/// +/// ```swift +/// logits = processor?.process(logits: logits) ?? logits +/// let y = sampler.sample(logits: logits) +/// processor?.didSample(token: y) +/// ``` +/// +/// See also: ``LogitSampler`` +public protocol LogitProcessor: Sendable { + + /// called before token generation starts with the text tokens of the prompt + mutating func prompt(_ prompt: MLXArray) + + /// called to visit ad possibly modify the logits + func process(logits: MLXArray) -> MLXArray + + /// called to provide the sampled token + mutating func didSample(token: MLXArray) +} + +/// Parameters for text generation, see ``TokenIterator``. +/// +/// This produces: +/// +/// - ``LogitSampler`` +/// - ``LogitProcessor`` +/// +/// for the `TokenIterator`. +public struct GenerateParameters: Sendable { + + /// Step size for processing the prompt + public var prefillStepSize = 512 + + /// sampling temperature + public var temperature: Float = 0.6 + + /// top p sampling + public var topP: Float = 1.0 + + /// penalty factor for repeating tokens + public var repetitionPenalty: Float? + + /// number of tokens to consider for repetition penalty + public var repetitionContextSize: Int = 20 + + public init( + temperature: Float = 0.6, topP: Float = 1.0, repetitionPenalty: Float? = nil, + repetitionContextSize: Int = 20 + ) { + self.temperature = temperature + self.topP = topP + self.repetitionPenalty = repetitionPenalty + self.repetitionContextSize = repetitionContextSize + } + + func sampler() -> LogitSampler { + if temperature == 0 { + return ArgMaxSampler() + } else if topP > 0 && topP < 1 { + return TopPSampler(temperature: temperature, topP: topP) + } else { + return CategoricalSampler(temperature: temperature) + } + } + + func processor() -> LogitProcessor? { + if let repetitionPenalty, repetitionContextSize > 0 { + return RepetitionContext( + repetitionPenalty: repetitionPenalty, repetitionContextSize: repetitionContextSize) + } else { + return nil + } + } +} + +/// Sampler that uses `argMax` (most likely) to sample the logits. +public struct ArgMaxSampler: LogitSampler { + public func sample(logits: MLX.MLXArray) -> MLX.MLXArray { + argMax(logits, axis: -1) + } +} + +/// Sampler that uses `topP` and `temperature` to sample the logits. +public struct TopPSampler: LogitSampler { + let temp: MLXArray + let topP: MLXArray + + init(temperature: Float, topP: Float) { + self.temp = MLXArray(temperature) + self.topP = MLXArray(topP) + } + + private let compiledTopPSampling: (MLXArray, MLXArray, MLXArray) -> MLXArray = { + compile(inputs: [MLXRandom.globalState], outputs: [MLXRandom.globalState]) { + logits, topP, temp in + let probs = softmax(logits / temp, axis: -1) + let sortedIndices = argSort(probs, axis: -1) + + // probs shape is [B,V] and after take it will be [1, B, V], so we squeeze it back to [B, V] + let sortedProbs = take(probs, sortedIndices, axis: -1).squeezed(axis: 0) + + let cumulativeProbs = cumsum(sortedProbs, axis: -1) + + let topProbs = MLX.where( + cumulativeProbs .> (1 - topP), sortedProbs, zeros(like: sortedProbs)) + + let sortedToken = categorical(log(topProbs)) + return sortedIndices.squeezed(axis: 0)[sortedToken] + } + }() + + public func sample(logits: MLXArray) -> MLXArray { + var logits = logits + if logits.dtype == .bfloat16 { + logits = logits.asType(.float32) + } + + return compiledTopPSampling(logits, topP, temp) + } +} + +/// Processor that uses `temperature` to sample the logits +public struct CategoricalSampler: LogitSampler { + let temp: MLXArray + + init(temperature: Float) { + self.temp = MLXArray(temperature) + } + + private let compiledCategorical: (MLXArray, MLXArray) -> MLXArray = { + compile(inputs: [MLXRandom.globalState], outputs: [MLXRandom.globalState]) { logits, temp in + categorical(logits * (1 / temp)) + } + }() + + public func sample(logits: MLXArray) -> MLXArray { + compiledCategorical(logits, temp) + } +} + +/// Processor that implements a `repetitionPenalty` +public struct RepetitionContext: LogitProcessor { + /// tokens in the repetition context sliding window + var tokens = [Int]() + + /// current write into into the tokens circular array + var index = 0 + + /// penalty factor for repeating tokens + let repetitionPenalty: Float + + /// number of tokens to consider for repetition penalty + let repetitionContextSize: Int + + init(repetitionPenalty: Float, repetitionContextSize: Int) { + precondition(repetitionContextSize > 0) + self.repetitionPenalty = repetitionPenalty + self.repetitionContextSize = repetitionContextSize + } + + mutating public func prompt(_ prompt: MLXArray) { + if prompt.shape[0] <= repetitionContextSize { + self.tokens = prompt.asArray(Int.self) + } else { + self.tokens = prompt[(-repetitionContextSize)...].asArray(Int.self) + } + } + + public func process(logits: MLXArray) -> MLXArray { + if tokens.count > 0 { + let indices = MLXArray(tokens.map { UInt32($0) }) + var selectedLogits = logits[0..., indices] + + selectedLogits = MLX.where( + selectedLogits .< 0, selectedLogits * repetitionPenalty, + selectedLogits / repetitionPenalty) + + logits[0..., indices] = selectedLogits + return logits + } + + return logits + } + + mutating public func didSample(token: MLXArray) { + if tokens.count >= repetitionContextSize { + tokens[index] = token.item(Int.self) + index = (index + 1) % repetitionContextSize + } else { + tokens.append(token.item(Int.self)) + } + } +} + +/// Generator of tokens. +/// +/// This is typically used via a call to ``generate(input:parameters:context:didGenerate:)``. +/// +/// To use it directly: +/// +/// ```swift +/// let generateParameters: GenerateParameters +/// let input: LMInput +/// let model: LanguageModel +/// +/// let iterator = try TokenIterator(input: input, model: model, parameters: parameters) +/// +/// for token in iterator { +/// ... +/// } +/// ``` +/// +/// Tokens are integers that can be passed through a `Tokenizer` or ``StreamingDetokenizer`` to produce Strings. +/// +/// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py +/// +/// Note: this uses `asyncEval()` and there may be an async evaluation running after a call to `next()`. +public struct TokenIterator: Sequence, IteratorProtocol { + let model: any LanguageModel + var state: LMOutput.State? + + var y: LMInput.Text + var cache: [KVCache] + var processor: LogitProcessor? + let sampler: LogitSampler + + /// Initialize a `TokenIterator` with the given tokens. Note: this has been + /// replaced with ``init(input:model:cache:parameters:)``. + /// + /// - Parameters: + /// - prompt: the prompt tokens + /// - model: the ``LanguageModel`` + /// - cache: optional ``KVCache`` + /// - parameters: the generation parameters + @available(*, deprecated, message: "please use init(input:model:cache:parameters:)") + public init( + prompt: MLXArray, model: any LanguageModel, cache: [KVCache]? = nil, + parameters: GenerateParameters + ) throws { + self.model = model + self.y = .init(tokens: prompt) + self.cache = cache ?? model.newCache(parameters: parameters) + + self.processor = parameters.processor() + self.sampler = parameters.sampler() + + try prepare(input: .init(text: y), windowSize: parameters.prefillStepSize) + } + + /// Initialize a `TokenIterator` with the given input. + /// + /// If more control is needed over the generation, + /// ``init(input:model:cache:processor:sampler:prefillStepSize:)`` + /// allows a caller to specify ``LogitProcessor`` and ``LogitSampler`` + /// directly. + /// + /// - Parameters: + /// - input: language model input + /// - model: the ``LanguageModel`` + /// - cache: optional ``KVCache`` + /// - parameters: the generation parameters + public init( + input: LMInput, model: any LanguageModel, cache: [KVCache]? = nil, + parameters: GenerateParameters + ) throws { + self.model = model + self.y = input.text + self.cache = cache ?? model.newCache(parameters: parameters) + + self.processor = parameters.processor() + self.sampler = parameters.sampler() + + try prepare(input: input, windowSize: parameters.prefillStepSize) + } + + /// Initialize a `TokenIterator` with the given input and logit handling. + /// + /// - Parameters: + /// - input: language model input + /// - model: the ``LanguageModel`` + /// - cache: optional ``KVCache`` + /// - processor: the logit processor + /// - sampler: the logit sampler + /// - prefillStepSize: optional prefill step size + public init( + input: LMInput, model: any LanguageModel, cache: [KVCache]? = nil, + processor: LogitProcessor?, sampler: LogitSampler, prefillStepSize: Int = 512 + ) throws { + self.model = model + self.y = input.text + self.cache = cache ?? model.newCache(parameters: nil) + + self.processor = processor + self.sampler = sampler + + try prepare(input: input, windowSize: prefillStepSize) + } + + mutating func prepare(input: LMInput, windowSize: Int? = nil) throws { + processor?.prompt(input.text.tokens) + + switch try model.prepare(input, cache: cache, windowSize: windowSize) { + case .tokens(let tokens): + y = tokens + + // evaluate the remainder of the prompt -- this primes the pump + let token = step(previous: y) + y = .init(tokens: token) + asyncEval(y.tokens) + + case .logits(let result): + y = .init(tokens: convertToToken(logits: result.logits)) + asyncEval(y.tokens) + + break + } + } + + mutating func convertToToken(logits: MLXArray) -> MLXArray { + // process the logits (one hot array of possible tokens) + var logits = logits[0..., -1, 0...] + logits = processor?.process(logits: logits) ?? logits + + // transform logits back to a token + let y = sampler.sample(logits: logits) + + processor?.didSample(token: y) + + return y + } + + /// Evaluate the next token and return the new token (y), updating cache state + mutating func step(previous: LMInput.Text) -> MLXArray { + let result = model( + previous[text: .newAxis], cache: cache.isEmpty ? nil : cache, state: state) + self.state = result.state + + return convertToToken(logits: result.logits) + } + + mutating public func next() -> Int? { + // save current value -- this will be returned + let previousY = y + + // compute the next state and async eval the next token + let token = step(previous: previousY) + y = .init(tokens: token) + asyncEval(token) + + return previousY.tokens.item(Int.self) + } +} + +/// Result of a call to ``generate(input:parameters:context:didGenerate:)``. +public struct GenerateResult: Sendable { + /// input (prompt, images, etc.) + public let inputText: LMInput.Text + + @available(*, deprecated, message: "use inputText") + public var promptTokens: [Int] { + inputText.tokens.asArray(Int.self) + } + + /// output tokens + public let tokens: [Int] + + /// output text + public let output: String + + /// time to process the prompt / generate the first token + public let promptTime: TimeInterval + + /// time to generate the remaining tokens + public let generateTime: TimeInterval + + public var promptTokensPerSecond: Double { + Double(inputText.tokens.size) / promptTime + } + + public var tokensPerSecond: Double { + Double(tokens.count) / generateTime + } + + public func summary() -> String { + """ + Prompt: \(inputText.tokens.size) tokens, \(promptTokensPerSecond.formatted()) tokens/s + Generation: \(tokens.count) tokens, \(tokensPerSecond.formatted()) tokens/s, \(generateTime.formatted())s + """ + } +} + +/// Action from token visitor callback in ``generate(input:parameters:context:didGenerate:)``. +public enum GenerateDisposition: Sendable { + /// keep producing tokens until an EOS token is produced + case more + + /// stop producing tokens, e.g. a token limit has been hit + case stop +} + +/// Given prompt tokens generate text using the given model and parameters. +/// +/// ``generate(input:parameters:context:didGenerate:)`` is the preferred call. +/// +/// - Parameters: +/// - promptTokens: tokenized prompt +/// - parameters: generation parameters +/// - model: model to evaluate +/// - tokenizer: tokenizer to convert tokens back into strings and recognizer special tokens +/// - extraEOSTokens: any additional stop tokens +/// - didGenerate: visitor for the tokens as they are generated +@available(*, deprecated, message: "please use generate(input:parameters:context:didGenerate:)") +public func generate( + promptTokens: [Int], parameters: GenerateParameters, model: any LanguageModel, + tokenizer: Tokenizer, + extraEOSTokens: Set? = nil, + didGenerate: ([Int]) -> GenerateDisposition +) throws -> GenerateResult { + let tokens = MLXArray(promptTokens) + let iterator = try TokenIterator( + prompt: tokens, model: model, parameters: parameters) + + // this is a compatibility cover -- create the required values + // for the iteration + let input = LMInput(tokens: tokens) + let configuration = ModelConfiguration(id: "stand-in", extraEOSTokens: extraEOSTokens ?? []) + let context = ModelContext( + configuration: configuration, model: model, processor: StandInUserInputProcessor(), + tokenizer: tokenizer) + + return generate( + input: input, context: context, iterator: iterator, didGenerate: didGenerate) +} + +/// Generate tokens from an ``LMInput`` and a ``ModelContext``. +/// +/// For example: +/// +/// ```swift +/// let generateParameters: GenerateParameters +/// let input: UserInput +/// let context: ModelContext +/// +/// let lmInput = try context.processor.prepare(input: input) +/// let result = generate(input: lmInput, +/// parameters: generateParameters, +/// context: context) { tokens in +/// .more +/// } +/// ``` +/// +/// Internally this constructs a ``TokenIterator`` and calls +/// ``generate(input:context:iterator:didGenerate:)`` +/// +/// - Parameters: +/// - input: prepared language model input +/// - parameters: parameters controlling the token generation +/// - context: model context (model and tokenizer) +/// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop +/// - Returns: the generated output +public func generate( + input: LMInput, parameters: GenerateParameters, context: ModelContext, + didGenerate: ([Int]) -> GenerateDisposition +) throws -> GenerateResult { + let iterator = try TokenIterator( + input: input, model: context.model, parameters: parameters) + return generate( + input: input, context: context, iterator: iterator, didGenerate: didGenerate) +} + +/// Low level token generation using a ``TokenIterator``. +/// +/// ``generate(input:parameters:context:didGenerate:)`` is the preferred call. +/// +/// - Parameters: +/// - input: prepared language model input +/// - context: model context (model and tokenizer) +/// - iterator: token iterator +/// - didGenerate: token visitor that can output tokens as they are generated and indicate early stop +/// - Returns: the generated output +public func generate( + input: LMInput, context: ModelContext, + iterator: TokenIterator, + didGenerate: ([Int]) -> GenerateDisposition +) -> GenerateResult { + var start = Date.timeIntervalSinceReferenceDate + var promptTime: TimeInterval = 0 + + let additionalEOSTokenIds = Set( + (context.configuration.extraEOSTokens ?? []) + .compactMap { + context.tokenizer.convertTokenToId($0) + }) + + var tokens = [Int]() + + for token in iterator { + // compute the timing for the prompt + if tokens.isEmpty { + let now = Date.timeIntervalSinceReferenceDate + promptTime = now - start + start = now + } + + if token == context.tokenizer.unknownTokenId || token == context.tokenizer.eosTokenId + || additionalEOSTokenIds.contains(token) + { + break + } + tokens.append(token) + + if didGenerate(tokens) == .stop { + break + } + } + + let now = Date.timeIntervalSinceReferenceDate + let generateTime = now - start + + // TokenIterator uses `asyncEval()` to keep the pipeline full. If the caller + // exits the program right away, those tasks will still be executing and will + // hit assertions as the mlx scheduler is torn down. Synchronize with the stream + // to make sure it is complete. + Stream().synchronize() + + return GenerateResult( + inputText: input.text, tokens: tokens, + output: context.tokenizer.decode(tokens: tokens), + promptTime: promptTime, generateTime: generateTime) +} diff --git a/Libraries/LLM/KVCache.swift b/Libraries/MLXLMCommon/KVCache.swift similarity index 87% rename from Libraries/LLM/KVCache.swift rename to Libraries/MLXLMCommon/KVCache.swift index bafcba6..594d80c 100644 --- a/Libraries/LLM/KVCache.swift +++ b/Libraries/MLXLMCommon/KVCache.swift @@ -5,7 +5,7 @@ import MLX /// Interface for Key/Value cache for LLMs. /// -/// See ``LLMModel/newCache(parameters:)-47tyu`` +/// See ``LanguageModel/newCache(parameters:)`` public protocol KVCache: Evaluatable { /// get the current offset @@ -39,28 +39,20 @@ public func createAttentionMask(h: MLXArray, cache: [KVCache]?) -> MLXArray? { } /// See https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/base.py#L11 -class KVCacheSimple: KVCache, Evaluatable { - let kHeadDim: Int - let vHeadDim: Int - let kvHeads: Int - +public class KVCacheSimple: KVCache, Evaluatable { var keys: MLXArray? var values: MLXArray? - var offset = 0 + public var offset = 0 var step = 256 - init(headDim: IntOrPair, kvHeads: Int) { - self.kHeadDim = headDim.first - self.vHeadDim = headDim.second - self.kvHeads = kvHeads - } + public init() {} public func innerState() -> [MLXArray] { [self.keys, self.values].compactMap { $0 } } - func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) { + public func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) { let previous = self.offset let reset = @@ -71,6 +63,10 @@ class KVCacheSimple: KVCache, Evaluatable { } if reset { let B = keys.dim(0) + let kvHeads = keys.dim(1) + let kHeadDim = keys.dim(3) + let vHeadDim = values.dim(3) + let nSteps = (step + keys.dim(2) - 1) / step let kShape = [B, kvHeads, nSteps * step, kHeadDim] let vShape = [B, kvHeads, nSteps * step, vHeadDim] diff --git a/Libraries/MLXLMCommon/LanguageModel.swift b/Libraries/MLXLMCommon/LanguageModel.swift new file mode 100644 index 0000000..1c99ec5 --- /dev/null +++ b/Libraries/MLXLMCommon/LanguageModel.swift @@ -0,0 +1,219 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import Hub +import MLX +import MLXNN +import Tokenizers + +/// Time/Height/Width struct to represent information about input images. +public struct THW: Sendable { + + public let t: Int + public let h: Int + public let w: Int + + public init(_ t: Int, _ h: Int, _ w: Int) { + self.t = t + self.h = h + self.w = w + } + + public var values: (Int, Int, Int) { + (t, h, w) + } + + public var product: Int { t * h * w } +} + +/// Representation of ``LanguageModel`` input. +/// +/// This can contain text (tokens), prepared images (`MLXArray`), or other media as +/// needed. ``LMInput`` is produced by ``UserInputProcessor`` in response +/// to ``UserInput``. +/// +/// The ``ModelContext`` holds the ``UserInputProcessor`` associated with a +/// ``LanguageModel``. +public struct LMInput { + public let text: Text + public let image: ProcessedImage? + + /// Representation of tokenized input text. + public struct Text { + + /// input token array + public let tokens: MLXArray + + /// optional mask array + public let mask: MLXArray? + + public init(tokens: MLXArray, mask: MLXArray? = nil) { + self.tokens = tokens + self.mask = mask + } + + public subscript( + indices: MLXArrayIndex..., stream stream: StreamOrDevice = .default + ) -> Text { + Text(tokens: tokens[indices, stream: stream], mask: mask?[indices, stream: stream]) + } + + public subscript( + text indices: MLXArrayIndex..., stream stream: StreamOrDevice = .default + ) -> Text { + Text(tokens: tokens[indices, stream: stream], mask: mask) + } + } + + /// Representation of prepared input image(s). + public struct ProcessedImage { + + public let pixels: MLXArray + public let imageGridThw: [THW]? + + public init( + pixels: MLXArray, imageGridThw: [THW]? = nil + ) { + self.pixels = pixels + self.imageGridThw = imageGridThw + } + } + + public init(tokens: MLXArray, mask: MLXArray? = nil) { + self.init(text: .init(tokens: tokens, mask: mask)) + } + + public init(text: LMInput.Text, image: LMInput.ProcessedImage? = nil) { + self.text = text + self.image = image + } +} + +/// ``LanguageModel`` step output. This is consumed internally +/// by the ``TokenIterator``. +public struct LMOutput { + + /// logits (one hot vector of probabilities for tokens) + public let logits: MLXArray + + /// optional ``State`` to carry forward into the next step + public let state: State? + + public struct State { + public let crossAttentionStates: MLXArray? + + public init(crossAttentionStates: MLXArray? = nil) { + self.crossAttentionStates = crossAttentionStates + } + } + + public init(logits: MLXArray, state: LMOutput.State? = nil) { + self.logits = logits + self.state = state + } +} + +/// The result of the call to ``LanguageModel/prepare(_:cache:windowSize:)`` +public enum PrepareResult { + /// tokens to process by the ``TokenIterator`` + case tokens(LMInput.Text) + + /// logits representing the next token + case logits(LMOutput) +} + +/// Interface for all Language Models (e.g. LLM, VLM). +/// +/// The language model is typically called by the ``TokenIterator`` and it: +/// +/// - consumes the ``LMInput`` +/// - calls ``prepare(_:cache:windowSize:)`` to initialize the KVCache and consume the prompt +/// - calls ``callAsFunction(_:cache:state:)-9kuvf`` for each token, producing an ``LMOutput`` +/// - the ``TokenIterator`` accumulates this information into a ``GenerateResult`` +public protocol LanguageModel: Module { + + /// Prepare the cache state and consume the ``LMInput``. + /// + /// This can return: + /// - ``PrepareResult/tokens(_:)`` if the caller should evaluate the (remaining) tokens normally + /// - ``PrepareResult/logits(_:)`` to produce the next token from the prompt + func prepare(_ input: LMInput, cache: [KVCache], windowSize: Int?) throws -> PrepareResult + + /// Primary entry point to produce a step (single token) from the model + func callAsFunction(_ input: LMInput.Text, cache: [KVCache]?, state: LMOutput.State?) + -> LMOutput + + /// Models may implement this simplified interface if they do not produce any ``LMOutput/State`` + func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray + + /// create a new array of ``KVCache`` -- automatic implementation if self + /// implements ``KVCacheDimensionProvider`` + func newCache(parameters: GenerateParameters?) -> [KVCache] + + /// Optionally preprocess the weights and modify / remove values as needed. + func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] +} + +extension LanguageModel { + public func callAsFunction(_ input: LMInput.Text, cache: [KVCache]?, state: LMOutput.State?) + -> LMOutput + { + let logits = callAsFunction(input.tokens, cache: cache) + return .init(logits: logits) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray { + fatalError("callAsFunction(inputs:cache:) not implemented for \(Self.self)") + } +} + +extension LanguageModel { + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + weights + } +} + +/// Optional protocol that can be implemented by ``LanguageModel`` and will +/// provide an automatic implementation of ``LanguageModel/newCache(parameters:)`` +public protocol KVCacheDimensionProvider { + var kvHeads: [Int] { get } +} + +extension LanguageModel where Self: KVCacheDimensionProvider { + public func newCache(parameters: GenerateParameters?) -> [KVCache] { + kvHeads.map { n in + KVCacheSimple() + } + } +} + +/// Base ``LanguageModel`` configuration -- provides `modelType` +/// and `quantization` (used in loading the model). +/// +/// This is used by ``ModelFactory/load(hub:configuration:progressHandler:)`` +/// to determine the type of model to load. +public struct BaseConfiguration: Codable, Sendable { + public let modelType: String + + public struct Quantization: Codable, Sendable { + public init(groupSize: Int, bits: Int) { + self.groupSize = groupSize + self.bits = bits + } + + public let groupSize: Int + public let bits: Int + + enum CodingKeys: String, CodingKey { + case groupSize = "group_size" + case bits = "bits" + } + } + + public var quantization: Quantization? + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case quantization + } +} diff --git a/Libraries/LLM/Load.swift b/Libraries/MLXLMCommon/Load.swift similarity index 54% rename from Libraries/LLM/Load.swift rename to Libraries/MLXLMCommon/Load.swift index 33f99a9..4374512 100644 --- a/Libraries/LLM/Load.swift +++ b/Libraries/MLXLMCommon/Load.swift @@ -1,17 +1,24 @@ // Copyright © 2024 Apple Inc. import Foundation -@preconcurrency import Hub +import Hub import MLX import MLXNN -import MLXRandom import Tokenizers -struct LLMError: Error { - let message: String -} - -func prepareModelDirectory( +/// Download the model using the `HubApi`. +/// +/// This will download `*.safetensors` and `*.json` if the ``ModelConfiguration`` +/// represents a Hub id, e.g. `mlx-community/gemma-2-2b-it-4bit`. +/// +/// This is typically called via ``ModelFactory/load(hub:configuration:progressHandler:)`` +/// +/// - Parameters: +/// - hub: HubApi instance +/// - configuration: the model identifier +/// - progressHandler: callback for progress +/// - Returns: URL for the directory containing downloaded files +public func downloadModel( hub: HubApi, configuration: ModelConfiguration, progressHandler: @Sendable @escaping (Progress) -> Void ) async throws -> URL { @@ -20,17 +27,19 @@ func prepareModelDirectory( case .id(let id): // download the model weights let repo = Hub.Repo(id: id) - let modelFiles = ["*.safetensors", "config.json"] + let modelFiles = ["*.safetensors", "*.json"] return try await hub.snapshot( from: repo, matching: modelFiles, progressHandler: progressHandler) case .directory(let directory): return directory } + } catch Hub.HubClientError.authorizationRequired { // an authorizationRequired means (typically) that the named repo doesn't exist on // on the server so retry with local only configuration return configuration.modelDirectory(hub: hub) + } catch { let nserror = error as NSError if nserror.domain == NSURLErrorDomain && nserror.code == NSURLErrorNotConnectedToInternet { @@ -43,27 +52,15 @@ func prepareModelDirectory( } } -/// Load and return the model and tokenizer -public func load( - hub: HubApi = HubApi(), configuration: ModelConfiguration, - progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } -) async throws -> (LLMModel, Tokenizer) { - let modelDirectory = try await prepareModelDirectory( - hub: hub, configuration: configuration, progressHandler: progressHandler) - let model = try loadSynchronous(modelDirectory: modelDirectory) - let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub) - - return (model, tokenizer) -} - -func loadSynchronous(modelDirectory: URL) throws -> LLMModel { - // create the model (no weights loaded) - let configurationURL = modelDirectory.appending(component: "config.json") - let baseConfig = try JSONDecoder().decode( - BaseConfiguration.self, from: Data(contentsOf: configurationURL)) - - let model = try baseConfig.modelType.createModel(configuration: configurationURL) - +/// Load model weights. +/// +/// This is typically called via ``ModelFactory/load(hub:configuration:progressHandler:)``. +/// This function loads all `safetensor` files in the given `modelDirectory`, +/// calls ``LanguageModel/sanitize(weights:)``, applies optional quantization, and +/// updates the model with the weights. +public func loadWeights( + modelDirectory: URL, model: LanguageModel, quantization: BaseConfiguration.Quantization? = nil +) throws { // load the weights var weights = [String: MLXArray]() let enumerator = FileManager.default.enumerator( @@ -81,7 +78,7 @@ func loadSynchronous(modelDirectory: URL) throws -> LLMModel { weights = model.sanitize(weights: weights) // quantize if needed - if let quantization = baseConfig.quantization { + if let quantization { quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) { path, module in weights["\(path).scales"] != nil @@ -93,18 +90,4 @@ func loadSynchronous(modelDirectory: URL) throws -> LLMModel { try model.update(parameters: parameters, verify: [.all]) eval(model) - - return model -} - -/// Load and return the model and tokenizer wrapped in a ``ModelContainer`` (provides -/// thread safe access). -public func loadModelContainer( - hub: HubApi = HubApi(), configuration: ModelConfiguration, - progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } -) async throws -> ModelContainer { - let modelDirectory = try await prepareModelDirectory( - hub: hub, configuration: configuration, progressHandler: progressHandler) - return try await ModelContainer( - hub: hub, modelDirectory: modelDirectory, configuration: configuration) } diff --git a/Libraries/MLXLMCommon/Lora.swift b/Libraries/MLXLMCommon/Lora.swift new file mode 100644 index 0000000..5457d3a --- /dev/null +++ b/Libraries/MLXLMCommon/Lora.swift @@ -0,0 +1,230 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import MLX +import MLXNN +import MLXOptimizers +import MLXRandom +import Tokenizers + +/// Layers to apply LoRA adapters to. +/// +/// This is the value returned by ``LoRAModel/loraLinearLayers()``. +public typealias LoRALinearLayers = [(Module, [String])] + +public protocol LoRAModel { + /// Return the layers and keys to apply LoRA adapters to. + /// + /// For example this might apply the adapters to the `q` an `v` projections in the + /// Attention layers: + /// + /// ```swift + /// model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } + /// ``` + /// + /// It is not required that a model implement this protocol to have LoRA adapters applied, but + /// the command line driver example uses this to produce the ``LoRALinearLayers``. + /// + /// ### See Also + /// - ``LoRATrain/convert(model:layers:)`` + func loraLinearLayers() -> LoRALinearLayers + + /// Return a suffix of the layers and keys to apply LoRA adapters to. + /// + /// See ``loraLinearLayers()`` + func loraLinearLayers(_ count: Int) -> LoRALinearLayers +} + +extension LoRAModel { + public func loraLinearLayers(_ count: Int) -> LoRALinearLayers { + loraLinearLayers().suffix(count) + } +} + +/// Protocol for LoRA implementations that provides a method for converting back to a `Linear` +/// (or subtype). +/// +/// This is normally called via ``LoRATrain/fuse(model:layers:deQuantize:)`` +public protocol LoRAConvertToLinear { + func toLinear(deQuantize: Bool) -> Linear +} + +/// Implementation of LoRA `Linear` replacement layer. +/// +/// This layer implements the LoRA capabilities for `Linear` layers, specifically: +/// +/// - converting `Linear` or `QuantizedLinear` layers to ``LoRALinear`` / ``QLoRALinear`` +/// - converting ``LoRALinear`` back to `Linear` or `QuantizedLinear` (``LoRAConvertToLinear``) +/// - implementing the LoRA evaluation +/// +/// ``QLoRALinear`` is the equivalent class for `QuantizedLinear`. +/// +/// This is not typically used directly -- ``LoRATrain/convert(model:layers:)`` is used to +/// add the adapter layers to a given model. +/// +/// ### See Also +/// - [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) +/// - [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) +/// - ``QLoRALinear`` +/// - ``LoRATrain/convert(model:layers:)`` +/// - ``LoRATrain/fuse(model:layers:deQuantize:)`` +public class LoRALinear: Linear, LoRAConvertToLinear { + + let scale: Float + + @ParameterInfo(key: "lora_a") var loraA: MLXArray + @ParameterInfo(key: "lora_b") var loraB: MLXArray + + required public init( + _ inputDimensions: Int, _ outputDimensions: Int, rank: Int = 8, bias: Bool = false, + scale: Float = 20.0, linear: Linear + ) { + // Scale for low-rank update + self.scale = scale + + // Low rank lora weights + let loraScale = 1 / sqrt(Float(inputDimensions)) + self._loraA.wrappedValue = MLXRandom.uniform( + low: -loraScale, high: loraScale, [inputDimensions, rank]) + self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions]) + + super.init(weight: linear.weight, bias: linear.bias) + + freeze() + } + + /// Freeze all parameters except the lora parameters + public override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) + throws + { + // realize the keys and omit the lora parameters + let keys = + (keys ?? self.filterMap(filter: Self.filterLocalParameters).flattened().map { $0.0 }) + .filter { + $0 != "lora_a" && $0 != "lora_b" + } + try super.freeze(recursive: recursive, keys: keys, strict: strict) + } + + /// Convert a `Linear` or `QuantizedLinear` layer into a new `Linear` layer + /// that implements the `LoRA` adapter. + /// + /// This is typically called via ``LoRATrain/convert(model:layers:)``. + /// + /// ### See Also + /// - ``LoRATrain/convert(model:layers:)`` + /// - ``QLoRALinear/from(linear:rank:)`` + public static func from(linear: Linear, rank: Int = 8) -> Linear { + if let linear = linear as? QuantizedLinear { + return QLoRALinear.from(linear: linear, rank: rank) + } + let (outputDimensions, inputDimensions) = linear.shape + return LoRALinear(inputDimensions, outputDimensions, rank: rank, linear: linear) + } + + /// Convert back into a fused `Linear` layer. + /// + /// This is typically called via ``LoRATrain/fuse(model:layers:deQuantize:)``. + /// + /// ### See Also + /// - ``LoRATrain/fuse(model:layers:deQuantize:)`` + /// - ``LoRAConvertToLinear`` + /// - ``QLoRALinear/toLinear(deQuantize:)`` + public func toLinear(deQuantize: Bool = false) -> Linear { + let dtype = weight.dtype + let loraB = (scale * loraB.T).asType(dtype) + let loraA = loraA.T.asType(dtype) + return Linear(weight: weight + matmul(loraB, loraA), bias: bias) + } + + public override func callAsFunction(_ x: MLXArray) -> MLXArray { + let y = super.callAsFunction(x.asType(weight.dtype)) + let z = matmul(matmul(x, self.loraA), self.loraB) + return y + scale * z + } +} + +/// Implementation of LoRA `QuantizedLinear` replacement layer. +/// +/// See ``LoRALinear`` (equivalent class for `Linear` layers) for more information. +public class QLoRALinear: QuantizedLinear, LoRAConvertToLinear { + + let scale: Float + + @ParameterInfo(key: "lora_a") var loraA: MLXArray + @ParameterInfo(key: "lora_b") var loraB: MLXArray + + required public init( + _ inputDimensions: Int, _ outputDimensions: Int, rank: Int = 8, bias: Bool = false, + scale: Float = 20.0, linear: QuantizedLinear + ) { + + // Scale for low-rank update + self.scale = scale + + // Low rank lora weights + let loraScale = 1 / sqrt(Float(inputDimensions)) + self._loraA.wrappedValue = MLXRandom.uniform( + low: -loraScale, high: loraScale, [inputDimensions, rank]) + self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions]) + + super.init( + weight: linear.weight, bias: linear.bias, scales: linear.scales, biases: linear.biases, + groupSize: linear.groupSize, bits: linear.bits) + + // start frozen except for the lora keys + freeze() + } + + /// Freeze all parameters except the lora parameters + public override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) + throws + { + // realize the keys and omit the lora parameters + let keys = + (keys ?? self.filterMap(filter: Self.filterLocalParameters).flattened().map { $0.0 }) + .filter { + $0 != "lora_a" && $0 != "lora_b" + } + try super.freeze(recursive: recursive, keys: keys, strict: strict) + } + + /// Convert a `QuantizedLinear` layer into a new `Linear` layer + /// that implements the `LoRA` adapter. + /// + /// This is typically called via ``LoRATrain/convert(model:layers:)``. + /// + /// ### See Also + /// - ``LoRATrain/convert(model:layers:)`` + /// - ``LoRALinear/from(linear:rank:)`` + public static func from(linear: QuantizedLinear, rank: Int = 8) -> Linear { + var (outputDimensions, inputDimensions) = linear.shape + inputDimensions = inputDimensions * 32 / linear.bits + return QLoRALinear(inputDimensions, outputDimensions, rank: rank, linear: linear) + } + + /// Convert back into a fused `QuantizedLinear` layer. + /// + /// This is typically called via ``LoRATrain/fuse(model:layers:deQuantize:)``. + /// + /// ### See Also + /// - ``LoRATrain/fuse(model:layers:deQuantize:)`` + public func toLinear(deQuantize: Bool = false) -> Linear { + // convert back into full weights + let weight = dequantized( + weight, scales: scales, biases: biases, groupSize: groupSize, bits: bits) + + let loraB = (scale * loraB.T).asType(.float16) + let loraA = loraA.T.asType(.float16) + + // convert back into quantized + return QuantizedLinear( + weight: weight + matmul(loraB, loraA), bias: bias, groupSize: groupSize, bits: bits) + } + + public override func callAsFunction(_ x: MLXArray) -> MLXArray { + let y = super.callAsFunction(x.asType(scales.dtype)) + let z = matmul(matmul(x, self.loraA), self.loraB) + return y + scale * z + } +} diff --git a/Libraries/MLXLMCommon/ModelConfiguration.swift b/Libraries/MLXLMCommon/ModelConfiguration.swift new file mode 100644 index 0000000..bad4c8f --- /dev/null +++ b/Libraries/MLXLMCommon/ModelConfiguration.swift @@ -0,0 +1,75 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import Hub + +/// Configuration for a given model name with overrides for prompts and tokens. +/// +/// See e.g. `MLXLM.ModelRegistry` for an example of use. +public struct ModelConfiguration: Sendable { + + public enum Identifier: Sendable { + case id(String) + case directory(URL) + } + + public var id: Identifier + + public var name: String { + switch id { + case .id(let string): + string + case .directory(let url): + url.deletingLastPathComponent().lastPathComponent + "/" + url.lastPathComponent + } + } + + /// pull the tokenizer from an alternate id + public let tokenizerId: String? + + /// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated + public let overrideTokenizer: String? + + /// A reasonable default prompt for the model + public let defaultPrompt: String + + /// Additional tokens to use for end of string + public let extraEOSTokens: Set + + public init( + id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil, + defaultPrompt: String = "hello", + extraEOSTokens: Set = [], + preparePrompt: (@Sendable (String) -> String)? = nil + ) { + self.id = .id(id) + self.tokenizerId = tokenizerId + self.overrideTokenizer = overrideTokenizer + self.defaultPrompt = defaultPrompt + self.extraEOSTokens = extraEOSTokens + } + + public init( + directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil, + defaultPrompt: String = "hello", + extraEOSTokens: Set = [] + ) { + self.id = .directory(directory) + self.tokenizerId = tokenizerId + self.overrideTokenizer = overrideTokenizer + self.defaultPrompt = defaultPrompt + self.extraEOSTokens = extraEOSTokens + } + + public func modelDirectory(hub: HubApi = HubApi()) -> URL { + switch id { + case .id(let id): + // download the model weights and config + let repo = Hub.Repo(id: id) + return hub.localRepoLocation(repo) + + case .directory(let directory): + return directory + } + } +} diff --git a/Libraries/MLXLMCommon/ModelContainer.swift b/Libraries/MLXLMCommon/ModelContainer.swift new file mode 100644 index 0000000..b969378 --- /dev/null +++ b/Libraries/MLXLMCommon/ModelContainer.swift @@ -0,0 +1,78 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import Hub +import MLX +import MLXNN +import Tokenizers + +/// Container for models that guarantees single threaded access. +/// +/// Wrap models used by e.g. the UI in a ModelContainer. Callers can access +/// the model and/or tokenizer (any values from the ``ModelContext``): +/// +/// ```swift +/// let messages = [["role": "user", "content": prompt]] +/// let promptTokens = try await modelContainer.perform { context in +/// try context.tokenizer.applyChatTemplate(messages: messages) +/// } +/// ``` +/// +/// or: +/// +/// ```swift +/// let userInput: UserInput +/// let result = await modelContainer.perform { context in +/// let input = try await context.processor.prepare(input: userInput) +/// return generate( +/// input: input, parameters: generateParameters, context: context +/// ) { tokens in +/// ... +/// } +/// } +/// ``` +public actor ModelContainer { + let context: ModelContext + nonisolated public let configuration: ModelConfiguration + + public init(context: ModelContext) { + self.context = context + self.configuration = context.configuration + } + + /// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as + /// `MLXArray` is not `Sendable`. + @available(*, deprecated, message: "prefer perform(_:) that uses a ModelContext") + public func perform(_ action: @Sendable (any LanguageModel, Tokenizer) throws -> R) rethrows + -> R + { + try action(context.model, context.tokenizer) + } + + /// Perform an action on the model and/or tokenizer with additional context values. + /// Callers _must_ eval any `MLXArray` before returning as + /// `MLXArray` is not `Sendable`. + @available(*, deprecated, message: "prefer perform(values:_:) that uses a ModelContext") + public func perform( + values: V, _ action: @Sendable (any LanguageModel, Tokenizer, V) throws -> R + ) rethrows -> R { + try action(context.model, context.tokenizer, values) + } + + /// Perform an action on the ``ModelContext``. Callers _must_ eval any `MLXArray` before returning as + /// `MLXArray` is not `Sendable`. + public func perform(_ action: @Sendable (ModelContext) async throws -> R) async rethrows -> R + { + try await action(context) + } + + /// Perform an action on the ``ModelContext`` with additional context values. + /// Callers _must_ eval any `MLXArray` before returning as + /// `MLXArray` is not `Sendable`. + public func perform( + values: V, _ action: @Sendable (ModelContext, V) async throws -> R + ) async rethrows -> R { + try await action(context, values) + } + +} diff --git a/Libraries/MLXLMCommon/ModelFactory.swift b/Libraries/MLXLMCommon/ModelFactory.swift new file mode 100644 index 0000000..2035b7e --- /dev/null +++ b/Libraries/MLXLMCommon/ModelFactory.swift @@ -0,0 +1,93 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import Hub +import Tokenizers + +public enum ModelFactoryError: Error { + case unsupportedModelType(String) + case unsupportedProcessorType(String) +} + +/// Context of types that work together to provide a ``LanguageModel``. +/// +/// A ``ModelContext`` is created by ``ModelFactory/load(hub:configuration:progressHandler:)``. +/// This contains the following: +/// +/// - ``ModelConfiguration`` -- identifier for the model +/// - ``LanguageModel`` -- the model itself, see ``generate(input:parameters:context:didGenerate:)`` +/// - ``UserInputProcessor`` -- can convert ``UserInput`` into ``LMInput`` +/// - `Tokenizer` -- the tokenizer used by ``UserInputProcessor`` +/// +/// See also ``ModelFactory/loadContainer(hub:configuration:progressHandler:)`` and +/// ``ModelContainer``. +public struct ModelContext { + public let configuration: ModelConfiguration + public let model: any LanguageModel + public let processor: any UserInputProcessor + public let tokenizer: Tokenizer + + public init( + configuration: ModelConfiguration, model: any LanguageModel, + processor: any UserInputProcessor, tokenizer: any Tokenizer + ) { + self.configuration = configuration + self.model = model + self.processor = processor + self.tokenizer = tokenizer + } +} + +public protocol ModelFactory: Sendable { + + /// Resolve a model identifier, e.g. "mlx-community/Llama-3.2-3B-Instruct-4bit", into + /// a ``ModelConfiguration``. + /// + /// This will either create a new (mostly unconfigured) ``ModelConfiguration`` or + /// return a registered instance that matches the id. + func configuration(id: String) -> ModelConfiguration + + func _load( + hub: HubApi, configuration: ModelConfiguration, + progressHandler: @Sendable @escaping (Progress) -> Void + ) async throws -> ModelContext + + func _loadContainer( + hub: HubApi, configuration: ModelConfiguration, + progressHandler: @Sendable @escaping (Progress) -> Void + ) async throws -> ModelContainer +} + +extension ModelFactory { + + /// Load a model identified by a ``ModelConfiguration`` and produce a ``ModelContext``. + /// + /// This method returns a ``ModelContext``. See also + /// ``loadContainer(hub:configuration:progressHandler:)`` for a method that + /// returns a ``ModelContainer``. + public func load( + hub: HubApi = HubApi(), configuration: ModelConfiguration, + progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } + ) async throws -> ModelContext { + try await _load(hub: hub, configuration: configuration, progressHandler: progressHandler) + } + + /// Load a model identified by a ``ModelConfiguration`` and produce a ``ModelContainer``. + public func loadContainer( + hub: HubApi = HubApi(), configuration: ModelConfiguration, + progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } + ) async throws -> ModelContainer { + try await _loadContainer( + hub: hub, configuration: configuration, progressHandler: progressHandler) + } + + public func _loadContainer( + hub: HubApi = HubApi(), configuration: ModelConfiguration, + progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } + ) async throws -> ModelContainer { + let context = try await _load( + hub: hub, configuration: configuration, progressHandler: progressHandler) + return ModelContainer(context: context) + } + +} diff --git a/Libraries/MLXLMCommon/Module+Extensions.swift b/Libraries/MLXLMCommon/Module+Extensions.swift new file mode 100644 index 0000000..885dd27 --- /dev/null +++ b/Libraries/MLXLMCommon/Module+Extensions.swift @@ -0,0 +1,24 @@ +// Copyright © 2024 Apple Inc. + +import MLXNN + +extension Module { + + /// Compute the number of parameters in a possibly quantized model + public func numParameters() -> Int { + return leafModules().flattenedValues().map { + mod -> Int in + if let qlin = mod as? QuantizedLinear { + return qlin.scales.size * qlin.groupSize + } else if let qemb = mod as? QuantizedEmbedding { + return qemb.scales.size * qemb.groupSize + } else { + return mod.parameters().flattenedValues().reduce( + 0, + { + $0 + $1.size + }) + } + }.reduce(0, +) + } +} diff --git a/Libraries/MLXLMCommon/README.md b/Libraries/MLXLMCommon/README.md new file mode 100644 index 0000000..0f1d815 --- /dev/null +++ b/Libraries/MLXLMCommon/README.md @@ -0,0 +1,124 @@ +# MLXLMCommon + +MLXLMCommon contains types and code that is generic across many types +of language models, from LLMs to VLMs: + +- Evaluation +- KVCache +- Loading +- UserInput + +## Loading a Model + +A model is typically loaded by using a `ModelFactory` and a `ModelConfiguration`: + +```swift +// e.g. VLMModelFactory.shared +let modelFactory: ModelFactory + +// e.g. MLXVLM.ModelRegistry.paligemma3bMix4488bit +let modelConfiguration: ModelConfiguration + +let container = try await modelFactory.loadContainer(configuration: modelConfiguration) +``` + +The `container` provides an isolation context (an `actor`) to run inference in the model. + +Predefined `ModelConfiguration` instances are provided as static variables +on the `ModelRegistry` types or they can be created: + +```swift +let modelConfiguration = ModelConfiguration(id: "mlx-community/paligemma-3b-mix-448-8bit") +``` + +The flow inside the `ModelFactory` goes like this: + +```swift +public class VLMModelFactory: ModelFactory { + + public func _load( + hub: HubApi, configuration: ModelConfiguration, + progressHandler: @Sendable @escaping (Progress) -> Void + ) async throws -> ModelContext { + // download the weight and config using HubApi + // load the base configuration + // using the typeRegistry create a model (random weights) + // load the weights, apply quantization as needed, update the model + // calls model.sanitize() for weight preparation + // load the tokenizer + // (vlm) load the processor configuration, create the processor + } +} +``` + +Callers with specialized requirements can use these individual components to manually +load models, if needed. + +## Evaluation Flow + +- Load the Model +- UserInput +- LMInput +- generate() + - NaiveStreamingDetokenizer + - TokenIterator + +## Using a Model + +Once a model is loaded you can evaluate a prompt or series of +messages. Minimally you need to prepare the user input: + +```swift +let prompt = "Describe the image in English" +var input = UserInput(prompt: prompt, images: image.map { .url($0) }) +input.processing.resize = .init(width: 256, height: 256) +``` + +This example shows adding some images and processing instructions -- if +model accepts text only then these parts can be omitted. The inference +calls are the same. + +Assuming you are using a `ModelContainer` (an actor that holds +a `ModelContext`, which is the bundled set of types that implement a +model), the first step is to convert the `UserInput` into the +`LMInput` (LanguageModel Input): + +```swift +let generateParameters: GenerateParameters +let input: UserInput + +let result = try await modelContainer.perform { [input] context in + let input = try context.processor.prepare(input: input) + +``` + +Given that `input` we can call `generate()` to produce a stream +of tokens. In this example we use a `NaiveStreamingDetokenizer` +to assist in converting a stream of tokens into text and print it. +The stream is stopped after we hit a maximum number of tokens: + +``` + var detokenizer = NaiveStreamingDetokenizer(tokenizer: context.tokenizer) + + return try MLXLMCommon.generate( + input: input, parameters: generateParameters, context: context + ) { tokens in + + if let last = tokens.last { + detokenizer.append(token: last) + } + + if let new = detokenizer.next() { + print(new, terminator: "") + fflush(stdout) + } + + if tokens.count >= maxTokens { + return .stop + } else { + return .more + } + } +} +``` + diff --git a/Libraries/MLXLMCommon/StringOrNumber.swift b/Libraries/MLXLMCommon/StringOrNumber.swift new file mode 100644 index 0000000..f547d1f --- /dev/null +++ b/Libraries/MLXLMCommon/StringOrNumber.swift @@ -0,0 +1,96 @@ +// Copyright © 2024 Apple Inc. + +import Foundation + +/// Representation of a heterogenous type in a JSON configuration file. +/// +/// This can be: a string, a numeric value or an array of numeric values. +/// There are methods to do unwrapping, see e.g. ``asFloat()`` and +/// ``asFloats()`` or callers can switch on the enum. +public enum StringOrNumber: Codable, Equatable, Sendable { + case string(String) + case int(Int) + case float(Float) + case ints([Int]) + case floats([Float]) + + public init(from decoder: Decoder) throws { + let values = try decoder.singleValueContainer() + + if let v = try? values.decode(Int.self) { + self = .int(v) + } else if let v = try? values.decode(Float.self) { + self = .float(v) + } else if let v = try? values.decode([Int].self) { + self = .ints(v) + } else if let v = try? values.decode([Float].self) { + self = .floats(v) + } else { + let v = try values.decode(String.self) + self = .string(v) + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .string(let v): try container.encode(v) + case .int(let v): try container.encode(v) + case .float(let v): try container.encode(v) + case .ints(let v): try container.encode(v) + case .floats(let v): try container.encode(v) + } + } + + /// Return the value as an optional array of integers. + /// + /// This will not coerce `Float` or `String` to `Int`. + public func asInts() -> [Int]? { + switch self { + case .string(let string): nil + case .int(let v): [v] + case .float(let float): nil + case .ints(let array): array + case .floats(let array): nil + } + } + + /// Return the value as an optional integer. + /// + /// This will not coerce `Float` or `String` to `Int`. + public func asInt() -> Int? { + switch self { + case .string(let string): nil + case .int(let v): v + case .float(let float): nil + case .ints(let array): array.count == 1 ? array[0] : nil + case .floats(let array): nil + } + } + + /// Return the value as an optional array of floats. + /// + /// This will not coerce `Int` or `String` to `Float`. + public func asFloats() -> [Float]? { + switch self { + case .string(let string): nil + case .int(let v): [Float(v)] + case .float(let float): [float] + case .ints(let array): array.map { Float($0) } + case .floats(let array): array + } + } + + /// Return the value as an optional float. + /// + /// This will not coerce `Int` or `String` to `Float`. + public func asFloat() -> Float? { + switch self { + case .string(let string): nil + case .int(let v): Float(v) + case .float(let float): float + case .ints(let array): array.count == 1 ? Float(array[0]) : nil + case .floats(let array): array.count == 1 ? array[0] : nil + } + } +} diff --git a/Libraries/LLM/Tokenizer.swift b/Libraries/MLXLMCommon/Tokenizer.swift similarity index 95% rename from Libraries/LLM/Tokenizer.swift rename to Libraries/MLXLMCommon/Tokenizer.swift index e79782a..e575aa5 100644 --- a/Libraries/LLM/Tokenizer.swift +++ b/Libraries/MLXLMCommon/Tokenizer.swift @@ -4,6 +4,10 @@ import Foundation import Hub import Tokenizers +struct TokenizerError: Error { + let message: String +} + public func loadTokenizer(configuration: ModelConfiguration, hub: HubApi) async throws -> Tokenizer { let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig( @@ -13,7 +17,7 @@ public func loadTokenizer(configuration: ModelConfiguration, hub: HubApi) async tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) } -func loadTokenizerConfig(configuration: ModelConfiguration, hub: HubApi) async throws -> ( +public func loadTokenizerConfig(configuration: ModelConfiguration, hub: HubApi) async throws -> ( Config, Config ) { // from AutoTokenizer.from() -- this lets us override parts of the configuration @@ -45,7 +49,7 @@ func loadTokenizerConfig(configuration: ModelConfiguration, hub: HubApi) async t } guard var tokenizerConfig = try await config.tokenizerConfig else { - throw LLMError(message: "missing config") + throw TokenizerError(message: "missing config") } let tokenizerData = try await config.tokenizerData diff --git a/Libraries/MLXLMCommon/UserInput.swift b/Libraries/MLXLMCommon/UserInput.swift new file mode 100644 index 0000000..ad3cb7f --- /dev/null +++ b/Libraries/MLXLMCommon/UserInput.swift @@ -0,0 +1,151 @@ +// Copyright © 2024 Apple Inc. + +import CoreImage +import Foundation +import MLX + +/// Container for raw user input. +/// +/// A ``UserInputProcessor`` can convert this to ``LMInput``. +/// See also ``ModelContext``. +public struct UserInput: Sendable { + + /// Representation of a prompt or series of messages (conversation). + public enum Prompt: Sendable, CustomStringConvertible { + case text(String) + case messages([[String: String]]) + + public func asMessages() -> [[String: String]] { + switch self { + case .text(let text): + return [["role": "user", "content": text]] + case .messages(let messages): + return messages + } + } + + public var description: String { + switch self { + case .text(let text): + return text + case .messages(let messages): + return messages.map { $0.description }.joined(separator: "\n") + } + } + } + + /// Representation of a single image. + public enum Image: Sendable { + case ciImage(CIImage) + case url(URL) + case array(MLXArray) + + public func asCIImage() throws -> CIImage { + switch self { + case .ciImage(let image): + return image + + case .url(let url): + if let image = CIImage(contentsOf: url) { + return image + } + throw UserInputError.unableToLoad(url) + + case .array(let array): + guard array.ndim == 3 else { + throw UserInputError.arrayError("array must have 3 dimensions: \(array.ndim)") + } + + var array = array + + // convert to 0 .. 255 + if array.max().item(Float.self) <= 1.0 { + array = array * 255 + } + + // planar -> pixels + switch array.dim(0) { + case 3, 4: + // channels first (planar) + array = array.transposed(1, 2, 0) + default: + break + } + + // 4 components per pixel + switch array.dim(-1) { + case 3: + // pad to 4 bytes per pixel + array = padded(array, widths: [0, 0, [0, 1]], value: MLXArray(255)) + case 4: + // good + break + default: + throw UserInputError.arrayError( + "channel dimension must be last and 3/4: \(array.shape)") + break + } + + let arrayData = array.asData() + let (H, W, C) = array.shape3 + let cs = CGColorSpace(name: CGColorSpace.sRGB)! + + return CIImage( + bitmapData: arrayData.data, bytesPerRow: W * 4, + size: .init(width: W, height: H), + format: .RGBA8, colorSpace: cs) + } + } + } + + /// Representation of processing to apply to media. + public struct Processing: Sendable { + public var resize: CGSize? + + public init(resize: CGSize? = nil) { + self.resize = resize + } + } + + public var prompt: Prompt + public var images = [Image]() + public var processing: Processing = .init() + + public init(prompt: String, images: [Image] = [Image]()) { + self.prompt = .text(prompt) + self.images = images + } + + public init(messages: [[String: String]], images: [Image] = [Image]()) { + self.prompt = .messages(messages) + self.images = images + } + + public init(prompt: Prompt, images: [Image] = [Image](), processing: Processing = .init()) { + self.prompt = prompt + self.images = images + self.processing = processing + } +} + +/// Protocol for a type that can convert ``UserInput`` to ``LMInput``. +/// +/// See also ``ModelContext``. +public protocol UserInputProcessor { + func prepare(input: UserInput) async throws -> LMInput +} + +private enum UserInputError: Error { + case notImplemented + case unableToLoad(URL) + case arrayError(String) +} + +/// A do-nothing ``UserInputProcessor``. +public struct StandInUserInputProcessor: UserInputProcessor { + public init() {} + + public func prepare(input: UserInput) throws -> LMInput { + throw UserInputError.notImplemented + } +} diff --git a/Libraries/MNIST/Files.swift b/Libraries/MLXMNIST/Files.swift similarity index 100% rename from Libraries/MNIST/Files.swift rename to Libraries/MLXMNIST/Files.swift diff --git a/Libraries/MNIST/MNIST.swift b/Libraries/MLXMNIST/MNIST.swift similarity index 100% rename from Libraries/MNIST/MNIST.swift rename to Libraries/MLXMNIST/MNIST.swift diff --git a/Libraries/MNIST/README.md b/Libraries/MLXMNIST/README.md similarity index 100% rename from Libraries/MNIST/README.md rename to Libraries/MLXMNIST/README.md diff --git a/Libraries/MNIST/Random.swift b/Libraries/MLXMNIST/Random.swift similarity index 100% rename from Libraries/MNIST/Random.swift rename to Libraries/MLXMNIST/Random.swift diff --git a/Libraries/MLXVLM/MediaProcessing.swift b/Libraries/MLXVLM/MediaProcessing.swift new file mode 100644 index 0000000..abfdf3b --- /dev/null +++ b/Libraries/MLXVLM/MediaProcessing.swift @@ -0,0 +1,157 @@ +// Copyright © 2024 Apple Inc. + +import CoreImage.CIFilterBuiltins +import MLX +import MLXLMCommon + +private let context = CIContext() + +/// Collection of methods for processing media (images, video, etc.). +/// +/// A typical image preparation pipeline might look like this: +/// +/// ```swift +/// var image: CIImage +/// image = MediaProcessing.inSRGBToneCurveSpace(image) +/// +/// // apply user instructions +/// image = MediaProcessing.apply(image, processing: processing) +/// +/// image = MediaProcessing.resampleBicubic(image, to: config.size.cgSize) +/// image = MediaProcessing.normalize( +/// image, mean: config.imageMeanTuple, std: config.imageStdTuple) +/// +/// return MediaProcessing.asMLXArray(image) +/// ``` +/// +/// This is the responsibility of the `UserInputProcessor`. +public enum MediaProcessing { + + /// VLM media processing is normally done withut regard to the colorspace. Many, + /// though not all, images are stored in sRGB and this wiill be the implicit colorspace + /// used. This converts to a colorspace with an sRGB tone curve, though not necessarily + /// sRGB primaries, etc. + /// + /// See ``inLinearToneCurveSpace(_:)`` + static public func inSRGBToneCurveSpace(_ image: CIImage) -> CIImage { + let filter = CIFilter.linearToSRGBToneCurve() + filter.inputImage = image + return filter.outputImage! + } + + /// Inverse of ``inSRGBToneCurveSpace(_:)`` (for completeness). + static public func inLinearToneCurveSpace(_ image: CIImage) -> CIImage { + let filter = CIFilter.sRGBToneCurveToLinear() + filter.inputImage = image + return filter.outputImage! + } + + /// Compute the best fit size of one size in another (respecting aspect ratio). + static public func bestFit(_ size: CGSize, in other: CGSize) -> CGSize { + let scale = bestFitScale(size, in: other) + return CGSize(width: round(size.width * scale), height: round(size.height * scale)) + } + + /// Compute the best fit scale of one size in another (respecting aspect ratio). + static public func bestFitScale(_ size: CGSize, in other: CGSize) -> CGFloat { + min(other.width / size.width, other.height / size.height) + } + + /// Resample the image using bicubic interpolation. + static public func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage { + let filter = CIFilter.bicubicScaleTransform() + let extent = image.extent.size + + filter.inputImage = image + + // set the aspect ratio to match the aspect ratio of the target + let inputAspectRatio = extent.width / extent.height + let desiredAspectRatio = size.width / size.height + filter.aspectRatio = Float(1 / inputAspectRatio * desiredAspectRatio) + + // that image is now the aspect ratio of the target and the size + // of the shorter dimension + let scale: CGFloat + if extent.width < extent.height { + scale = size.width / extent.width + } else { + scale = size.height / extent.height + } + filter.scale = Float(scale) + + let rescaled = filter.outputImage! + + // the image has a DoD larger than the requested size so crop + // it to the desired size + return rescaled.cropped(to: CGRect(origin: .zero, size: size)) + } + + /// Normalize the image using the given mean and standard deviation parameters. + static public func normalize( + _ image: CIImage, mean: (CGFloat, CGFloat, CGFloat), std: (CGFloat, CGFloat, CGFloat) + ) -> CIImage { + let filter = CIFilter.colorMatrix() + filter.inputImage = image + + // this should match + // https://pytorch.org/vision/main/generated/torchvision.transforms.Normalize.html + // + // output[channel] = (input[channel] - mean[channel]) / std[channel] + // + // The CI filter computes input * factor + bias so we want to do: + // input * 1 / std - mean / std + + filter.rVector = .init(x: 1 / std.0, y: 0, z: 0, w: 0) + filter.gVector = .init(x: 0, y: 1 / std.1, z: 0, w: 0) + filter.bVector = .init(x: 0, y: 0, z: 1 / std.2, w: 0) + + filter.aVector = .init(x: 0, y: 0, z: 0, w: 1) + filter.biasVector = .init(x: -mean.0 / std.0, y: -mean.1 / std.1, z: -mean.2 / std.2, w: 0) + + return filter.outputImage! + } + + /// Convert the CIImage into a planar 3 channel MLXArray `[1, C, H, W]` + static public func asMLXArray(_ image: CIImage, colorSpace: CGColorSpace? = nil) -> MLXArray { + let size = image.extent.size + let w = Int(size.width.rounded()) + let h = Int(size.height.rounded()) + + // probably not strictly necessary, but this is what happens in + // e.g. image_processing_siglip in transformers (float32) + let format = CIFormat.RGBAf + let componentsPerPixel = 4 + let bytesPerPixel = componentsPerPixel * 4 + let bytesPerRow = w * bytesPerPixel + + var data = Data(count: w * h * bytesPerPixel) + data.withUnsafeMutableBytes { ptr in + context.render( + image, toBitmap: ptr.baseAddress!, rowBytes: bytesPerRow, bounds: image.extent, + format: format, colorSpace: colorSpace) + context.clearCaches() + } + + var array = MLXArray(data, [h, w, 4], type: Float32.self) + + // drop 4th channel + array = array[0..., 0..., ..<3] + + // convert to 1, C, H, W + array = array.reshaped(1, h, w, 3).transposed(0, 3, 1, 2) + + return array + } + + /// Apply `UserInput.Processing`, if needed, to the image. + static func apply(_ image: CIImage, processing: UserInput.Processing?) -> CIImage { + var image = image + + if let resize = processing?.resize { + let scale = bestFitScale(image.extent.size, in: resize) + image = image.transformed(by: CGAffineTransform(scaleX: scale, y: scale)) + } + + return image + } +} diff --git a/Libraries/MLXVLM/Models/Paligemma.swift b/Libraries/MLXVLM/Models/Paligemma.swift new file mode 100644 index 0000000..a103ccb --- /dev/null +++ b/Libraries/MLXVLM/Models/Paligemma.swift @@ -0,0 +1,736 @@ +// Copyright © 2024 Apple Inc. + +// port of https://github.com/Blaizzy/mlx-vlm/tree/main/mlx_vlm/models/paligemma + +import CoreImage +import Foundation +import Hub +import MLX +import MLXFast +import MLXLMCommon +import MLXNN +import Tokenizers + +// MARK: - Language + +private enum Language { + + // specialized norm for gemma + fileprivate class RMSNorm: Module, UnaryLayer { + let weight: MLXArray + let eps: Float + + public init(dimensions: Int, eps: Float = 1e-5) { + self.weight = MLXArray.ones([dimensions]).asType(.float16) + self.eps = eps + super.init() + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + return MLXFast.rmsNorm(x, weight: 1.0 + self.weight, eps: self.eps) + } + } + + fileprivate class Attention: Module { + + let args: PaliGemmaConfiguration.TextConfiguration + let scale: Float + + @ModuleInfo(key: "q_proj") var wq: Linear + @ModuleInfo(key: "k_proj") var wk: Linear + @ModuleInfo(key: "v_proj") var wv: Linear + @ModuleInfo(key: "o_proj") var wo: Linear + + let rope: RoPE + + public init(_ args: PaliGemmaConfiguration.TextConfiguration) { + self.args = args + + let dim = args.hiddenSize + let heads = args.attentionHeads + let kvHeads = args.kvHeads + + let headDim = args.hiddenSize / heads + self.scale = pow(Float(headDim), -0.5) + + self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false) + self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false) + self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false) + self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false) + + self.rope = RoPE( + dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? + ) -> MLXArray { + let (B, L) = (x.dim(0), x.dim(1)) + + var queries = wq(x) + var keys = wk(x) + var values = wv(x) + + // prepare the queries, keys and values for the attention computation + queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + + if let cache { + queries = rope(queries, offset: cache.offset) + keys = rope(keys, offset: cache.offset) + (keys, values) = cache.update(keys: keys, values: values) + } else { + queries = rope(queries) + keys = rope(keys) + } + + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) + + return wo(output) + } + } + + fileprivate class MLP: Module, UnaryLayer { + + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "down_proj") var down: Linear + @ModuleInfo(key: "up_proj") var up: Linear + + public init(dimensions: Int, hiddenDimensions: Int) { + self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) + self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + down(gelu(gate(x)) * up(x)) + } + } + + fileprivate class TransformerBlock: Module { + + @ModuleInfo(key: "self_attn") var attention: Attention + let mlp: MLP + + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm + + public init(_ args: PaliGemmaConfiguration.TextConfiguration) { + self._attention.wrappedValue = Attention(args) + self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) + self._inputLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + self._postAttentionLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? + ) -> MLXArray { + var r = attention(inputLayerNorm(x), mask: mask, cache: cache) + let h = x + r + r = mlp(postAttentionLayerNorm(h)) + let out = h + r + return out + } + } + + fileprivate class GemmaModel: Module { + + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + + fileprivate let layers: [TransformerBlock] + fileprivate let norm: RMSNorm + + let hiddenScale: Float + + public init(_ args: PaliGemmaConfiguration.TextConfiguration) { + precondition(args.vocabularySize > 0) + + self._embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + + self.hiddenScale = pow(Float(args.hiddenSize), 0.5) + + self.layers = (0 ..< args.hiddenLayers) + .map { _ in + TransformerBlock(args) + } + self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction( + _ inputs: MLXArray, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil, + mask: MLXArray? = nil + ) -> MLXArray { + var h = inputEmbedding ?? embedTokens(inputs) + h = h * hiddenScale + + let mask: MLXArray? = + if mask == nil || (cache?[0].offset ?? 0) > 0 { + createAttentionMask(h: h, cache: cache) + } else { + nil + } + + for (i, layer) in layers.enumerated() { + h = layer(h, mask: mask, cache: cache?[i]) + } + + return norm(h) + } + } + + fileprivate class LanguageModel: Module, KVCacheDimensionProvider { + @ModuleInfo var model: GemmaModel + + var kvHeads: [Int] + + public init(_ args: PaliGemmaConfiguration.TextConfiguration) { + self.model = GemmaModel(args) + + self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } + } + + public func callAsFunction( + _ inputs: MLXArray, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil, + mask: MLXArray? = nil + ) -> LMOutput { + var out = model(inputs, cache: cache, inputEmbedding: inputEmbedding, mask: mask) + out = model.embedTokens.asLinear(out) + return LMOutput(logits: out) + } + + func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + weights.filter { + !$0.key.contains("self_attn.rotary_emb.inv_freq") + } + } + } +} + +// MARK: - Vision + +private enum Vision { + fileprivate class Attention: Module { + + let numHeads: Int + let scale: Float + + @ModuleInfo(key: "q_proj") var wq: Linear + @ModuleInfo(key: "k_proj") var wk: Linear + @ModuleInfo(key: "v_proj") var wv: Linear + @ModuleInfo(key: "out_proj") var wo: Linear + + public init(dims: Int, numHeads: Int, bias: Bool = true) { + precondition(dims % numHeads == 0, "Dimensions must be divisible by numHeads") + + self.numHeads = numHeads + let headDim = dims / numHeads + self.scale = pow(Float(headDim), -0.5) + + self._wq.wrappedValue = Linear(dims, dims, bias: bias) + self._wk.wrappedValue = Linear(dims, dims, bias: bias) + self._wv.wrappedValue = Linear(dims, dims, bias: bias) + self._wo.wrappedValue = Linear(dims, dims, bias: bias) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil + ) -> MLXArray { + var queries = wq(x) + var keys = wk(x) + var values = wv(x) + + let (B, L) = (queries.dim(0), queries.dim(1)) + let S = keys.dim(1) + + queries = queries.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, S, numHeads, -1).transposed(0, 2, 1, 3) + values = values.reshaped(B, S, numHeads, -1).transposed(0, 2, 1, 3) + + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) + + return wo(output) + } + } + + fileprivate class PhiMLP: Module, UnaryLayer { + + @ModuleInfo var fc1: Linear + @ModuleInfo var fc2: Linear + + public init(_ config: PaliGemmaConfiguration.VisionConfiguration) { + self.fc1 = Linear(config.hiddenSize, config.intermediateSize, bias: true) + self.fc2 = Linear(config.intermediateSize, config.hiddenSize, bias: true) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + fc2(geluApproximate(fc1(x))) + } + } + + fileprivate class EncoderLayer: Module { + + @ModuleInfo(key: "self_attn") var attention: Attention + @ModuleInfo(key: "layer_norm1") var layerNorm1: LayerNorm + @ModuleInfo var mlp: PhiMLP + @ModuleInfo(key: "layer_norm2") var layerNorm2: LayerNorm + + public init(_ config: PaliGemmaConfiguration.VisionConfiguration) { + self._attention.wrappedValue = Attention( + dims: config.hiddenSize, numHeads: config.attentionHeads, bias: true) + self._layerNorm1.wrappedValue = LayerNorm( + dimensions: config.hiddenSize, eps: config.layerNormEps) + self.mlp = PhiMLP(config) + self._layerNorm2.wrappedValue = LayerNorm( + dimensions: config.hiddenSize, eps: config.layerNormEps) + } + + public func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil) -> MLXArray { + var r = attention(layerNorm1(x), mask: mask) + let h = x + r + r = mlp(layerNorm2(h)) + return h + r + } + } + + fileprivate class Encoder: Module { + var layers: [EncoderLayer] + + public init(_ config: PaliGemmaConfiguration.VisionConfiguration) { + self.layers = (0 ..< config.hiddenLayers).map { _ in + EncoderLayer(config) + } + } + + public func callAsFunction( + _ x: MLXArray, outputHiddenStates: Bool = false, mask: MLXArray? = nil + ) -> (MLXArray, [MLXArray]?) { + var encoderStates: [MLXArray]? = outputHiddenStates ? [] : nil + var h = x + var x = x + for l in layers { + x = l(x, mask: mask) + if outputHiddenStates { + encoderStates?.append(x) + } + h = x[0] + } + return (h, encoderStates) + } + } + + fileprivate class VisionEmbeddings: Module, UnaryLayer { + + @ModuleInfo(key: "patch_embedding") var patchEmbedding: Conv2d + @ModuleInfo(key: "position_embedding") var positionEmbedding: Embedding + + let positions: Int + let positionIds: MLXArray + + public init(_ config: PaliGemmaConfiguration.VisionConfiguration) { + self._patchEmbedding.wrappedValue = Conv2d( + inputChannels: config.channels, outputChannels: config.hiddenSize, + kernelSize: .init(config.patchSize), stride: .init(config.patchSize) + ) + let d = config.imageSize / config.patchSize + self.positions = d * d + self._positionEmbedding.wrappedValue = Embedding( + embeddingCount: positions, dimensions: config.hiddenSize + ) + self.positionIds = MLXArray(0 ..< positions)[.newAxis, 0...] + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + var patchEmbeddings = self.patchEmbedding(x) + patchEmbeddings = patchEmbeddings.flattened(start: 1, end: 2) + let embeddings = patchEmbeddings + self.positionEmbedding(self.positionIds) + return embeddings + } + } + + fileprivate class SigLipVisionModel: Module { + + @ModuleInfo var embeddings: VisionEmbeddings + @ModuleInfo var encoder: Encoder + @ModuleInfo(key: "post_layernorm") var postLayerNorm: LayerNorm + + public init(_ config: PaliGemmaConfiguration.VisionConfiguration) { + self.embeddings = VisionEmbeddings(config) + self.encoder = Encoder(config) + self._postLayerNorm.wrappedValue = LayerNorm(dimensions: config.hiddenSize) + } + + public func callAsFunction(_ x: MLXArray, outputHiddenStates: Bool = false) -> ( + MLXArray, MLXArray, MLXArray? + ) { + let x = embeddings(x) + + let (encoderOutput, hiddenStates) = encoder(x, outputHiddenStates: outputHiddenStates) + let poolerOutput = postLayerNorm(encoderOutput) + + return (poolerOutput, x, hiddenStates?.last) + } + } + + fileprivate class VisionModel: Module { + + @ModuleInfo(key: "vision_model") var visionModel: SigLipVisionModel + + public init(_ config: PaliGemmaConfiguration.VisionConfiguration) { + precondition( + config.modelType == "siglip_vision_model", + "Unsupported modelType: \(config.modelType)") + self._visionModel.wrappedValue = SigLipVisionModel(config) + } + + public func callAsFunction(_ x: MLXArray, outputHiddenStates: Bool = false) -> ( + MLXArray, MLXArray, MLXArray? + ) { + visionModel(x, outputHiddenStates: outputHiddenStates) + } + + private func isMLXWeight(_ array: MLXArray) -> Bool { + if array.ndim != 4 { + return false + } + + let (outChannels, kH, kW) = (array.dim(0), array.dim(1), array.dim(2)) + return outChannels >= kH && outChannels >= kW && kH == kW + } + + func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + var sanitizedWeights = [String: MLXArray]() + + for (k, v) in weights { + if k.contains("position_id") { + // Remove unused position_ids + continue + } else if k.contains("patch_embedding.weight") { + // PyTorch conv2d weight tensors have shape: + // [out_channels, in_channels, kH, KW] + // MLX conv2d expects the weight be of shape: + // [out_channels, kH, KW, in_channels] + if isMLXWeight(v) { + sanitizedWeights[k] = v + } else { + sanitizedWeights[k] = v.transposed(0, 2, 3, 1) + } + } else { + sanitizedWeights[k] = v + } + } + + return sanitizedWeights + } + } +} + +// MARK: - Processor + +/// PaliGemma VLM `UserInputProcessor`. +/// +/// This is meant to be used with ``PaliGemma`` and is typically created by ``VLMModelFactory``. +public class PaligGemmaProcessor: UserInputProcessor { + + private let config: PaliGemmaProcessorConfiguration + private let tokenizer: any Tokenizer + + public init(_ config: PaliGemmaProcessorConfiguration, tokenizer: any Tokenizer) { + self.config = config + self.tokenizer = tokenizer + } + + private func prepare(image: CIImage, processing: UserInput.Processing?) -> MLXArray { + // based on image_processing_siglip from transformers + var image = image + + // we want to do all of the image processing in an sRGB tone curve + // rather than a linear space as that is what transformers / torch_vision + // do (implicitly by using sRGB rasters directly) + image = MediaProcessing.inSRGBToneCurveSpace(image) + + // apply user instructions + image = MediaProcessing.apply(image, processing: processing) + + image = MediaProcessing.resampleBicubic(image, to: config.size.cgSize) + image = MediaProcessing.normalize( + image, mean: config.imageMeanTuple, std: config.imageStdTuple) + + return MediaProcessing.asMLXArray(image) + } + + public func prepare(input: UserInput) throws -> LMInput { + switch input.images.count { + case 0: throw VLMError.imageRequired + case 1: break + default: throw VLMError.singleImageAllowed + } + + // this doesn't have a chat template so just use the last message. + var prompt = input.prompt.asMessages().last?["content"] ?? "" + + // based on transformers/processing_paligemma + let count = input.images.count * config.imageSequenceLength + prompt = + Array(repeating: "", count: count).joined() + (tokenizer.bosToken ?? "") + prompt + + "\n" + + let promptTokens = try tokenizer.encode(text: prompt) + let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) + let mask = ones(like: promptArray).asType(.int8) + + let pixels = try prepare(image: input.images[0].asCIImage(), processing: input.processing) + + return LMInput(text: .init(tokens: promptArray, mask: mask), image: .init(pixels: pixels)) + } + +} + +// MARK: - Model + +private class PaliGemmaMultiModalProjector: Module, UnaryLayer { + + @ModuleInfo var linear: Linear + + public init(_ config: PaliGemmaConfiguration.VisionConfiguration) { + self.linear = Linear(config.hiddenSize, config.projectionDimensions, bias: true) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + linear(x) + } +} + +/// PaliGemma VLM +/// +/// This is typically created by ``VLMModelFactory``. +public class PaliGemma: Module, VLMModel, KVCacheDimensionProvider { + + @ModuleInfo(key: "vision_tower") private var visionModel: Vision.VisionModel + @ModuleInfo(key: "language_model") private var languageModel: Language.LanguageModel + @ModuleInfo(key: "multi_modal_projector") private var multiModalProjector: + PaliGemmaMultiModalProjector + + public let config: PaliGemmaConfiguration + + public var vocabularySize: Int { config.vocabularySize } + public var kvHeads: [Int] { languageModel.kvHeads } + + public func loraLinearLayers() -> MLXLMCommon.LoRALinearLayers { + languageModel.model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } + } + + public init(_ config: PaliGemmaConfiguration) { + self.config = config + self._visionModel.wrappedValue = Vision.VisionModel(config.visionConfiguration) + self._languageModel.wrappedValue = Language.LanguageModel(config.textConfiguration) + self._multiModalProjector.wrappedValue = PaliGemmaMultiModalProjector( + config.visionConfiguration) + } + + private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, mask: MLXArray) -> ( + MLXArray, MLXArray + ) { + guard let pixelValues else { + return (inputIds, mask) + } + + let inputEmbedding = languageModel.model.embedTokens(inputIds) + let (hiddenState, _, _) = self.visionModel( + pixelValues.transposed(0, 2, 3, 1).asType(inputEmbedding.dtype), + outputHiddenStates: true + ) + + var imageFeatures = hiddenState[.newAxis, .ellipsis].asType(inputEmbedding.dtype) + imageFeatures = multiModalProjector(imageFeatures) + + return prepareInputsForMultimodal( + imageFeatures: imageFeatures, inputEmbedding: inputEmbedding, + inputIds: inputIds, attentionMask: mask) + } + + private func prepareInputsForMultimodal( + imageFeatures: MLXArray, inputEmbedding: MLXArray, inputIds: MLXArray, + attentionMask: MLXArray + ) -> (MLXArray, MLXArray) { + let embedDimension = imageFeatures.dim(2) + let (batchSize, sequenceLength) = inputIds.shape2 + var scaledImageFeatures = imageFeatures / pow(Float(config.hiddenSize), 0.5) + + let textMask = (inputIds .!= config.imageTokenIndex) & (inputIds .!= config.padTokenId) + let imageMask = inputIds .== config.imageTokenIndex + let padMask = inputIds .== config.padTokenId + + // expand masks to match embedding dimension + var textMaskExpanded = expandedDimensions(textMask, axis: -1) + var padMaskExpanded = expandedDimensions(padMask, axis: -1) + + // insert padding and text token embeddings + var finalEmbedding = which(textMaskExpanded, inputEmbedding, 0) + finalEmbedding = which(padMaskExpanded, 0, finalEmbedding) + + let padSize = finalEmbedding.dim(1) - scaledImageFeatures.dim(1) + scaledImageFeatures = padded(scaledImageFeatures, widths: [0, .init((0, padSize)), 0]) + + // insert image embeddings - the image mask is always less or equal to the sentence in length + var imageMaskExpanded = expandedDimensions(imageMask, axis: -1) + finalEmbedding = which(imageMaskExpanded, scaledImageFeatures, finalEmbedding) + + finalEmbedding = which(padMaskExpanded, 0, finalEmbedding) + + let attentionMaskExpanded1 = expandedDimensions(attentionMask, axis: 1) + let attentionMaskExpanded2 = expandedDimensions(attentionMask, axis: 2) + var finalAttentionMask4d = attentionMaskExpanded1 * attentionMaskExpanded2 + finalAttentionMask4d = expandedDimensions(finalAttentionMask4d, axis: 1) + + return (finalEmbedding, finalAttentionMask4d) + } + + public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws + -> PrepareResult + { + guard let image = input.image else { throw VLMError.imageRequired } + guard let mask = input.text.mask else { throw VLMError.maskRequired } + let inputIds = input.text.tokens + + let (inputEmbedding, finalAttentionMask4d) = inputEmbeddings( + inputIds: inputIds, pixelValues: image.pixels, mask: mask) + + let result = languageModel( + inputIds, cache: cache, inputEmbedding: inputEmbedding, mask: finalAttentionMask4d) + + return .logits(result) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [any KVCache]?) -> MLXArray { + languageModel(inputs, cache: cache).logits + } + + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + languageModel.sanitize(weights: visionModel.sanitize(weights: weights)) + } +} + +// MARK: - Configuration + +/// Confguration for ``PaliGemma`` +public struct PaliGemmaConfiguration: Codable, Sendable { + + public struct TextConfiguration: Codable, Sendable { + public let modelType: String + public let hiddenSize: Int + public let hiddenLayers: Int + public let intermediateSize: Int + public let attentionHeads: Int + public let kvHeads: Int + public let vocabularySize: Int + private let _rmsNormEps: Float? + public var rmsNormEps: Float { _rmsNormEps ?? 1e-6 } + private let _ropeTheta: Float? + public var ropeTheta: Float { _ropeTheta ?? 10_000 } + private let _ropeTraditional: Bool? + public var ropeTraditional: Bool { _ropeTraditional ?? false } + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case hiddenSize = "hidden_size" + case hiddenLayers = "num_hidden_layers" + case intermediateSize = "intermediate_size" + case attentionHeads = "num_attention_heads" + case kvHeads = "num_key_value_heads" + case vocabularySize = "vocab_size" + case _rmsNormEps = "rms_norm_eps" + case _ropeTheta = "rope_theta" + case _ropeTraditional = "rope_traditional" + } + } + + public struct VisionConfiguration: Codable, Sendable { + public let modelType: String + public let hiddenSize: Int + public let hiddenLayers: Int + public let intermediateSize: Int + public let attentionHeads: Int + public let patchSize: Int + public let projectionDimensions: Int + public let imageSize: Int + private let _channels: Int? + public var channels: Int { _channels ?? 3 } + private let _layerNormEps: Float? + public var layerNormEps: Float { _layerNormEps ?? 1e-6 } + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case hiddenSize = "hidden_size" + case hiddenLayers = "num_hidden_layers" + case intermediateSize = "intermediate_size" + case attentionHeads = "num_attention_heads" + case patchSize = "patch_size" + case projectionDimensions = "projection_dim" + case imageSize = "image_size" + case _channels = "num_channels" + case _layerNormEps = "layer_norm_eps" + } + } + + public let textConfiguration: TextConfiguration + public let visionConfiguration: VisionConfiguration + public let modelType: String + public let vocabularySize: Int + public let ignoreIndex: Int + public let imageTokenIndex: Int + public let hiddenSize: Int + public let padTokenId: Int + + enum CodingKeys: String, CodingKey { + case textConfiguration = "text_config" + case visionConfiguration = "vision_config" + case modelType = "model_type" + case vocabularySize = "vocab_size" + case ignoreIndex = "ignore_index" + case imageTokenIndex = "image_token_index" + case hiddenSize = "hidden_size" + case padTokenId = "pad_token_id" + } +} + +/// Configuration for ``PaligGemmaProcessor`` +public struct PaliGemmaProcessorConfiguration: Codable, Sendable { + + public struct Size: Codable, Sendable { + public let width: Int + public let height: Int + + var cgSize: CGSize { .init(width: width, height: height) } + } + + public let imageMean: [CGFloat] + public let imageStd: [CGFloat] + public let size: Size + public let imageSequenceLength: Int + + public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { + (imageMean[0], imageMean[1], imageMean[2]) + } + public var imageStdTuple: (CGFloat, CGFloat, CGFloat) { + (imageStd[0], imageStd[1], imageStd[2]) + } + + enum CodingKeys: String, CodingKey { + case imageMean = "image_mean" + case imageStd = "image_std" + case size + case imageSequenceLength = "image_seq_length" + } +} diff --git a/Libraries/MLXVLM/Models/Qwen2VL.swift b/Libraries/MLXVLM/Models/Qwen2VL.swift new file mode 100644 index 0000000..5c1a7fc --- /dev/null +++ b/Libraries/MLXVLM/Models/Qwen2VL.swift @@ -0,0 +1,1005 @@ +// Copyright © 2024 Apple Inc. + +// port of https://github.com/Blaizzy/mlx-vlm/tree/main/mlx_vlm/models/qwen2_vl + +import CoreImage +import Foundation +import Hub +import MLX +import MLXFast +import MLXLMCommon +import MLXNN +import Tokenizers + +// MARK: - Common + +/// Rotates half the hidden dims of the input +private func rotateHalf(_ x: MLXArray) -> MLXArray { + let index = x.dim(-1) / 2 + let x1 = x[.ellipsis, 0 ..< index] + let x2 = x[.ellipsis, index...] + return concatenated([-x2, x1], axis: -1) +} + +// MARK: - Language + +private enum Language { + + /// Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors + static private func applyMultimodalRotaryPositionEmbedding( + q: MLXArray, k: MLXArray, cos: MLXArray, sin: MLXArray, + positionIds: MLXArray, mropeSection: [Int] + ) -> (MLXArray, MLXArray) { + var cos = cos[positionIds] + var sin = sin[positionIds] + + cos = + concatenated( + // [m[i % 3] for i, m in enumerate(mx.split(cos, mrope_section, axis=-1))] + split(cos, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] }, + axis: -1 + )[0..., .newAxis, 0..., 0...] + + sin = + concatenated( + split(sin, indices: mropeSection, axis: -1).enumerated().map { i, m in m[i % 3] }, + axis: -1 + )[0..., .newAxis, 0..., 0...] + + // Apply rotary embedding + let qEmbed = (q * cos) + (rotateHalf(q) * sin) + let kEmbed = (k * cos) + (rotateHalf(k) * sin) + return (qEmbed, kEmbed) + } + + fileprivate class Attention: Module { + + let heads: Int + let kvHeads: Int + let headDim: Int + let scale: Float + let mropeSection: [Int] + + @ModuleInfo(key: "q_proj") var wq: Linear + @ModuleInfo(key: "k_proj") var wk: Linear + @ModuleInfo(key: "v_proj") var wv: Linear + @ModuleInfo(key: "o_proj") var wo: Linear + + @ModuleInfo(key: "rotary_emb") var rotaryEmbedding: RoPE + + public init(_ args: Qwen2VLConfiguration.TextConfiguration) { + let dim = args.hiddenSize + self.heads = args.attentionHeads + self.kvHeads = args.kvHeads + self.headDim = dim / heads + self.scale = pow(Float(headDim), -0.5) + + self._wq.wrappedValue = Linear(dim, heads * headDim, bias: true) + self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) + self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true) + self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false) + + if let v = args.ropeScaling?["mrope_section"], let array = v.asInts() { + // mrope_section = np.cumsum(mrope_section * 2)[:-1].tolist() + self.mropeSection = sequence(state: (0, array.makeIterator())) { state in + if let v = state.1.next() { + // note the *2 + state.0 += v * 2 + return state.0 + } else { + return nil + } + }.dropLast() + } else { + fatalError("rope_scaling['mrope_section'] must be an array of integers") + } + + self._rotaryEmbedding.wrappedValue = RoPE( + dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? + ) -> MLXArray { + let (B, L) = (x.dim(0), x.dim(1)) + + var queries = wq(x) + var keys = wk(x) + var values = wv(x) + + // prepare the queries, keys and values for the attention computation + queries = queries.reshaped(B, L, heads, headDim).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3) + + let offset = cache?.offset ?? 0 + let mask = mask?[0..., 0 ..< keys.dim(-2)] + + queries = rotaryEmbedding(queries, offset: offset) + keys = rotaryEmbedding(keys, offset: offset) + + if let cache { + (keys, values) = cache.update(keys: keys, values: values) + } + + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) + + return wo(output) + } + } + + fileprivate class MLP: Module, UnaryLayer { + + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "down_proj") var down: Linear + @ModuleInfo(key: "up_proj") var up: Linear + + public init(dimensions: Int, hiddenDimensions: Int) { + self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) + self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + down(silu(gate(x)) * up(x)) + } + } + + fileprivate class Qwen2VLDecoderLayer: Module { + + @ModuleInfo(key: "self_attn") var attention: Attention + let mlp: MLP + + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm + + public init(_ args: Qwen2VLConfiguration.TextConfiguration) { + self._attention.wrappedValue = Attention(args) + self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) + self._inputLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + self._postAttentionLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? + ) -> MLXArray { + var r = attention(inputLayerNorm(x), mask: mask, cache: cache) + let h = x + r + r = mlp(postAttentionLayerNorm(h)) + let out = h + r + return out + } + } + + fileprivate class Qwen2Model: Module { + + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + + fileprivate let layers: [Qwen2VLDecoderLayer] + fileprivate let norm: RMSNorm + + public init(_ args: Qwen2VLConfiguration.TextConfiguration) { + precondition(args.vocabularySize > 0) + + self._embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + + self.layers = (0 ..< args.hiddenLayers) + .map { _ in + Qwen2VLDecoderLayer(args) + } + self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction( + _ inputs: MLXArray?, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil + ) -> MLXArray { + var h: MLXArray + if let inputEmbedding { + h = inputEmbedding + } else if let inputs { + h = embedTokens(inputs) + } else { + fatalError("one of inputs or inputEmbedding must be non-nil") + } + + let mask = createAttentionMask(h: h, cache: cache) + + for (i, layer) in layers.enumerated() { + h = layer(h, mask: mask, cache: cache?[i]) + } + + return norm(h) + } + } + + fileprivate class LanguageModel: Module, KVCacheDimensionProvider { + @ModuleInfo var model: Qwen2Model + @ModuleInfo(key: "lm_head") var lmHead: Linear? + + var kvHeads: [Int] + + public init(_ args: Qwen2VLConfiguration.TextConfiguration) { + self.model = Qwen2Model(args) + + if !args.tieWordEmbeddings { + _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) + } + + self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } + } + + public func callAsFunction( + _ inputs: MLXArray?, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil + ) -> LMOutput { + var out = model(inputs, cache: cache, inputEmbedding: inputEmbedding) + if let lmHead { + out = lmHead(out) + } else { + out = model.embedTokens.asLinear(out) + } + return LMOutput(logits: out) + } + } +} + +// MARK: - Vision + +private enum Vision { + + static fileprivate func applyMultimodalRotaryPositionEmbedding( + _ tensor: MLXArray, freqs: MLXArray + ) -> MLXArray { + var cos = cos(freqs) + var sin = sin(freqs) + + cos = expandedDimensions(cos, axis: 1) + cos = tiled(cos, repetitions: [1, 1, 2]) + cos = expandedDimensions(cos, axis: 0) + + sin = expandedDimensions(sin, axis: 1) + sin = tiled(sin, repetitions: [1, 1, 2]) + sin = expandedDimensions(sin, axis: 0) + + let output = (tensor * cos) + (rotateHalf(tensor) * sin) + return output.asType(tensor.dtype) + } + + fileprivate class VisionRotaryEmbedding { + let dimensions: Int + let theta: Float + let inverseFreq: MLXArray + + init(dimensions: Int, theta: Float) { + self.dimensions = dimensions + self.theta = theta + let p = MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions + self.inverseFreq = 1.0 / pow(theta, p) + } + + func callAsFunction(sequenceLength: Int) -> MLXArray { + let seq = MLXArray(0 ..< sequenceLength).asType(inverseFreq.dtype) + let freqs = outer(seq, inverseFreq) + return freqs + } + } + + fileprivate class PatchEmbed: Module, UnaryLayer { + @ModuleInfo var proj: Conv3d + + let patchSize: Int + let temporalPatchSize: Int + let inChannels: Int + let embedDimensions: Int + + init(patchSize: Int, temporalPatchSize: Int, inChannels: Int, embedDimensions: Int) { + self.patchSize = patchSize + self.temporalPatchSize = temporalPatchSize + self.inChannels = inChannels + self.embedDimensions = embedDimensions + + let kernelSize = IntOrTriple([temporalPatchSize, patchSize, patchSize]) + self._proj.wrappedValue = Conv3d( + inputChannels: inChannels, + outputChannels: embedDimensions, + kernelSize: kernelSize, + stride: kernelSize, + bias: false + ) + } + + func callAsFunction(_ hiddenStates: MLXArray) -> MLXArray { + var hiddenStates = hiddenStates.reshaped( + -1, inChannels, temporalPatchSize, patchSize, patchSize + ).movedAxis(source: 1, destination: 4) + + hiddenStates = proj(hiddenStates) + hiddenStates = hiddenStates.reshaped(-1, embedDimensions) + return hiddenStates + } + } + + fileprivate class PatchMerger: Module, UnaryLayer { + let hiddenSize: Int + @ModuleInfo(key: "ln_q") var layerNormQ: LayerNorm + @ModuleInfo var mlp: (Linear, GELU, Linear) + + init(dimensions: Int, contextDimensions: Int, spatialMergeSize: Int) { + self.hiddenSize = contextDimensions * (spatialMergeSize * spatialMergeSize) + self._layerNormQ.wrappedValue = LayerNorm(dimensions: contextDimensions, eps: 1e-6) + self.mlp = ( + Linear(hiddenSize, hiddenSize), + GELU(), + Linear(hiddenSize, dimensions) + ) + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + var x = layerNormQ(x).reshaped(-1, hiddenSize) + x = mlp.0(x) + x = mlp.1(x) + x = mlp.2(x) + return x + } + } + + fileprivate class Attention: Module { + + let numHeads: Int + let scale: Float + + @ModuleInfo(key: "qkv") var qkv: Linear + @ModuleInfo(key: "proj") var proj: Linear + + public init(dims: Int, numHeads: Int) { + self.numHeads = numHeads + let headDim = dims / numHeads + self.scale = pow(Float(headDim), -0.5) + + self._qkv.wrappedValue = Linear(dims, 3 * dims, bias: true) + self._proj.wrappedValue = Linear(dims, dims) + } + + public func callAsFunction( + _ x: MLXArray, gridThw: [THW], rotaryPositionEmbedding: MLXArray + ) -> MLXArray { + let sequenceLength = x.dim(0) + let B = gridThw[0].t + let L = sequenceLength / B + + let qkv = qkv(x).reshaped(sequenceLength, 3, -1) + let s = split(qkv, parts: 3, axis: 1) + var (q, k, v) = (s[0], s[1], s[2]) + + q = q.reshaped(sequenceLength, numHeads, -1) + k = k.reshaped(sequenceLength, numHeads, -1) + v = v.reshaped(sequenceLength, numHeads, -1) + + q = applyMultimodalRotaryPositionEmbedding(q, freqs: rotaryPositionEmbedding) + k = applyMultimodalRotaryPositionEmbedding(k, freqs: rotaryPositionEmbedding) + + q = q.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) + k = k.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) + v = v.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) + + let output = MLXFast.scaledDotProductAttention( + queries: q, keys: k, values: v, scale: scale, mask: nil + ) + .transposed(0, 2, 1, 3) + .reshaped(sequenceLength, -1) + + return proj(output) + } + } + + fileprivate class MLP: Module, UnaryLayer { + + @ModuleInfo var activation: GELU + @ModuleInfo var fc1: Linear + @ModuleInfo var fc2: Linear + + public init(dimensions: Int, hiddenDimensions: Int) { + self.activation = GELU(approximation: .fast) + self.fc1 = Linear(dimensions, hiddenDimensions) + self.fc2 = Linear(hiddenDimensions, dimensions) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + fc2(activation(fc1(x))) + } + } + + fileprivate class Qwen2VLVisionBlock: Module { + + @ModuleInfo var norm1: LayerNorm + @ModuleInfo var norm2: LayerNorm + @ModuleInfo(key: "attn") var attention: Attention + @ModuleInfo var mlp: MLP + + public init(_ config: Qwen2VLConfiguration.VisionConfiguration) { + self.norm1 = LayerNorm(dimensions: config.embedDimensions, eps: 1e-6) + self.norm2 = LayerNorm(dimensions: config.embedDimensions, eps: 1e-6) + + self._attention.wrappedValue = Attention( + dims: config.embedDimensions, numHeads: config.numHeads) + + let mlpHiddenDimensions = Int(Float(config.embedDimensions) * config.mlpRatio) + self.mlp = MLP( + dimensions: config.embedDimensions, hiddenDimensions: mlpHiddenDimensions) + } + + func callAsFunction( + _ hiddenStates: MLXArray, gridThw: [THW], rotaryPositionEmbedding: MLXArray + ) -> MLXArray { + var hiddenStates = + hiddenStates + + attention( + norm1(hiddenStates), + gridThw: gridThw, + rotaryPositionEmbedding: rotaryPositionEmbedding + ) + hiddenStates = hiddenStates + mlp(norm2(hiddenStates)) + return hiddenStates + } + } + + fileprivate class VisionModel: Module { + + @ModuleInfo(key: "patch_embed") var patchEmbed: PatchEmbed + @ModuleInfo(key: "rotary_pos_emb") var rotaryPositionEmbedding: VisionRotaryEmbedding + @ModuleInfo(key: "blocks") var blocks: [Qwen2VLVisionBlock] + @ModuleInfo(key: "merger") var patchMerger: PatchMerger + + let spatialMergeSize: Int + + public init(_ config: Qwen2VLConfiguration.VisionConfiguration) { + self.spatialMergeSize = config.spatialMergeSize + + self._patchEmbed.wrappedValue = PatchEmbed( + patchSize: config.patchSize, + temporalPatchSize: config.temporalPatchSize, + inChannels: config.inChannels, + embedDimensions: config.embedDimensions) + + let headDimensions = config.embedDimensions / config.numHeads + self._rotaryPositionEmbedding.wrappedValue = VisionRotaryEmbedding( + dimensions: headDimensions / 2, theta: 10_000) + + self._blocks.wrappedValue = (0 ..< config.depth).map { _ in + Qwen2VLVisionBlock(config) + } + self._patchMerger.wrappedValue = PatchMerger( + dimensions: config.hiddenSize, contextDimensions: config.embedDimensions, + spatialMergeSize: 2) + } + + func rotaryPositionEmbedding(_ gridThw: [THW]) -> MLXArray { + var positionIds = [MLXArray]() + + for row in gridThw { + let (t, h, w) = row.values + + var hposIds = expandedDimensions(MLXArray(0 ..< h), axis: 1) + hposIds = repeated(hposIds, count: w, axis: 1) + hposIds = + hposIds + .reshaped( + h / spatialMergeSize, + spatialMergeSize, + w / spatialMergeSize, + spatialMergeSize + ) + .transposed(0, 2, 1, 3) + .flattened() + + var wposIds = expandedDimensions(MLXArray(0 ..< w), axis: 0) + wposIds = repeated(wposIds, count: h, axis: 0) + wposIds = + wposIds + .reshaped( + h / spatialMergeSize, + spatialMergeSize, + w / spatialMergeSize, + spatialMergeSize + ) + .transposed(0, 2, 1, 3) + .flattened() + + let stackedPosIds = stacked([hposIds, wposIds], axis: -1) + positionIds.append(repeated(stackedPosIds, count: t, axis: 0)) + } + + let indices = concatenated(positionIds, axis: 0) + let maxGridSize = gridThw.lazy.map { max($0.h, $0.w) }.max() ?? 0 + let rotaryPositionEmbedFull = rotaryPositionEmbedding(sequenceLength: maxGridSize)[ + indices] + + return rotaryPositionEmbedFull.reshaped(indices.dim(0), -1) + } + + public func callAsFunction(_ hiddenStates: MLXArray, gridThw: [THW]) -> MLXArray { + var hiddenStates = patchEmbed(hiddenStates) + let rotaryPositionEmbedding = rotaryPositionEmbedding(gridThw) + + let batchSize = gridThw.count + + for block in blocks { + hiddenStates = block( + hiddenStates, gridThw: gridThw, + rotaryPositionEmbedding: rotaryPositionEmbedding) + } + + return patchMerger(hiddenStates) + } + + private func isMLXWeight(_ array: MLXArray) -> Bool { + if array.ndim != 4 && array.ndim != 5 { + return false + } + + if array.dim(-1) == 3 { + return true + } + + let (outChannels, kH, kW) = (array.dim(1), array.dim(2), array.dim(3)) + return outChannels >= kH && outChannels >= kW && kH == kW + } + + func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + var sanitizedWeights = [String: MLXArray]() + + for (k, v) in weights { + if k.contains("position_id") { + // Remove unused position_ids + continue + } else if k.contains("patch_embed.proj.weight") { + // PyTorch conv2d weight tensors have shape: + // [B, out_channels, in_channels, kH, KW] + // MLX conv2d expects the weight be of shape: + // [B, out_channels, kH, KW, in_channels] + if isMLXWeight(v) { + sanitizedWeights[k] = v + } else { + sanitizedWeights[k] = v.transposed(0, 2, 3, 4, 1) + } + } else { + sanitizedWeights[k] = v + } + } + + return sanitizedWeights + } + } +} + +// MARK: - Processor + +/// Qwen2VL VLM `UserInputProcessor`. +/// +/// This is meant to be used with ``Qwen2VL`` and is typically created by ``VLMModelFactory``. +public class Qwen2VLProcessor: UserInputProcessor { + + private let config: Qwen2VLProcessorConfiguration + private let tokenizer: any Tokenizer + + public init(_ config: Qwen2VLProcessorConfiguration, tokenizer: any Tokenizer) { + self.config = config + self.tokenizer = tokenizer + } + + // image_processing_qwen2_vl.smart_resize + private func targetSize(height: Int, width: Int, factor: Int, minPixels: Int, maxPixels: Int) + throws -> (Int, Int) + { + if height < factor { + throw VLMError.imageProcessingFailure( + "height: \(height) must be larger than factor: \(factor)") + } + if width < factor { + throw VLMError.imageProcessingFailure( + "width: \(width) must be larger than factor: \(factor)") + } + if max(height, width) / min(height, width) > 200 { + throw VLMError.imageProcessingFailure( + "absolute aspect ratio must be smaller than 200: \(width)x\(height)") + } + + var hBar = Int(round(Float(height) / Float(factor))) * factor + var wBar = Int(round(Float(width) / Float(factor))) * factor + + if hBar * wBar > maxPixels { + let beta = sqrt(Float(height * width) / Float(maxPixels)) + hBar = Int(floor(Float(height) / beta / Float(factor))) * factor + wBar = Int(floor(Float(width) / beta / Float(factor))) * factor + } else if hBar * wBar < minPixels { + let beta = sqrt(Float(minPixels) / Float(height * width)) + hBar = Int(floor(Float(height) * beta / Float(factor))) * factor + wBar = Int(floor(Float(width) * beta / Float(factor))) * factor + } + return (hBar, wBar) + } + + public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> ( + MLXArray, THW + ) { + // first apply the user requested resizing, etc. if any + let images = images.map { MediaProcessing.apply($0, processing: processing) } + + // image_processing_qwen2_vl._preprocess + + let size = images[0].extent.size + let (resizedHeight, resizedWidth) = try targetSize( + height: Int(size.height), width: Int(size.width), + factor: config.patchSize * config.mergeSize, + minPixels: config.size.minPixels, maxPixels: config.size.maxPixels) + let resizedSize = CGSize(width: resizedWidth, height: resizedHeight) + + let processedImages = + try images + .map { + MediaProcessing.inSRGBToneCurveSpace($0) + } + .map { + return MediaProcessing.resampleBicubic($0, to: resizedSize) + } + .map { + MediaProcessing.normalize( + $0, mean: config.imageMeanTuple, std: config.imageStdTuple) + } + .map { + MediaProcessing.asMLXArray($0) + } + + var patches = concatenated(processedImages) + if patches.dim(0) != config.temporalPatchSize { + patches = tiled(patches, repetitions: [config.temporalPatchSize, 1, 1, 1]) + } + let channel = patches.dim(1) + let gridT = patches.dim(0) / self.config.temporalPatchSize + let gridH = resizedHeight / self.config.patchSize + let gridW = resizedWidth / self.config.patchSize + + patches = patches.reshaped( + gridT, + config.temporalPatchSize, + channel, + gridH / config.mergeSize, + config.mergeSize, + config.patchSize, + gridW / config.mergeSize, + config.mergeSize, + config.patchSize + ) + patches = patches.transposed(0, 3, 6, 4, 7, 2, 1, 5, 8) + + let flattenedPatches = patches.reshaped( + gridT * gridH * gridW, + channel * config.temporalPatchSize * config.patchSize * config.patchSize + ) + + return (flattenedPatches, .init(gridT, gridH, gridW)) + } + + public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?) -> String { + // the tokenizer does have a chat template and it expects messages + // like this: + // + // [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'}, + // {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}] + // + // The output of the prompt template is fed into + // image_processing_qwen2_vl.preprocess where it is further augmented + // by replacing tokens according to imageTHW. + // + // Neither the structured content nor the postprocessing of the template + // are supported in current Tokenizer/Jinja (swift) so handle that here. + + var messages = prompt.asMessages() + if messages[0]["role"] != "system" { + messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0) + } + + let lastIndex = messages.count - 1 + var lastMessage = messages[lastIndex]["content"] ?? "" + + // image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image + let mergeLength = config.mergeSize * config.mergeSize + for thw in imageTHW ?? [] { + lastMessage += "<|vision_start|>" + lastMessage += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength) + .joined() + lastMessage += "<|vision_end|>" + } + + messages[lastIndex]["content"] = lastMessage + + return + messages + .map { + "<|im_start|>\($0["role"] ?? "user")\n\($0["content"] ?? "")<|im_end|>" + } + .joined(separator: "\n") + + "\n<|im_start|>assistant\n" + } + + public func prepare(input: UserInput) throws -> LMInput { + if input.images.isEmpty { + // just a straight text prompt + let prompt = prepare(prompt: input.prompt, imageTHW: nil) + let promptTokens = try tokenizer.encode(text: prompt) + return LMInput(tokens: MLXArray(promptTokens)) + } + + // image_processing_qwen2_vl.preprocess + let images = try input.images.map { + try preprocess(images: [$0.asCIImage()], processing: input.processing) + } + let pixels = concatenated(images.map { $0.0 }) + let image = LMInput.ProcessedImage(pixels: pixels, imageGridThw: images.map { $0.1 }) + + // processing_qwen2_vl.Qwen2VLProcessor + let prompt = prepare(prompt: input.prompt, imageTHW: image.imageGridThw) + let promptTokens = try tokenizer.encode(text: prompt) + let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) + let mask = ones(like: promptArray).asType(.int8) + + return LMInput(text: .init(tokens: promptArray, mask: mask), image: image) + } + +} + +// MARK: - Model + +/// Qwen2VL VLM +/// +/// This is typically created by ``VLMModelFactory``. +public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider { + + @ModuleInfo(key: "vision_tower") private var visionModel: Vision.VisionModel + @ModuleInfo(key: "language_model") private var languageModel: Language.LanguageModel + + public let config: Qwen2VLConfiguration + + public var vocabularySize: Int { config.baseConfiguration.vocabularySize } + public var kvHeads: [Int] { languageModel.kvHeads } + + public func loraLinearLayers() -> MLXLMCommon.LoRALinearLayers { + languageModel.model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } + } + + public init(_ config: Qwen2VLConfiguration) { + self.config = config + self._visionModel.wrappedValue = Vision.VisionModel(config.visionConfiguration) + self._languageModel.wrappedValue = Language.LanguageModel(config.textConfiguration) + } + + private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, gridThw: [THW]?) + -> MLXArray + { + guard let pixelValues, let gridThw else { + return languageModel(inputIds).logits + } + + // Get the input embeddings from the language model + let inputEmbeds = languageModel.model.embedTokens(inputIds) + + // Get the ouptut hidden states from the vision model + var hiddenStates = self.visionModel(pixelValues, gridThw: gridThw) + + if hiddenStates.ndim == 2 { + hiddenStates = hiddenStates[.newAxis, 0..., 0...] + } + + // Insert special image tokens in the input_ids + return mergeInputIdsWithImageFeatures( + inputIds: inputIds, inputEmbeds: inputEmbeds, imageFeatures: hiddenStates) + } + + private func mergeInputIdsWithImageFeatures( + inputIds: MLXArray, inputEmbeds: MLXArray, imageFeatures: MLXArray + ) -> MLXArray { + let imageTokenIndex = config.baseConfiguration.imageTokenId + + var imageIndices = [Int]() + for (i, v) in inputIds.asArray(Int.self).enumerated() { + if v == imageTokenIndex { + imageIndices.append(i) + } + } + + inputEmbeds[0..., MLXArray(imageIndices), 0...] = imageFeatures + return inputEmbeds + } + + public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws + -> PrepareResult + { + let gridThw = input.image?.imageGridThw + + let dtype = visionModel.patchEmbed.proj.weight.dtype + let pixels = input.image?.pixels.asType(dtype) + + let inputEmbeddings = self.inputEmbeddings( + inputIds: input.text.tokens, pixelValues: pixels, gridThw: gridThw) + + let result = languageModel(nil, cache: cache, inputEmbedding: inputEmbeddings) + + return .logits(result) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [any KVCache]?) -> MLXArray { + languageModel(inputs, cache: cache).logits + } + + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + visionModel.sanitize( + weights: + Dictionary( + uniqueKeysWithValues: weights.map { key, value in + var key = key + if !key.contains("vision_tower") { + key = key.replacingOccurrences(of: "visual", with: "vision_tower") + } + if !key.contains("language_model") { + key = key.replacingOccurrences( + of: "model", with: "language_model.model") + key = key.replacingOccurrences( + of: "lm_head", with: "language_model.lm_head") + } + + return (key, value) + }) + ) + } + +} + +// MARK: - Configuration + +/// Configuration for ``Qwen2VL`` +public struct Qwen2VLConfiguration: Codable, Sendable { + + public struct TextConfiguration: Codable, Sendable { + public let modelType: String + public let hiddenSize: Int + public let hiddenLayers: Int + public let intermediateSize: Int + public let attentionHeads: Int + private let _rmsNormEps: Float? + public var rmsNormEps: Float { _rmsNormEps ?? 1e-6 } + public let vocabularySize: Int + public let kvHeads: Int + private let _maxPositionEmbeddings: Int? + public var maxpPositionEmbeddings: Int { _maxPositionEmbeddings ?? 32768 } + private let _ropeTheta: Float? + public var ropeTheta: Float { _ropeTheta ?? 1_000_000 } + private let _ropeTraditional: Bool? + public var ropeTraditional: Bool { _ropeTraditional ?? false } + public let ropeScaling: [String: StringOrNumber]? + private let _tieWordEmbeddings: Bool? + public var tieWordEmbeddings: Bool { _tieWordEmbeddings ?? true } + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case hiddenSize = "hidden_size" + case hiddenLayers = "num_hidden_layers" + case intermediateSize = "intermediate_size" + case attentionHeads = "num_attention_heads" + case _rmsNormEps = "rms_norm_eps" + case vocabularySize = "vocab_size" + case kvHeads = "num_key_value_heads" + case _maxPositionEmbeddings = "max_position_embeddings" + case _ropeTheta = "rope_theta" + case _ropeTraditional = "rope_traditional" + case ropeScaling = "rope_scaling" + case _tieWordEmbeddings = "tie_word_embeddings" + } + } + + public struct VisionConfiguration: Codable, Sendable { + public let depth: Int + public let embedDimensions: Int + public let hiddenSize: Int + public let numHeads: Int + public let patchSize: Int + public let mlpRatio: Float + public let _inChannels: Int? + public var inChannels: Int { _inChannels ?? 3 } + public let _layerNormEps: Float? + public var layerNormEps: Float { _layerNormEps ?? 1e-6 } + public let spatialPatchSize: Int + public let spatialMergeSize: Int + public let temporalPatchSize: Int + + enum CodingKeys: String, CodingKey { + case depth + case embedDimensions = "embed_dim" + case hiddenSize = "hidden_size" + case numHeads = "num_heads" + case patchSize = "patch_size" + case mlpRatio = "mlp_ratio" + case _inChannels = "in_channels" + case _layerNormEps = "layer_norm_eps" + case spatialPatchSize = "spatial_patch_size" + case spatialMergeSize = "spatial_merge_size" + case temporalPatchSize = "temporal_patch_size" + } + } + + public struct BaseConfiguration: Codable, Sendable { + public let modelType: String + public let vocabularySize: Int + public let imageTokenId: Int + public let hiddenSize: Int + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case vocabularySize = "vocab_size" + case imageTokenId = "image_token_id" + case hiddenSize = "hidden_size" + } + } + + public let textConfiguration: TextConfiguration + public let visionConfiguration: VisionConfiguration + public let baseConfiguration: BaseConfiguration + + enum CodingKeys: String, CodingKey { + case visionConfiguration = "vision_config" + } + + public init(from decoder: any Swift.Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + + // this is a sub-dictionary + self.visionConfiguration = try container.decode( + VisionConfiguration.self, forKey: .visionConfiguration) + + // these are overlaid in the top level + self.textConfiguration = try TextConfiguration(from: decoder) + self.baseConfiguration = try BaseConfiguration(from: decoder) + } +} + +/// Configuration for ``Qwen2VLProcessor`` +public struct Qwen2VLProcessorConfiguration: Codable, Sendable { + + public struct Size: Codable, Sendable { + public let maxPixels: Int + public let minPixels: Int + + enum CodingKeys: String, CodingKey { + case maxPixels = "max_pixels" + case minPixels = "min_pixels" + } + } + + public let imageMean: [CGFloat] + public let imageStd: [CGFloat] + public let size: Size + public let mergeSize: Int + public let patchSize: Int + public let temporalPatchSize: Int + + public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { + (imageMean[0], imageMean[1], imageMean[2]) + } + public var imageStdTuple: (CGFloat, CGFloat, CGFloat) { + (imageStd[0], imageStd[1], imageStd[2]) + } + + enum CodingKeys: String, CodingKey { + case imageMean = "image_mean" + case imageStd = "image_std" + case size + case mergeSize = "merge_size" + case patchSize = "patch_size" + case temporalPatchSize = "temporal_patch_size" + } +} diff --git a/Libraries/MLXVLM/README.md b/Libraries/MLXVLM/README.md new file mode 100644 index 0000000..eb6806d --- /dev/null +++ b/Libraries/MLXVLM/README.md @@ -0,0 +1,377 @@ +# MLXVLM + +This is a port of several models from: + +- https://github.com/Blaizzy/mlx-vlm + +using the Hugging Face swift transformers package to provide tokenization: + +- https://github.com/huggingface/swift-transformers + +The [VLMModelFactory.swift](VLMModelFactory.swift) provides minor overrides and customization -- +if you require overrides for the tokenizer or prompt customizations they can be +added there. + +This is set up to load models from Hugging Face, e.g. https://huggingface.co/mlx-community + +The following models have been tried: + +- mlx-community/paligemma-3b-mix-448-8bit +- mlx-community/Qwen2-VL-2B-Instruct-4bit + +Currently supported model types are: + +- paligemma +- qwen2_vl + +See [llm-tool](../../Tools/llm-tool) + +# Adding a Model + +If the model follows the typical VLM pattern: + +- `config.json`, `tokenizer.json`, and `tokenizer_config.json` +- `*.safetensors` + +You can follow the pattern of the models in the [Models](Models) directory +and create a `.swift` file for your new model: + +## Create a Model Configuration + +Create a configuration struct for both the Text and Vision models +that matches the structure in `config.json`. A struct like this +is recommended: + +```swift +public struct YourModelConfiguration: Codable, Sendable { + public struct TextConfiguration: Codable, Sendable { + public let hiddenSize: Int + + // use this pattern for values that need defaults + public let _layerNormEps: Float? + public var layerNormEps: Float { _layerNormEps ?? 1e-6 } + + enum CodingKeys: String, CodingKey { + case hiddenSize = "hidden_size" + case _layerNormEps = "layer_norm_eps" + } + } + + public struct VisionConfiguration: Codable, Sendable { + ... + } + + public let textConfiguration: TextConfiguration + public let visionConfiguration: VisionConfiguration + public let vocabularySize: Int + + enum CodingKeys: String, CodingKey { + case textConfiguration = "text_config" + case visionConfiguration = "vision_config" + case vocabularySize = "vocab_size" + } +} +``` + +## Create a Processor Configuration + +VLMs also require a image/video preprocessor. Create a configuration to match +the `preprocessor_config.json` file: + +```swift +public struct YourModelProcessorConfiguration: Codable, Sendable { + + public struct Size: Codable, Sendable { + public let width: Int + public let height: Int + + var cgSize: CGSize { .init(width: width, height: height) } + } + + public let imageMean: [CGFloat] + public let imageStd: [CGFloat] + public let size: Size + + public var imageMeanTuple: (CGFloat, CGFloat, CGFloat) { + (imageMean[0], imageMean[1], imageMean[2]) + } + public var imageStdTuple: (CGFloat, CGFloat, CGFloat) { + (imageStd[0], imageStd[1], imageStd[2]) + } + + enum CodingKeys: String, CodingKey { + case imageMean = "image_mean" + case imageStd = "image_std" + case size + } +} +``` + +this will be consumed by: + +```swift +public class YourModelProcessor: UserInputProcessor { +... +``` + +discussed later. + +## Create the Vision, Text and VLM Classes + +VLMs have language and vision models that are aggregated into a single +top level model. + +For purposes of name spacing you might put the Language and Vision +models into an `enum` to create something structured like this: + +```swift +// MARK: - Language + +private enum Language { + + fileprivate class Attention: Module { + ... + } + + ... + + fileprivate class LanguageModel: Module, KVCacheDimensionProvider { + @ModuleInfo var model: YourModel + + var kvHeads: [Int] + var headDim: MLX.IntOrPair + + public init(_ args: YourModelConfiguration.TextConfiguration) { + self.model = YourModel(args) + + self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } + } + + public func callAsFunction( + _ inputs: MLXArray, cache: [KVCache]? = nil, inputEmbedding: MLXArray? = nil, + mask: MLXArray? = nil + ) -> LMOutput { + ... + return LMOutput(logits: ...) + } + } +} +``` + +Similarly the Vision model can go into an `enum` namespace: + +```swift +// MARK: - Vision + +private enum Vision { + + fileprivate class Attention: Module { + ... + } + + fileprivate class VisionModel: Module { + + @ModuleInfo(key: "vision_model") var visionModel: InternalVisionModel + + public init(_ config: YourModelConfiguration.VisionConfiguration) { + self._visionModel.wrappedValue = InternalVisionModel(config) + } + + public func callAsFunction(_ x: MLXArray, outputHiddenStates: Bool = false) -> ( + MLXArray, MLXArray, MLXArray? + ) { + visionModel(x, outputHiddenStates: outputHiddenStates) + } + } +} +``` + +The exact signatures on the `init()` and `callAsFunction()` can vary as needed -- +these models are not exposed to callers. + +The top level model is the only piece of the model with public API and it +should implement `VLMModel` (aka `LanguageModel`). Here is an outline of how +the top level model might work: + +```swift +public class YourModel: Module, VLMModel, KVCacheDimensionProvider { + + @ModuleInfo(key: "vision_tower") private var visionModel: Vision.VisionModel + @ModuleInfo(key: "language_model") private var languageModel: Language.LanguageModel + + public let config: YourModelConfiguration + + public var vocabularySize: Int { config.vocabularySize } + public var kvHeads: [Int] { languageModel.kvHeads } + public var headDim: MLX.IntOrPair { languageModel.headDim } + + public func loraLinearLayers() -> MLXLMCommon.LoRALinearLayers { + languageModel.model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } + } + + public init(_ config: YourModelConfiguration) { + self.config = config + self._visionModel.wrappedValue = Vision.VisionModel(config.visionConfiguration) + self._languageModel.wrappedValue = Language.LanguageModel(config.textConfiguration) + } + + public func prepare(_ input: LMInput, cache: [any KVCache], windowSize: Int?) throws + -> PrepareResult + { + // TODO prepare the cache and resulting logits based on the + // text prompt and any media assets + guard let image = input.image else { throw VLMError.imageRequired } + guard let mask = input.text.mask else { throw VLMError.maskRequired } + let inputIds = input.text.tokens + + let inputEmbedding = inputEmbeddings( + inputIds: inputIds, pixelValues: image.pixels, mask: mask) + + let result = languageModel( + inputIds, cache: cache, inputEmbedding: inputEmbedding, mask: mask) + + return .logits(result) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [any KVCache]?) -> MLXArray { + // TODO evaluate a step in the language model + languageModel(inputs, cache: cache).logits + } +} +``` + +## Create the UserInputProcessor + +VLMs require custom `UserInputProcessor` instances to manipulate the prompts and +media as needed. For example it might: + +- apply resampling and normalization to the images +- convert the images into an `MLXArray` and build a `THW` struct describing the layout +- modify the prompt by injecting `` tokens that the model expects + +In the python implementations, much of this code typically lives in the `transformers` +package from huggingface -- inspection will be required to determine which code +is called and what it does. You can examine the processors in the `Models` directory: +they reference the files and functions that they are based on. + +The `UserInputProcessor` is initialized with the `ProcessorConfiguration` (defined above) +and has a prepare method: + +```swift +public func prepare(input: UserInput) throws -> LMInput +``` + +This is a slight paraphrase of the `PaligemmaUserInputProcessor` as an example: + +```swift +public class YourModelProcessor: UserInputProcessor { + + private let config: YourModelProcessorConfiguration + private let tokenizer: any Tokenizer + + public init(_ config: YourModelProcessorConfiguration, tokenizer: any Tokenizer) { + self.config = config + self.tokenizer = tokenizer + } + + private func prepare(image: CIImage, processing: UserInput.Processing?) -> MLXArray { + // based on image_processing_siglip from transformers + var image = image + + // we want to do all of the image processing in an sRGB tone curve + // rather than a linear space as that is what transformers / torch_vision + // do (implicitly by using sRGB rasters directly) + image = MediaProcessing.inSRGBToneCurveSpace(image) + + // apply user instructions + image = MediaProcessing.apply(image, processing: processing) + + image = MediaProcessing.resampleBicubic(image, to: config.size.cgSize) + image = MediaProcessing.normalize( + image, mean: config.imageMeanTuple, std: config.imageStdTuple) + + return MediaProcessing.asMLXArray(image) + } + + public func prepare(input: UserInput) throws -> LMInput { + switch input.images.count { + case 0: throw VLMError.imageRequired + case 1: break + default: throw VLMError.singleImageAllowed + } + + // this doesn't have a chat template so just use the last message. + var prompt = input.prompt.asMessages().last?["content"] ?? "" + + // based on transformers/processing_paligemma + let count = input.images.count * config.imageSequenceLength + prompt = + Array(repeating: "", count: count).joined() + (tokenizer.bosToken ?? "") + prompt + + "\n" + + let promptTokens = try tokenizer.encode(text: prompt) + let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0) + let mask = ones(like: promptArray) + + let pixels = try prepare(image: input.images[0].asCIImage(), processing: input.processing) + + return LMInput(text: .init(tokens: promptArray, mask: mask), image: .init(pixels: pixels)) + } + +} +``` + +Note that the python code may rely on the chat template to inject the image tokens +(paligemma does not). This may have to be expressed in swift code as the current +interface does not support the structured parameters used for this (see Qwen2VL +processor for an example). + +## Register the Model + +In [VLMModelFactory.swift](VLMModelFactory.swift) register the model type itself +(this is independent of the model id): + +```swift +public class ModelTypeRegistry: @unchecked Sendable { +... + private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [ + "yourModel": create(YourModelConfiguration.self, YourModel.init), +``` + +Similarly, register the UserInputProcessor type (`preprocessor_config.json`): + +```swift +public class ProcessorTypeRegistry: @unchecked Sendable { +... + private var creators: + [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor] = [ + "YourModelProcessor": create( + YourModelProcessorConfiguration.self, YourModelProcessor.init), +``` + +Add a constant for the model in the ModelRegistry (not strictly required but useful +for callers to refer to it in code): + +```swift +public class ModelRegistry: @unchecked Sendable { +... + static public let yourModel_4bit = ModelConfiguration( + id: "mlx-community/YourModel-4bit", + defaultPrompt: "Describe the image in English" + ) +``` + +and finally add it to the all list -- this will let users find the model +configuration by id: + +```swift + private static func all() -> [ModelConfiguration] { + [ + paligemma3bMix4488bit, +... + yourModel_4bit, +``` + +# Using a Model + +See [MLXLMCommon/README.md](../MLXLMCommon/README.md#using-a-model). diff --git a/Libraries/MLXVLM/VLMModel.swift b/Libraries/MLXVLM/VLMModel.swift new file mode 100644 index 0000000..2c8de0c --- /dev/null +++ b/Libraries/MLXVLM/VLMModel.swift @@ -0,0 +1,7 @@ +// Copyright © 2024 Apple Inc. + +import MLX +import MLXLMCommon + +public protocol VLMModel: LanguageModel, LoRAModel { +} diff --git a/Libraries/MLXVLM/VLMModelFactory.swift b/Libraries/MLXVLM/VLMModelFactory.swift new file mode 100644 index 0000000..51843b1 --- /dev/null +++ b/Libraries/MLXVLM/VLMModelFactory.swift @@ -0,0 +1,228 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import Hub +import MLX +import MLXLMCommon +import Tokenizers + +public enum VLMError: Error { + case imageRequired + case maskRequired + case singleImageAllowed + case imageProcessingFailure(String) +} + +public struct BaseProcessorConfiguration: Codable, Sendable { + public let processorClass: String + + enum CodingKeys: String, CodingKey { + case processorClass = "processor_class" + } +} + +/// Creates a function that loads a configuration file and instantiates a model with the proper configuration +private func create( + _ configurationType: C.Type, _ modelInit: @escaping (C) -> M +) -> (URL) throws -> M { + { url in + let configuration = try JSONDecoder().decode( + C.self, from: Data(contentsOf: url)) + return modelInit(configuration) + } +} + +private func create( + _ configurationType: C.Type, _ processorInit: @escaping (C, any Tokenizer) -> P +) -> (URL, any Tokenizer) throws -> P { + { url, tokenizer in + let configuration = try JSONDecoder().decode( + C.self, from: Data(contentsOf: url)) + return processorInit(configuration, tokenizer) + } +} + +/// Registry of model type, e.g 'llama', to functions that can instantiate the model from configuration. +/// +/// Typically called via ``LLMModelFactory/load(hub:configuration:progressHandler:)``. +public class ModelTypeRegistry: @unchecked Sendable { + + // Note: using NSLock as we have very small (just dictionary get/set) + // critical sections and expect no contention. this allows the methods + // to remain synchronous. + private let lock = NSLock() + + private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [ + "paligemma": create(PaliGemmaConfiguration.self, PaliGemma.init), + "qwen2_vl": create(Qwen2VLConfiguration.self, Qwen2VL.init), + ] + + /// Add a new model to the type registry. + public func registerModelType( + _ type: String, creator: @Sendable @escaping (URL) throws -> any LanguageModel + ) { + lock.withLock { + creators[type] = creator + } + } + + /// Given a `modelType` and configuration file instantiate a new `LanguageModel`. + public func createModel(configuration: URL, modelType: String) throws -> any LanguageModel { + let creator = lock.withLock { + creators[modelType] + } + guard let creator else { + throw ModelFactoryError.unsupportedModelType(modelType) + } + return try creator(configuration) + } + +} + +public class ProcessorTypeRegistry: @unchecked Sendable { + + // Note: using NSLock as we have very small (just dictionary get/set) + // critical sections and expect no contention. this allows the methods + // to remain synchronous. + private let lock = NSLock() + + private var creators: + [String: @Sendable (URL, any Tokenizer) throws -> any UserInputProcessor] = [ + "PaliGemmaProcessor": create( + PaliGemmaProcessorConfiguration.self, PaligGemmaProcessor.init), + "Qwen2VLProcessor": create( + Qwen2VLProcessorConfiguration.self, Qwen2VLProcessor.init), + ] + + /// Add a new model to the type registry. + public func registerProcessorType( + _ type: String, + creator: @Sendable @escaping (URL, any Tokenizer) throws -> any UserInputProcessor + ) { + lock.withLock { + creators[type] = creator + } + } + + /// Given a `processorType` and configuration file instantiate a new `UserInputProcessor`. + public func createModel(configuration: URL, processorType: String, tokenizer: any Tokenizer) + throws -> any UserInputProcessor + { + let creator = lock.withLock { + creators[processorType] + } + guard let creator else { + throw ModelFactoryError.unsupportedProcessorType(processorType) + } + return try creator(configuration, tokenizer) + } + +} + +/// Registry of models and any overrides that go with them, e.g. prompt augmentation. +/// If asked for an unknown configuration this will use the model/tokenizer as-is. +/// +/// The python tokenizers have a very rich set of implementations and configuration. The +/// swift-tokenizers code handles a good chunk of that and this is a place to augment that +/// implementation, if needed. +public class ModelRegistry: @unchecked Sendable { + + private let lock = NSLock() + private var registry = Dictionary(uniqueKeysWithValues: all().map { ($0.name, $0) }) + + static public let paligemma3bMix448_8bit = ModelConfiguration( + id: "mlx-community/paligemma-3b-mix-448-8bit", + defaultPrompt: "Describe the image in English" + ) + + static public let qwen2VL2BInstruct4Bit = ModelConfiguration( + id: "mlx-community/Qwen2-VL-2B-Instruct-4bit", + defaultPrompt: "Describe the image in English" + ) + + static private func all() -> [ModelConfiguration] { + [ + paligemma3bMix448_8bit, + qwen2VL2BInstruct4Bit, + ] + } + + public func register(configurations: [ModelConfiguration]) { + lock.withLock { + for c in configurations { + registry[c.name] = c + } + } + } + + public func configuration(id: String) -> ModelConfiguration { + lock.withLock { + if let c = registry[id] { + return c + } else { + return ModelConfiguration(id: id) + } + } + } +} + +/// Factory for creating new LLMs. +/// +/// Callers can use the `shared` instance or create a new instance if custom configuration +/// is required. +/// +/// ```swift +/// let modelContainer = try await VLMModelFactory.shared.loadContainer( +/// configuration: ModelRegistry.paligemma3bMix4488bit) +/// ``` +public class VLMModelFactory: ModelFactory { + + public static let shared = VLMModelFactory() + + /// registry of model type, e.g. configuration value `paligemma` -> configuration and init methods + public let typeRegistry = ModelTypeRegistry() + + /// registry of input processor type, e.g. configuration value `PaliGemmaProcessor` -> configuration and init methods + public let processorRegistry = ProcessorTypeRegistry() + + /// registry of model id to configuration, e.g. `mlx-community/paligemma-3b-mix-448-8bit` + public let modelRegistry = ModelRegistry() + + public func configuration(id: String) -> ModelConfiguration { + modelRegistry.configuration(id: id) + } + + public func _load( + hub: HubApi, configuration: ModelConfiguration, + progressHandler: @Sendable @escaping (Progress) -> Void + ) async throws -> ModelContext { + // download weights and config + let modelDirectory = try await downloadModel( + hub: hub, configuration: configuration, progressHandler: progressHandler) + + // load the generic config to unerstand which model and how to load the weights + let configurationURL = modelDirectory.appending(component: "config.json") + let baseConfig = try JSONDecoder().decode( + BaseConfiguration.self, from: Data(contentsOf: configurationURL)) + + let model = try typeRegistry.createModel( + configuration: configurationURL, modelType: baseConfig.modelType) + + // apply the weights to the bare model + try loadWeights( + modelDirectory: modelDirectory, model: model, quantization: baseConfig.quantization) + + let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub) + + let processorConfiguration = modelDirectory.appending(component: "preprocessor_config.json") + let baseProcessorConfig = try JSONDecoder().decode( + BaseProcessorConfiguration.self, from: Data(contentsOf: processorConfiguration)) + let processor = try processorRegistry.createModel( + configuration: processorConfiguration, + processorType: baseProcessorConfig.processorClass, tokenizer: tokenizer) + + return .init( + configuration: configuration, model: model, processor: processor, tokenizer: tokenizer) + } + +} diff --git a/Libraries/MNIST/MNIST.h b/Libraries/MNIST/MNIST.h deleted file mode 100644 index 8b13789..0000000 --- a/Libraries/MNIST/MNIST.h +++ /dev/null @@ -1 +0,0 @@ - diff --git a/Libraries/StableDiffusion/Image.swift b/Libraries/StableDiffusion/Image.swift index 44bc1b1..75d94e1 100644 --- a/Libraries/StableDiffusion/Image.swift +++ b/Libraries/StableDiffusion/Image.swift @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. import CoreGraphics +import CoreImage import Foundation import ImageIO import MLX @@ -113,6 +114,23 @@ public struct Image { } } + /// Convert the image data to a CIImage + public func asCIImage() -> CIImage { + // we need 4 bytes per pixel + var raster = data + if data.dim(-1) == 3 { + raster = padded(raster, widths: [0, 0, [0, 1]], value: MLXArray(255)) + } + + let arrayData = raster.asData() + let (H, W, C) = raster.shape3 + let cs = CGColorSpace(name: CGColorSpace.sRGB)! + + return CIImage( + bitmapData: arrayData.data, bytesPerRow: W * 4, size: .init(width: W, height: H), + format: .RGBA8, colorSpace: cs) + } + /// Save the image public func save(url: URL) throws { let uti = UTType(filenameExtension: url.pathExtension) ?? UTType.png diff --git a/Libraries/StableDiffusion/Tokenizer.swift b/Libraries/StableDiffusion/Tokenizer.swift index baaad6f..1ce2ea9 100644 --- a/Libraries/StableDiffusion/Tokenizer.swift +++ b/Libraries/StableDiffusion/Tokenizer.swift @@ -36,7 +36,7 @@ struct Bigram: Hashable { class CLIPTokenizer { let pattern = - /<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+/ + #/<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+/# let bpeRanks: [Bigram: Int] let vocabulary: [String: Int] @@ -121,7 +121,7 @@ class CLIPTokenizer { // a much more thorough job here but this should suffice for 95% of // cases. - let clean = text.lowercased().replacing(/\s+/, with: " ") + let clean = text.lowercased().replacing(#/\s+/#, with: " ") let tokens = clean.matches(of: pattern).map { $0.description } // Split the tokens according to the byte-pair merge file diff --git a/Package.swift b/Package.swift index 26a5c72..5628936 100644 --- a/Package.swift +++ b/Package.swift @@ -8,17 +8,23 @@ let package = Package( platforms: [.macOS(.v14), .iOS(.v16)], products: [ .library( - name: "LLM", + name: "MLXLLM", targets: ["MLXLLM"]), .library( - name: "MNIST", + name: "MLXVLM", + targets: ["MLXVLM"]), + .library( + name: "MLXLMCommon", + targets: ["MLXLMCommon"]), + .library( + name: "MLXMNIST", targets: ["MLXMNIST"]), .library( name: "StableDiffusion", targets: ["StableDiffusion"]), ], dependencies: [ - .package(url: "https://github.com/ml-explore/mlx-swift", from: "0.18.1"), + .package(url: "https://github.com/ml-explore/mlx-swift", from: "0.21.2"), .package(url: "https://github.com/huggingface/swift-transformers", from: "0.1.13"), .package(url: "https://github.com/1024jp/GzipSwift", "6.0.1" ... "6.0.1"), .package(url: "https://github.com/apple/swift-async-algorithms", from: "1.0.0"), @@ -27,6 +33,7 @@ let package = Package( .target( name: "MLXLLM", dependencies: [ + "MLXLMCommon", .product(name: "MLX", package: "mlx-swift"), .product(name: "MLXFast", package: "mlx-swift"), .product(name: "MLXNN", package: "mlx-swift"), @@ -34,10 +41,48 @@ let package = Package( .product(name: "MLXRandom", package: "mlx-swift"), .product(name: "Transformers", package: "swift-transformers"), ], - path: "Libraries/LLM", + path: "Libraries/MLXLLM", exclude: [ - "README.md", - "LLM.h", + "README.md" + ], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency") + ] + ), + .target( + name: "MLXVLM", + dependencies: [ + "MLXLMCommon", + .product(name: "MLX", package: "mlx-swift"), + .product(name: "MLXFast", package: "mlx-swift"), + .product(name: "MLXNN", package: "mlx-swift"), + .product(name: "MLXOptimizers", package: "mlx-swift"), + .product(name: "MLXRandom", package: "mlx-swift"), + .product(name: "Transformers", package: "swift-transformers"), + ], + path: "Libraries/MLXVLM", + exclude: [ + "README.md" + ], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency") + ] + ), + .target( + name: "MLXLMCommon", + dependencies: [ + .product(name: "MLX", package: "mlx-swift"), + .product(name: "MLXNN", package: "mlx-swift"), + .product(name: "MLXOptimizers", package: "mlx-swift"), + .product(name: "MLXRandom", package: "mlx-swift"), + .product(name: "Transformers", package: "swift-transformers"), + ], + path: "Libraries/MLXLMCommon", + exclude: [ + "README.md" + ], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency") ] ), .target( @@ -65,10 +110,12 @@ let package = Package( .product(name: "Transformers", package: "swift-transformers"), .product(name: "Gzip", package: "GzipSwift"), ], - path: "Libraries/MNIST", + path: "Libraries/MLXMNIST", exclude: [ - "README.md", - "MNIST.h", + "README.md" + ], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency") ] ), .target( @@ -77,10 +124,14 @@ let package = Package( .product(name: "MLX", package: "mlx-swift"), .product(name: "MLXNN", package: "mlx-swift"), .product(name: "MLXRandom", package: "mlx-swift"), + .product(name: "Transformers", package: "swift-transformers"), ], path: "Libraries/StableDiffusion", exclude: [ "README.md" + ], + swiftSettings: [ + .enableExperimentalFeature("StrictConcurrency") ] ), ] diff --git a/README.md b/README.md index 9da5af8..b2006f6 100644 --- a/README.md +++ b/README.md @@ -51,13 +51,13 @@ Add the following dependency to your Package.swift .package(url: "https://github.com/ml-explore/mlx-swift-examples/", branch: "main"), ``` -Then add one library or both libraries to the target as a dependency. +Then add one or more libraries to the target as a dependency: ```swift .target( name: "YourTargetName", dependencies: [ - .product(name: "LLM", package: "mlx-swift-examples") + .product(name: "MLXLLM", package: "mlx-swift-examples") ]), ``` diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift index 3706f2e..b5ea4da 100644 --- a/Tools/llm-tool/LLMTool.swift +++ b/Tools/llm-tool/LLMTool.swift @@ -1,10 +1,14 @@ // Copyright © 2024 Apple Inc. import ArgumentParser +import CoreImage import Foundation -import LLM +import Hub import MLX +import MLXLLM +import MLXLMCommon import MLXRandom +import MLXVLM import Tokenizers @main @@ -19,21 +23,22 @@ struct LLMTool: AsyncParsableCommand { struct ModelArguments: ParsableArguments, Sendable { @Option(name: .long, help: "Name of the huggingface model or absolute path to directory") - var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx" + var model: String? @Sendable - func load() async throws -> (ModelContainer, ModelConfiguration) { + func load(defaultModel: String, modelFactory: ModelFactory) async throws -> ModelContainer { let modelConfiguration: ModelConfiguration - if self.model.hasPrefix("/") { + let modelName = self.model ?? defaultModel + + if modelName.hasPrefix("/") { // path - modelConfiguration = ModelConfiguration(directory: URL(filePath: self.model)) + modelConfiguration = ModelConfiguration(directory: URL(filePath: modelName)) } else { // identifier - modelConfiguration = await ModelConfiguration.configuration(id: model) + modelConfiguration = modelFactory.configuration(id: modelName) } - let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration) - return (modelContainer, modelConfiguration) + return try await modelFactory.loadContainer(configuration: modelConfiguration) } } @@ -85,16 +90,14 @@ struct GenerateArguments: ParsableArguments, Sendable { } func generate( - promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer, - extraEOSTokens: Set? = nil + input: LMInput, context: ModelContext ) - -> GenerateResult + throws -> GenerateResult { - var detokenizer = NaiveStreamingDetokenizer(tokenizer: tokenizer) + var detokenizer = NaiveStreamingDetokenizer(tokenizer: context.tokenizer) - return LLM.generate( - promptTokens: promptTokens, parameters: generateParameters, - model: model, tokenizer: tokenizer, extraEOSTokens: extraEOSTokens + return try MLXLMCommon.generate( + input: input, parameters: generateParameters, context: context ) { tokens in if let last = tokens.last { @@ -129,7 +132,7 @@ struct MemoryArguments: ParsableArguments, Sendable { var startMemory: GPU.Snapshot? - mutating func start(_ load: () async throws -> L) async throws -> L { + mutating func start(_ load: @Sendable () async throws -> L) async throws -> L { if let cacheSize { GPU.set(cacheLimit: cacheSize * 1024 * 1024) } @@ -201,31 +204,74 @@ struct EvaluateCommand: AsyncParsableCommand { @OptionGroup var memory: MemoryArguments @OptionGroup var generate: GenerateArguments + @Option(parsing: .upToNextOption, help: "Resize images to this size (width, height)") + var resize: [Int] = [] + + @Option(parsing: .upToNextOption, help: "Paths or urls for input images") + var image: [URL] = [] + + private func userInput(modelConfiguration: ModelConfiguration) -> UserInput { + // prompt and images + let prompt = generate.prompt ?? modelConfiguration.defaultPrompt + let images = image.map { UserInput.Image.url($0) } + var input = UserInput(prompt: prompt, images: images) + + // processing instructions + if !resize.isEmpty { + let size: CGSize + if resize.count == 1 { + // single value represents width/height + let v = resize[0] + size = CGSize(width: v, height: v) + } else { + let v0 = resize[0] + let v1 = resize[0] + size = CGSize(width: v0, height: v1) + } + input.processing.resize = size + } + + return input + } + @MainActor mutating func run() async throws { - let (modelContainer, modelConfiguration) = try await memory.start(args.load) + let modelFactory: ModelFactory + let defaultModel: ModelConfiguration + + // switch between LLM and VLM + let vlm = image.count > 0 + if vlm { + modelFactory = VLMModelFactory.shared + defaultModel = MLXVLM.ModelRegistry.qwen2VL2BInstruct4Bit + } else { + modelFactory = LLMModelFactory.shared + defaultModel = MLXLLM.ModelRegistry.mistral7B4bit + } + + // load the model + let modelContainer = try await memory.start { [args] in + try await args.load(defaultModel: defaultModel.name, modelFactory: modelFactory) + } + + // get the resolved configuration (this has the default prompt) + let modelConfiguration = modelContainer.configuration if !generate.quiet { print("Model loaded -> \(modelConfiguration.id)") } - let prompt = generate.prompt ?? modelConfiguration.defaultPrompt - let messages = [["role": "user", "content": prompt]] - let promptTokens = try await modelContainer.perform { _, tokenizer in - try tokenizer.applyChatTemplate(messages: messages) - } + let userInput = self.userInput(modelConfiguration: modelConfiguration) if !generate.quiet { print("Starting generation ...") - print(prompt, terminator: "") + print(userInput.prompt, terminator: " ") } - let result = await modelContainer.perform { [generate] model, tokenizer in - generate.generate( - promptTokens: promptTokens, model: model, tokenizer: tokenizer, - extraEOSTokens: modelConfiguration.extraEOSTokens) + let result = try await modelContainer.perform { [generate] context in + let input = try await context.processor.prepare(input: userInput) + return try generate.generate(input: input, context: context) } - print() if !generate.quiet { print("------") diff --git a/Tools/llm-tool/LoraCommands.swift b/Tools/llm-tool/LoraCommands.swift index a5f668f..4e47604 100644 --- a/Tools/llm-tool/LoraCommands.swift +++ b/Tools/llm-tool/LoraCommands.swift @@ -3,8 +3,9 @@ import ArgumentParser import Foundation import Hub -import LLM import MLX +import MLXLLM +import MLXLMCommon import MLXNN import MLXOptimizers import MLXRandom @@ -21,6 +22,8 @@ struct LoRACommand: AsyncParsableCommand { ) } +private let defaultModel = MLXLLM.ModelRegistry.mistral7B4bit.name + /// Common arguments for loading a LoRA mdoel with adapter weights struct LoRAModelArguments: ParsableArguments, Sendable { @@ -35,26 +38,22 @@ struct LoRAModelArguments: ParsableArguments, Sendable { /// Load the model and apply the LoRA adapters. /// /// This does not load the adapter weights as they may not exist yet. - func load() async throws -> (ModelContainer, ModelConfiguration) { - let (modelContainer, modelConfiguration) = try await args.load() + func load( + defaultModel: String = defaultModel, + modelFactory: ModelFactory = LLMModelFactory.shared + ) async throws -> ModelContainer { + let modelContainer = try await args.load( + defaultModel: defaultModel, modelFactory: modelFactory) // convert some of the Linear layers to LoRALinear - await modelContainer.perform { model, _ in - LoRATrain.convert(model: model, layers: loraLayers(model: model)) - } - - return (modelContainer, modelConfiguration) - } - - func loraLayers(model: Module) -> LoRALinearLayers { - guard let layerProvider = model as? LoRAModel else { - // the layerProvider will indicate which Linear layers need to be replaced - fatalError( - "Model \(type(of: model)) (\(args.model)) must implement the LoRALayerProvider protocol" - ) + await modelContainer.perform { context in + guard let lora = context.model as? LoRAModel else { + fatalError("Model \(modelContainer.configuration.name) is not a LoRAModel") + } + LoRATrain.convert(model: context.model, layers: lora.loraLinearLayers(loraLayers)) } - return Array(layerProvider.loraLinearLayers().suffix(loraLayers)) + return modelContainer } func describe(model: Module) { @@ -62,7 +61,7 @@ struct LoRAModelArguments: ParsableArguments, Sendable { let trainableParameterCount = model.trainableParameters() .flattenedValues().map { $0.size }.reduce(0, +) - print("Model: \(args.model)") + print("Model: \(args.model ?? defaultModel)") print("Total parameters: \((totalParameterCount / 1_000_000).formatted())M") print( "Trainable parameters: \((Float(trainableParameterCount) / 1_000_000).formatted(.number.precision(.significantDigits(1 ..< 4))))M" @@ -122,17 +121,17 @@ struct LoRATrainCommand: AsyncParsableCommand { @MainActor mutating func run() async throws { - let (modelContainer, _) = try await args.load() - await modelContainer.perform { [args] model, _ in - args.describe(model: model) + let modelContainer = try await args.load() + await modelContainer.perform { [args] context in + args.describe(model: context.model) } memory.start() if resume { print("Loading pretrained adapters from \(args.adapter.path())") - try await modelContainer.perform { [args] model, _ in - try LoRATrain.loadLoRAWeights(model: model, url: args.adapter) + try await modelContainer.perform { [args] context in + try LoRATrain.loadLoRAWeights(model: context.model, url: args.adapter) } } @@ -148,17 +147,17 @@ struct LoRATrainCommand: AsyncParsableCommand { } // train - try await modelContainer.perform { [args, parameters, learningRate] model, tokenizer in + try await modelContainer.perform { [args, parameters, learningRate] context in let optimizer = Adam(learningRate: learningRate) try LoRATrain.train( - model: model, train: train, validate: valid, optimizer: optimizer, - tokenizer: tokenizer, + model: context.model, train: train, validate: valid, optimizer: optimizer, + tokenizer: context.tokenizer, parameters: parameters ) { progress in print(progress) return .more } - try LoRATrain.saveLoRAWeights(model: model, url: args.adapter) + try LoRATrain.saveLoRAWeights(model: context.model, url: args.adapter) } } } @@ -188,22 +187,27 @@ struct LoRAFuseCommand: AsyncParsableCommand { outputURL = HubApi().localRepoLocation(repo) } - let (modelContainer, modelConfiguration) = try await args.load() + let modelContainer = try await args.load() // load the prepared weights - try await modelContainer.perform { [args] model, _ in - try LoRATrain.loadLoRAWeights(model: model, url: args.adapter) + try await modelContainer.perform { [args] context in + try LoRATrain.loadLoRAWeights(model: context.model, url: args.adapter) } // fuse them back into Linear/QuantizedLinear - await modelContainer.perform { [args, deQuantize] model, _ in + await modelContainer.perform { [args, deQuantize] context in + guard let lora = context.model as? LoRAModel else { + fatalError("Model \(modelContainer.configuration.name) is not a LoRAModel") + } + LoRATrain.fuse( - model: model, layers: args.loraLayers(model: model), deQuantize: deQuantize) + model: context.model, layers: lora.loraLinearLayers(args.loraLayers), + deQuantize: deQuantize) } // make the new directory and copy files from source model try FileManager.default.createDirectory(at: outputURL, withIntermediateDirectories: true) - let inputURL = modelConfiguration.modelDirectory() + let inputURL = modelContainer.configuration.modelDirectory() let enumerator = FileManager.default.enumerator( at: inputURL, includingPropertiesForKeys: nil)! for case let url as URL in enumerator { @@ -217,8 +221,8 @@ struct LoRAFuseCommand: AsyncParsableCommand { } // write them back out - try await modelContainer.perform { model, _ in - let weights = Dictionary(uniqueKeysWithValues: model.parameters().flattened()) + try await modelContainer.perform { context in + let weights = Dictionary(uniqueKeysWithValues: context.model.parameters().flattened()) try save(arrays: weights, url: outputURL.appending(component: "weights.safetensors")) } @@ -246,20 +250,21 @@ struct LoRATestCommand: AsyncParsableCommand { @MainActor mutating func run() async throws { - let (modelContainer, _) = try await args.load() - await modelContainer.perform { [args] model, _ in - args.describe(model: model) + let modelContainer = try await args.load() + await modelContainer.perform { [args] context in + args.describe(model: context.model) } - try await modelContainer.perform { [args] model, _ in - try LoRATrain.loadLoRAWeights(model: model, url: args.adapter) + try await modelContainer.perform { [args] context in + try LoRATrain.loadLoRAWeights(model: context.model, url: args.adapter) } memory.start() let test = try loadLoRAData(directory: data, name: "test") - let loss = await modelContainer.perform { [batchSize] model, tokenizer in + let loss = await modelContainer.perform { [batchSize] context in LoRATrain.evaluate( - model: model, dataset: test, tokenizer: tokenizer, batchSize: batchSize, + model: context.model, dataset: test, + tokenizer: context.tokenizer, batchSize: batchSize, batchCount: 0) } @@ -281,21 +286,17 @@ struct LoRAEvalCommand: AsyncParsableCommand { @MainActor mutating func run() async throws { - let (modelContainer, modelConfiguration) = try await args.load() - await modelContainer.perform { [args] model, _ in - args.describe(model: model) + let modelContainer = try await args.load() + await modelContainer.perform { [args] context in + args.describe(model: context.model) } - try await modelContainer.perform { [args] model, _ in - try LoRATrain.loadLoRAWeights(model: model, url: args.adapter) + try await modelContainer.perform { [args] context in + try LoRATrain.loadLoRAWeights(model: context.model, url: args.adapter) } memory.start() - let prompt = generate.prompt ?? modelConfiguration.defaultPrompt - let messages = [["role": "user", "content": prompt]] - let promptTokens = try await modelContainer.perform { _, tokenizer in - try tokenizer.applyChatTemplate(messages: messages) - } + let prompt = generate.prompt ?? modelContainer.configuration.defaultPrompt if !generate.quiet { print("Starting generation ...") @@ -303,11 +304,16 @@ struct LoRAEvalCommand: AsyncParsableCommand { } // generate and print the result - await modelContainer.perform { [generate] model, tokenizer in - let _ = generate.generate( - promptTokens: promptTokens, model: model, tokenizer: tokenizer, - extraEOSTokens: modelConfiguration.extraEOSTokens) + let result = try await modelContainer.perform { [generate] context in + let input = try await context.processor.prepare(input: .init(prompt: prompt)) + return try generate.generate(input: input, context: context) + } + + if !generate.quiet { + print("------") + print(result.summary()) + + memory.reportMemoryStatistics() } - print() } } diff --git a/Tools/llm-tool/README.md b/Tools/llm-tool/README.md index 452fe94..f630529 100644 --- a/Tools/llm-tool/README.md +++ b/Tools/llm-tool/README.md @@ -2,7 +2,9 @@ See various READMEs: -- [LLM](../../Libraries/LLM/README.md) +- [MLXLMCommon](../../Libraries/MLXLMCommon/README.md) -- common LM code +- [MLXLLM](../../Libraries/MLXLLM/README.md) -- large language models +- [MLXVLM](../../Libraries/MLXVLM/README.md) -- vision language models ### Building @@ -28,7 +30,7 @@ The model should be a path in the Hugging Face repository, e.g.: - `mlx-community/Mistral-7B-v0.1-hf-4bit-mlx` - `mlx-community/phi-2-hf-4bit-mlx` -See [LLM](../../Libraries/LLM/README.md) for more info. +See [LLM](../../Libraries/MLXLLM/README.md) for more info. ### Running: Command Line diff --git a/Tools/mnist-tool/MNISTTool.swift b/Tools/mnist-tool/MNISTTool.swift index 7557ec7..bf04617 100644 --- a/Tools/mnist-tool/MNISTTool.swift +++ b/Tools/mnist-tool/MNISTTool.swift @@ -3,10 +3,10 @@ import ArgumentParser import Foundation import MLX +import MLXMNIST import MLXNN import MLXOptimizers import MLXRandom -import MNIST @main struct MNISTTool: AsyncParsableCommand { diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj index 094e4f6..8967789 100644 --- a/mlx-swift-examples.xcodeproj/project.pbxproj +++ b/mlx-swift-examples.xcodeproj/project.pbxproj @@ -3,20 +3,13 @@ archiveVersion = 1; classes = { }; - objectVersion = 56; + objectVersion = 70; objects = { /* Begin PBXBuildFile section */ 12305EAF2B9D864400C92FEE /* PredictionView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 12305EAE2B9D864400C92FEE /* PredictionView.swift */; }; - 1C55317A2C5AAB4E00B07ECD /* Gemma2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1C5531792C5AAB4E00B07ECD /* Gemma2.swift */; }; - 1CD79C702BD80DE100B6C06F /* Phi3.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1CD79C6F2BD80DE100B6C06F /* Phi3.swift */; }; - 525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */; }; - 52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; }; - 7BBD0D6E2BE044A10019C5D7 /* OpenELM.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7BBD0D6D2BE044A10019C5D7 /* OpenELM.swift */; }; 81695B412BA373D300F260D8 /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = 81695B402BA373D300F260D8 /* MarkdownUI */; }; 819BEFF82BAF8B4E0002CCEE /* DeviceStat.swift in Sources */ = {isa = PBXBuildFile; fileRef = 819BEFF62BAF8B4E0002CCEE /* DeviceStat.swift */; }; - 927B80422C83769800500C13 /* PhiMoE.swift in Sources */ = {isa = PBXBuildFile; fileRef = 927B80412C83769400500C13 /* PhiMoE.swift */; }; - 927C784E2C7A578A001E5878 /* SuScaledRotaryEmbedding.swift in Sources */ = {isa = PBXBuildFile; fileRef = 927C784D2C7A578A001E5878 /* SuScaledRotaryEmbedding.swift */; }; C3056BAE2BCD97B700A31D04 /* LoRATrainingExampleApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3056BAD2BCD97B700A31D04 /* LoRATrainingExampleApp.swift */; }; C3056BB02BCD97B700A31D04 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3056BAF2BCD97B700A31D04 /* ContentView.swift */; }; C3056BB22BCD97B800A31D04 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3056BB12BCD97B800A31D04 /* Assets.xcassets */; }; @@ -24,147 +17,45 @@ C3056BBA2BCD981900A31D04 /* train.jsonl in Resources */ = {isa = PBXBuildFile; fileRef = C3056BA22BCD973400A31D04 /* train.jsonl */; }; C3056BBB2BCD981900A31D04 /* test.jsonl in Resources */ = {isa = PBXBuildFile; fileRef = C3056BA12BCD973400A31D04 /* test.jsonl */; }; C3056BBC2BCD981900A31D04 /* valid.jsonl in Resources */ = {isa = PBXBuildFile; fileRef = C3056BA32BCD973400A31D04 /* valid.jsonl */; }; - C3056BBD2BCD984F00A31D04 /* LLM.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; }; - C3056BBE2BCD984F00A31D04 /* LLM.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; C3288D762B6D9313009FF608 /* LinearModelTraining.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3288D752B6D9313009FF608 /* LinearModelTraining.swift */; }; C3288D7B2B6D9339009FF608 /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C3288D7A2B6D9339009FF608 /* ArgumentParser */; }; - C343B2782CC8091B00334888 /* SwitchLayers.swift in Sources */ = {isa = PBXBuildFile; fileRef = C343B2772CC8091B00334888 /* SwitchLayers.swift */; }; + C32A17FD2CFFB98A0092A5B6 /* MLXLLM in Frameworks */ = {isa = PBXBuildFile; productRef = C32A17FC2CFFB98A0092A5B6 /* MLXLLM */; }; + C32A17FF2CFFB98A0092A5B6 /* MLXVLM in Frameworks */ = {isa = PBXBuildFile; productRef = C32A17FE2CFFB98A0092A5B6 /* MLXVLM */; }; + C32A18012CFFD1810092A5B6 /* MLXMNIST in Frameworks */ = {isa = PBXBuildFile; productRef = C32A18002CFFD1810092A5B6 /* MLXMNIST */; }; + C32A18032CFFD1920092A5B6 /* MLXMNIST in Frameworks */ = {isa = PBXBuildFile; productRef = C32A18022CFFD1920092A5B6 /* MLXMNIST */; }; + C32A18052CFFD19F0092A5B6 /* MLXLLM in Frameworks */ = {isa = PBXBuildFile; productRef = C32A18042CFFD19F0092A5B6 /* MLXLLM */; }; + C32A18072CFFD1AA0092A5B6 /* MLXLLM in Frameworks */ = {isa = PBXBuildFile; productRef = C32A18062CFFD1AA0092A5B6 /* MLXLLM */; }; + C32A18092CFFD1B70092A5B6 /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = C32A18082CFFD1B70092A5B6 /* StableDiffusion */; }; + C32A18462D00E1490092A5B6 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C32A18452D00E1490092A5B6 /* MLX */; }; + C32A18482D00E1540092A5B6 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C32A18472D00E1540092A5B6 /* MLX */; }; + C32A184A2D00E1540092A5B6 /* MLXNN in Frameworks */ = {isa = PBXBuildFile; productRef = C32A18492D00E1540092A5B6 /* MLXNN */; }; + C32A184C2D00E1540092A5B6 /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = C32A184B2D00E1540092A5B6 /* MLXOptimizers */; }; C34E48F52B696F0B00FCB841 /* LLMTool.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48F42B696F0B00FCB841 /* LLMTool.swift */; }; - C34E49102B69A92900FCB841 /* MNIST.h in Headers */ = {isa = PBXBuildFile; fileRef = C34E490F2B69A92900FCB841 /* MNIST.h */; settings = {ATTRIBUTES = (Public, ); }; }; - C34E49152B69C1E300FCB841 /* Files.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E49142B69C1E300FCB841 /* Files.swift */; }; - C34E491C2B69C43600FCB841 /* Gzip in Frameworks */ = {isa = PBXBuildFile; productRef = C34E491B2B69C43600FCB841 /* Gzip */; }; C34E49242B6A026F00FCB841 /* MNISTTool.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E49232B6A026F00FCB841 /* MNISTTool.swift */; }; C34E49292B6A028100FCB841 /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C34E49282B6A028100FCB841 /* ArgumentParser */; }; - C34E492A2B6A028800FCB841 /* MNIST.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C34E490D2B69A92900FCB841 /* MNIST.framework */; }; - C34E492B2B6A028800FCB841 /* MNIST.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C34E490D2B69A92900FCB841 /* MNIST.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; - C36BEFB02BBCBAC2002D4AFE /* Lora.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFAF2BBCBAC2002D4AFE /* Lora.swift */; }; - C36BEFB22BBDE9D0002D4AFE /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = C36BEFB12BBDE9D0002D4AFE /* MLXOptimizers */; }; C36BEFB52BBDEAD8002D4AFE /* LoraCommands.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFB32BBDEA69002D4AFE /* LoraCommands.swift */; }; C36BEFB82BBDED51002D4AFE /* Arguments.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFB62BBDECBC002D4AFE /* Arguments.swift */; }; - C36BEFBB2BBF02CC002D4AFE /* Lora+Data.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFBA2BBF02CC002D4AFE /* Lora+Data.swift */; }; - C36BEFCA2BC09953002D4AFE /* Configuration.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFC92BC09953002D4AFE /* Configuration.swift */; }; - C36BEFCC2BC09E53002D4AFE /* Sampler.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFCB2BC09E53002D4AFE /* Sampler.swift */; }; - C36BEFCE2BC0A194002D4AFE /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C36BEFCD2BC0A194002D4AFE /* MLX */; }; - C36BEFD02BC0A194002D4AFE /* MLXNN in Frameworks */ = {isa = PBXBuildFile; productRef = C36BEFCF2BC0A194002D4AFE /* MLXNN */; }; - C36BEFD22BC0A194002D4AFE /* MLXRandom in Frameworks */ = {isa = PBXBuildFile; productRef = C36BEFD12BC0A194002D4AFE /* MLXRandom */; }; - C36BEFD42BC0B4A9002D4AFE /* Clip.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFD32BC0B4A9002D4AFE /* Clip.swift */; }; - C36BEFD82BC0BD9B002D4AFE /* VAE.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFD72BC0BD9B002D4AFE /* VAE.swift */; }; - C36BEFDA2BC0BDDC002D4AFE /* UNet.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFD92BC0BDDC002D4AFE /* UNet.swift */; }; C36BEFE32BC32988002D4AFE /* ImageTool.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFE22BC32988002D4AFE /* ImageTool.swift */; }; - C36BEFE72BC329AB002D4AFE /* StableDiffusion.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C36BEFC22BC098F3002D4AFE /* StableDiffusion.framework */; }; - C36BEFE82BC329AB002D4AFE /* StableDiffusion.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C36BEFC22BC098F3002D4AFE /* StableDiffusion.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; C36BEFEF2BC329C5002D4AFE /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C36BEFEE2BC329C5002D4AFE /* ArgumentParser */; }; C36BEFF22BC32A9A002D4AFE /* Progress in Frameworks */ = {isa = PBXBuildFile; productRef = C36BEFF12BC32A9A002D4AFE /* Progress */; }; - C36BEFF42BC349FA002D4AFE /* Load.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFF32BC349FA002D4AFE /* Load.swift */; }; - C36BEFF62BC34A46002D4AFE /* StableDiffusion.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFF52BC34A46002D4AFE /* StableDiffusion.swift */; }; - C36BEFF82BC59CE1002D4AFE /* Tokenizer.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFF72BC59CE1002D4AFE /* Tokenizer.swift */; }; - C36BEFFA2BC5B996002D4AFE /* Transformers in Frameworks */ = {isa = PBXBuildFile; productRef = C36BEFF92BC5B996002D4AFE /* Transformers */; }; - C36BEFFC2BC5BA79002D4AFE /* Image.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFFB2BC5BA79002D4AFE /* Image.swift */; }; C36BF0042BC5CE55002D4AFE /* StableDiffusionExampleApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BF0032BC5CE55002D4AFE /* StableDiffusionExampleApp.swift */; }; C36BF0062BC5CE55002D4AFE /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BF0052BC5CE55002D4AFE /* ContentView.swift */; }; C36BF0082BC5CE56002D4AFE /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C36BF0072BC5CE56002D4AFE /* Assets.xcassets */; }; C36BF00C2BC5CE56002D4AFE /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C36BF00B2BC5CE56002D4AFE /* Preview Assets.xcassets */; }; - C36BF0102BC5CF17002D4AFE /* StableDiffusion.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C36BEFC22BC098F3002D4AFE /* StableDiffusion.framework */; }; - C36BF0112BC5CF17002D4AFE /* StableDiffusion.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C36BEFC22BC098F3002D4AFE /* StableDiffusion.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; C36BF0352BC70F11002D4AFE /* Arguments.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BF0342BC70F11002D4AFE /* Arguments.swift */; }; - C373EB732C792DD1004E201E /* KVCache.swift in Sources */ = {isa = PBXBuildFile; fileRef = C373EB722C792DD1004E201E /* KVCache.swift */; }; - C380FFF72C8F5053006428A3 /* Internlm2.swift in Sources */ = {isa = PBXBuildFile; fileRef = C380FFF62C8F5053006428A3 /* Internlm2.swift */; }; - C38935C82B869C7A0037B833 /* LLM.h in Headers */ = {isa = PBXBuildFile; fileRef = C38935C72B869C7A0037B833 /* LLM.h */; settings = {ATTRIBUTES = (Public, ); }; }; - C38935CC2B869C870037B833 /* Llama.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48EE2B696E6500FCB841 /* Llama.swift */; }; - C38935CD2B869C870037B833 /* Configuration.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48EF2B696E6500FCB841 /* Configuration.swift */; }; - C38935CE2B869C870037B833 /* Load.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48ED2B696E6500FCB841 /* Load.swift */; }; - C38935D02B869CC40037B833 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C38935CF2B869CC40037B833 /* MLX */; }; - C38935D22B869CC40037B833 /* MLXNN in Frameworks */ = {isa = PBXBuildFile; productRef = C38935D12B869CC40037B833 /* MLXNN */; }; - C38935D42B869CC40037B833 /* MLXRandom in Frameworks */ = {isa = PBXBuildFile; productRef = C38935D32B869CC40037B833 /* MLXRandom */; }; - C38935D62B869CC40037B833 /* Transformers in Frameworks */ = {isa = PBXBuildFile; productRef = C38935D52B869CC40037B833 /* Transformers */; }; - C38935D72B869CCD0037B833 /* LLM.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; }; - C38935D82B869CCD0037B833 /* LLM.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; - C38935DF2B869DD00037B833 /* Phi.swift in Sources */ = {isa = PBXBuildFile; fileRef = C38935DE2B869DD00037B833 /* Phi.swift */; }; - C38935E12B869F420037B833 /* LLMModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = C38935E02B869F420037B833 /* LLMModel.swift */; }; - C38935E32B86C0FE0037B833 /* Gemma.swift in Sources */ = {isa = PBXBuildFile; fileRef = C38935E22B86C0FE0037B833 /* Gemma.swift */; }; C392737D2B606A1D00368D5D /* Tutorial.swift in Sources */ = {isa = PBXBuildFile; fileRef = C392737C2B606A1D00368D5D /* Tutorial.swift */; }; - C3932D572B6A060B00A81055 /* MNIST.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3932D562B6A060B00A81055 /* MNIST.swift */; }; - C3932D592B6A0BE400A81055 /* Random.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3932D582B6A0BE400A81055 /* Random.swift */; }; C397C59C2B62C6D0004B084D /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C397C59B2B62C6D0004B084D /* ArgumentParser */; }; - C3A8B3AC2B9283150002EFB8 /* Models.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3AB2B9283150002EFB8 /* Models.swift */; }; C3A8B3CB2B92951E0002EFB8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3C32B92951E0002EFB8 /* Assets.xcassets */; }; C3A8B3CC2B92951E0002EFB8 /* MNISTTrainerApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3C42B92951E0002EFB8 /* MNISTTrainerApp.swift */; }; C3A8B3CD2B92951E0002EFB8 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3C62B92951E0002EFB8 /* Preview Assets.xcassets */; }; C3A8B3CF2B92951E0002EFB8 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3C92B92951E0002EFB8 /* ContentView.swift */; }; - C3A8B3D22B92A0880002EFB8 /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = C3A8B3D12B92A0880002EFB8 /* MLXOptimizers */; }; - C3A8B3D32B92A0880002EFB8 /* MNIST.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C34E490D2B69A92900FCB841 /* MNIST.framework */; }; - C3A8B3D42B92A0880002EFB8 /* MNIST.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C34E490D2B69A92900FCB841 /* MNIST.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; C3A8B3F32B92A2A90002EFB8 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3EC2B92A2A90002EFB8 /* Assets.xcassets */; }; C3A8B3F42B92A2A90002EFB8 /* LLMEvalApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3ED2B92A2A90002EFB8 /* LLMEvalApp.swift */; }; C3A8B3F52B92A2A90002EFB8 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3A8B3EF2B92A2A90002EFB8 /* Preview Assets.xcassets */; }; C3A8B3F72B92A2A90002EFB8 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3A8B3F22B92A2A90002EFB8 /* ContentView.swift */; }; - C3A8B3F82B92A3360002EFB8 /* LLM.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; }; - C3A8B3F92B92A3360002EFB8 /* LLM.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; - C3E786AB2B8D1AEC0004D037 /* Evaluate.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */; }; - C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */; }; - C3FBCB212B8520B80007E490 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB202B8520B80007E490 /* MLX */; }; - C3FBCB292B8520DA0007E490 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB282B8520DA0007E490 /* MLX */; }; - C3FBCB2B2B8520DA0007E490 /* MLXNN in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB2A2B8520DA0007E490 /* MLXNN */; }; - C3FBCB2D2B8520E80007E490 /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB2C2B8520E80007E490 /* MLXOptimizers */; }; - C3FBCB2F2B8520F20007E490 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB2E2B8520F20007E490 /* MLX */; }; - C3FBCB312B8520F20007E490 /* MLXNN in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB302B8520F20007E490 /* MLXNN */; }; - C3FBCB332B8520F20007E490 /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB322B8520F20007E490 /* MLXOptimizers */; }; - C3FBCB352B8520F20007E490 /* MLXRandom in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB342B8520F20007E490 /* MLXRandom */; }; - F24B083A2BAF1A65008C8D19 /* Cohere.swift in Sources */ = {isa = PBXBuildFile; fileRef = F24B08392BAF1A65008C8D19 /* Cohere.swift */; }; + C3E7D94D2CF6C9B20056C095 /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = C3E7D94C2CF6C9B20056C095 /* StableDiffusion */; }; /* End PBXBuildFile section */ -/* Begin PBXContainerItemProxy section */ - C3056BBF2BCD984F00A31D04 /* PBXContainerItemProxy */ = { - isa = PBXContainerItemProxy; - containerPortal = C39273682B60697700368D5D /* Project object */; - proxyType = 1; - remoteGlobalIDString = C38935C42B869C7A0037B833; - remoteInfo = LLM; - }; - C34E492C2B6A028800FCB841 /* PBXContainerItemProxy */ = { - isa = PBXContainerItemProxy; - containerPortal = C39273682B60697700368D5D /* Project object */; - proxyType = 1; - remoteGlobalIDString = C34E490C2B69A92900FCB841; - remoteInfo = MNIST; - }; - C36BEFE92BC329AB002D4AFE /* PBXContainerItemProxy */ = { - isa = PBXContainerItemProxy; - containerPortal = C39273682B60697700368D5D /* Project object */; - proxyType = 1; - remoteGlobalIDString = C36BEFC12BC098F3002D4AFE; - remoteInfo = StableDiffusion; - }; - C36BF0122BC5CF17002D4AFE /* PBXContainerItemProxy */ = { - isa = PBXContainerItemProxy; - containerPortal = C39273682B60697700368D5D /* Project object */; - proxyType = 1; - remoteGlobalIDString = C36BEFC12BC098F3002D4AFE; - remoteInfo = StableDiffusion; - }; - C38935D92B869CCD0037B833 /* PBXContainerItemProxy */ = { - isa = PBXContainerItemProxy; - containerPortal = C39273682B60697700368D5D /* Project object */; - proxyType = 1; - remoteGlobalIDString = C38935C42B869C7A0037B833; - remoteInfo = LLM; - }; - C3A8B3D52B92A0880002EFB8 /* PBXContainerItemProxy */ = { - isa = PBXContainerItemProxy; - containerPortal = C39273682B60697700368D5D /* Project object */; - proxyType = 1; - remoteGlobalIDString = C34E490C2B69A92900FCB841; - remoteInfo = MNIST; - }; - C3A8B3FA2B92A3360002EFB8 /* PBXContainerItemProxy */ = { - isa = PBXContainerItemProxy; - containerPortal = C39273682B60697700368D5D /* Project object */; - proxyType = 1; - remoteGlobalIDString = C38935C42B869C7A0037B833; - remoteInfo = LLM; - }; -/* End PBXContainerItemProxy section */ - /* Begin PBXCopyFilesBuildPhase section */ C3056BC12BCD984F00A31D04 /* Embed Frameworks */ = { isa = PBXCopyFilesBuildPhase; @@ -172,7 +63,6 @@ dstPath = ""; dstSubfolderSpec = 10; files = ( - C3056BBE2BCD984F00A31D04 /* LLM.framework in Embed Frameworks */, ); name = "Embed Frameworks"; runOnlyForDeploymentPostprocessing = 0; @@ -201,7 +91,6 @@ dstPath = ""; dstSubfolderSpec = 10; files = ( - C34E492B2B6A028800FCB841 /* MNIST.framework in Embed Frameworks */, ); name = "Embed Frameworks"; runOnlyForDeploymentPostprocessing = 0; @@ -221,7 +110,6 @@ dstPath = ""; dstSubfolderSpec = 10; files = ( - C36BEFE82BC329AB002D4AFE /* StableDiffusion.framework in Embed Frameworks */, ); name = "Embed Frameworks"; runOnlyForDeploymentPostprocessing = 0; @@ -232,7 +120,6 @@ dstPath = ""; dstSubfolderSpec = 10; files = ( - C36BF0112BC5CF17002D4AFE /* StableDiffusion.framework in Embed Frameworks */, ); name = "Embed Frameworks"; runOnlyForDeploymentPostprocessing = 0; @@ -243,7 +130,6 @@ dstPath = ""; dstSubfolderSpec = 10; files = ( - C38935D82B869CCD0037B833 /* LLM.framework in Embed Frameworks */, ); name = "Embed Frameworks"; runOnlyForDeploymentPostprocessing = 0; @@ -272,7 +158,6 @@ dstPath = ""; dstSubfolderSpec = 10; files = ( - C3A8B3D42B92A0880002EFB8 /* MNIST.framework in Embed Frameworks */, ); name = "Embed Frameworks"; runOnlyForDeploymentPostprocessing = 0; @@ -283,7 +168,6 @@ dstPath = ""; dstSubfolderSpec = 10; files = ( - C3A8B3F92B92A3360002EFB8 /* LLM.framework in Embed Frameworks */, ); name = "Embed Frameworks"; runOnlyForDeploymentPostprocessing = 0; @@ -292,14 +176,7 @@ /* Begin PBXFileReference section */ 12305EAE2B9D864400C92FEE /* PredictionView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PredictionView.swift; sourceTree = ""; }; - 1C5531792C5AAB4E00B07ECD /* Gemma2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Gemma2.swift; sourceTree = ""; }; - 1CD79C6F2BD80DE100B6C06F /* Phi3.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Phi3.swift; sourceTree = ""; }; - 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Starcoder2.swift; sourceTree = ""; }; - 52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = ""; }; - 7BBD0D6D2BE044A10019C5D7 /* OpenELM.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = OpenELM.swift; sourceTree = ""; }; 819BEFF62BAF8B4E0002CCEE /* DeviceStat.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DeviceStat.swift; sourceTree = ""; }; - 927B80412C83769400500C13 /* PhiMoE.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PhiMoE.swift; sourceTree = ""; }; - 927C784D2C7A578A001E5878 /* SuScaledRotaryEmbedding.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SuScaledRotaryEmbedding.swift; sourceTree = ""; }; C3056BA12BCD973400A31D04 /* test.jsonl */ = {isa = PBXFileReference; lastKnownFileType = text; path = test.jsonl; sourceTree = ""; }; C3056BA22BCD973400A31D04 /* train.jsonl */ = {isa = PBXFileReference; lastKnownFileType = text; path = train.jsonl; sourceTree = ""; }; C3056BA32BCD973400A31D04 /* valid.jsonl */ = {isa = PBXFileReference; lastKnownFileType = text; path = valid.jsonl; sourceTree = ""; }; @@ -315,34 +192,14 @@ C3288D732B6D9313009FF608 /* LinearModelTraining */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = LinearModelTraining; sourceTree = BUILT_PRODUCTS_DIR; }; C3288D752B6D9313009FF608 /* LinearModelTraining.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LinearModelTraining.swift; sourceTree = ""; }; C3288D842B6D94BD009FF608 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; - C343B2772CC8091B00334888 /* SwitchLayers.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SwitchLayers.swift; sourceTree = ""; }; - C34E48ED2B696E6500FCB841 /* Load.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Load.swift; sourceTree = ""; }; - C34E48EE2B696E6500FCB841 /* Llama.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Llama.swift; sourceTree = ""; }; - C34E48EF2B696E6500FCB841 /* Configuration.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Configuration.swift; sourceTree = ""; }; C34E48F42B696F0B00FCB841 /* LLMTool.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = LLMTool.swift; sourceTree = ""; }; - C34E48F62B69832600FCB841 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; C34E48F92B69930300FCB841 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; - C34E490D2B69A92900FCB841 /* MNIST.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = MNIST.framework; sourceTree = BUILT_PRODUCTS_DIR; }; - C34E490F2B69A92900FCB841 /* MNIST.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = MNIST.h; sourceTree = ""; }; - C34E49142B69C1E300FCB841 /* Files.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Files.swift; sourceTree = ""; }; C34E49212B6A026F00FCB841 /* mnist-tool */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = "mnist-tool"; sourceTree = BUILT_PRODUCTS_DIR; }; C34E49232B6A026F00FCB841 /* MNISTTool.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MNISTTool.swift; sourceTree = ""; }; - C36BEFAF2BBCBAC2002D4AFE /* Lora.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Lora.swift; sourceTree = ""; }; C36BEFB32BBDEA69002D4AFE /* LoraCommands.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LoraCommands.swift; sourceTree = ""; }; C36BEFB62BBDECBC002D4AFE /* Arguments.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Arguments.swift; sourceTree = ""; }; - C36BEFBA2BBF02CC002D4AFE /* Lora+Data.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "Lora+Data.swift"; sourceTree = ""; }; - C36BEFC22BC098F3002D4AFE /* StableDiffusion.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = StableDiffusion.framework; sourceTree = BUILT_PRODUCTS_DIR; }; - C36BEFC92BC09953002D4AFE /* Configuration.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Configuration.swift; sourceTree = ""; }; - C36BEFCB2BC09E53002D4AFE /* Sampler.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Sampler.swift; sourceTree = ""; }; - C36BEFD32BC0B4A9002D4AFE /* Clip.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Clip.swift; sourceTree = ""; }; - C36BEFD72BC0BD9B002D4AFE /* VAE.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = VAE.swift; sourceTree = ""; }; - C36BEFD92BC0BDDC002D4AFE /* UNet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = UNet.swift; sourceTree = ""; }; C36BEFE02BC32988002D4AFE /* image-tool */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = "image-tool"; sourceTree = BUILT_PRODUCTS_DIR; }; C36BEFE22BC32988002D4AFE /* ImageTool.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ImageTool.swift; sourceTree = ""; }; - C36BEFF32BC349FA002D4AFE /* Load.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Load.swift; sourceTree = ""; }; - C36BEFF52BC34A46002D4AFE /* StableDiffusion.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = StableDiffusion.swift; sourceTree = ""; }; - C36BEFF72BC59CE1002D4AFE /* Tokenizer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Tokenizer.swift; sourceTree = ""; }; - C36BEFFB2BC5BA79002D4AFE /* Image.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Image.swift; sourceTree = ""; }; C36BF0012BC5CE55002D4AFE /* StableDiffusionExample.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = StableDiffusionExample.app; sourceTree = BUILT_PRODUCTS_DIR; }; C36BF0032BC5CE55002D4AFE /* StableDiffusionExampleApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = StableDiffusionExampleApp.swift; sourceTree = ""; }; C36BF0052BC5CE55002D4AFE /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = ""; }; @@ -350,19 +207,9 @@ C36BF0092BC5CE56002D4AFE /* StableDiffusionExample.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = StableDiffusionExample.entitlements; sourceTree = ""; }; C36BF00B2BC5CE56002D4AFE /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = ""; }; C36BF0342BC70F11002D4AFE /* Arguments.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Arguments.swift; sourceTree = ""; }; - C373EB722C792DD1004E201E /* KVCache.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = KVCache.swift; sourceTree = ""; }; - C380FFF62C8F5053006428A3 /* Internlm2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Internlm2.swift; sourceTree = ""; }; - C38935C52B869C7A0037B833 /* LLM.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = LLM.framework; sourceTree = BUILT_PRODUCTS_DIR; }; - C38935C72B869C7A0037B833 /* LLM.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = LLM.h; sourceTree = ""; }; - C38935DE2B869DD00037B833 /* Phi.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Phi.swift; sourceTree = ""; }; - C38935E02B869F420037B833 /* LLMModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LLMModel.swift; sourceTree = ""; }; - C38935E22B86C0FE0037B833 /* Gemma.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Gemma.swift; sourceTree = ""; }; C39273742B606A0A00368D5D /* Tutorial */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = Tutorial; sourceTree = BUILT_PRODUCTS_DIR; }; C392737C2B606A1D00368D5D /* Tutorial.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Tutorial.swift; sourceTree = ""; }; - C3932D562B6A060B00A81055 /* MNIST.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MNIST.swift; sourceTree = ""; }; - C3932D582B6A0BE400A81055 /* Random.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Random.swift; sourceTree = ""; }; C397C58B2B62C6A9004B084D /* llm-tool */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = "llm-tool"; sourceTree = BUILT_PRODUCTS_DIR; }; - C3A8B3AB2B9283150002EFB8 /* Models.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Models.swift; sourceTree = ""; }; C3A8B3B22B9295090002EFB8 /* MNISTTrainer.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MNISTTrainer.app; sourceTree = BUILT_PRODUCTS_DIR; }; C3A8B3C22B92951E0002EFB8 /* MNISTTrainer-Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = "MNISTTrainer-Info.plist"; sourceTree = ""; }; C3A8B3C32B92951E0002EFB8 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; @@ -379,22 +226,21 @@ C3A8B3F12B92A2A90002EFB8 /* LLMEval.entitlements */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.entitlements; path = LLMEval.entitlements; sourceTree = ""; }; C3A8B3F22B92A2A90002EFB8 /* ContentView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = ""; }; C3C3240B2B6CA689007D2D9A /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; - C3C3240C2B6CA792007D2D9A /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; C3C36A6B2CA714600099FFA4 /* Build.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; path = Build.xcconfig; sourceTree = ""; }; C3D573052C40701E00857A35 /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; - C3D573062C40702D00857A35 /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; - C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Evaluate.swift; sourceTree = ""; }; - C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Tokenizer.swift; sourceTree = ""; }; - F24B08392BAF1A65008C8D19 /* Cohere.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Cohere.swift; sourceTree = ""; }; F8D7023A2BB4E223003D7CF5 /* Package.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Package.swift; sourceTree = ""; }; /* End PBXFileReference section */ +/* Begin PBXFileSystemSynchronizedRootGroup section */ + C397D92E2CD440EF00B87EE2 /* Libraries */ = {isa = PBXFileSystemSynchronizedRootGroup; explicitFileTypes = {}; explicitFolders = (); path = Libraries; sourceTree = ""; }; +/* End PBXFileSystemSynchronizedRootGroup section */ + /* Begin PBXFrameworksBuildPhase section */ C3056BA82BCD97B700A31D04 /* Frameworks */ = { isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - C3056BBD2BCD984F00A31D04 /* LLM.framework in Frameworks */, + C32A18072CFFD1AA0092A5B6 /* MLXLLM in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -402,52 +248,29 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - C3FBCB332B8520F20007E490 /* MLXOptimizers in Frameworks */, - C3FBCB312B8520F20007E490 /* MLXNN in Frameworks */, - C3FBCB2F2B8520F20007E490 /* MLX in Frameworks */, - C3FBCB352B8520F20007E490 /* MLXRandom in Frameworks */, + C32A184C2D00E1540092A5B6 /* MLXOptimizers in Frameworks */, + C32A184A2D00E1540092A5B6 /* MLXNN in Frameworks */, + C32A18482D00E1540092A5B6 /* MLX in Frameworks */, C3288D7B2B6D9339009FF608 /* ArgumentParser in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; - C34E490A2B69A92900FCB841 /* Frameworks */ = { - isa = PBXFrameworksBuildPhase; - buildActionMask = 2147483647; - files = ( - C3FBCB2B2B8520DA0007E490 /* MLXNN in Frameworks */, - C3FBCB292B8520DA0007E490 /* MLX in Frameworks */, - C34E491C2B69C43600FCB841 /* Gzip in Frameworks */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; C34E491E2B6A026F00FCB841 /* Frameworks */ = { isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - C3FBCB2D2B8520E80007E490 /* MLXOptimizers in Frameworks */, - C34E492A2B6A028800FCB841 /* MNIST.framework in Frameworks */, + C32A18012CFFD1810092A5B6 /* MLXMNIST in Frameworks */, C34E49292B6A028100FCB841 /* ArgumentParser in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; - C36BEFBF2BC098F3002D4AFE /* Frameworks */ = { - isa = PBXFrameworksBuildPhase; - buildActionMask = 2147483647; - files = ( - C36BEFD02BC0A194002D4AFE /* MLXNN in Frameworks */, - C36BEFD22BC0A194002D4AFE /* MLXRandom in Frameworks */, - C36BEFFA2BC5B996002D4AFE /* Transformers in Frameworks */, - C36BEFCE2BC0A194002D4AFE /* MLX in Frameworks */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; C36BEFDD2BC32988002D4AFE /* Frameworks */ = { isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( C36BEFEF2BC329C5002D4AFE /* ArgumentParser in Frameworks */, + C3E7D94D2CF6C9B20056C095 /* StableDiffusion in Frameworks */, C36BEFF22BC32A9A002D4AFE /* Progress in Frameworks */, - C36BEFE72BC329AB002D4AFE /* StableDiffusion.framework in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -455,19 +278,7 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - C36BF0102BC5CF17002D4AFE /* StableDiffusion.framework in Frameworks */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; - C38935C22B869C7A0037B833 /* Frameworks */ = { - isa = PBXFrameworksBuildPhase; - buildActionMask = 2147483647; - files = ( - C36BEFB22BBDE9D0002D4AFE /* MLXOptimizers in Frameworks */, - C38935D22B869CC40037B833 /* MLXNN in Frameworks */, - C38935D42B869CC40037B833 /* MLXRandom in Frameworks */, - C38935D62B869CC40037B833 /* Transformers in Frameworks */, - C38935D02B869CC40037B833 /* MLX in Frameworks */, + C32A18092CFFD1B70092A5B6 /* StableDiffusion in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -475,7 +286,7 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - C3FBCB212B8520B80007E490 /* MLX in Frameworks */, + C32A18462D00E1490092A5B6 /* MLX in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -483,8 +294,9 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( + C32A17FD2CFFB98A0092A5B6 /* MLXLLM in Frameworks */, C397C59C2B62C6D0004B084D /* ArgumentParser in Frameworks */, - C38935D72B869CCD0037B833 /* LLM.framework in Frameworks */, + C32A17FF2CFFB98A0092A5B6 /* MLXVLM in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -492,8 +304,7 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - C3A8B3D32B92A0880002EFB8 /* MNIST.framework in Frameworks */, - C3A8B3D22B92A0880002EFB8 /* MLXOptimizers in Frameworks */, + C32A18032CFFD1920092A5B6 /* MLXMNIST in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -501,7 +312,7 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - C3A8B3F82B92A3360002EFB8 /* LLM.framework in Frameworks */, + C32A18052CFFD19F0092A5B6 /* MLXLLM in Frameworks */, 81695B412BA373D300F260D8 /* MarkdownUI in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; @@ -517,24 +328,6 @@ path = ViewModels; sourceTree = ""; }; - 92EBB3282CD014B800998339 /* Models */ = { - isa = PBXGroup; - children = ( - C380FFF62C8F5053006428A3 /* Internlm2.swift */, - 7BBD0D6D2BE044A10019C5D7 /* OpenELM.swift */, - 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */, - C34E48EE2B696E6500FCB841 /* Llama.swift */, - C38935E22B86C0FE0037B833 /* Gemma.swift */, - 1C5531792C5AAB4E00B07ECD /* Gemma2.swift */, - C38935DE2B869DD00037B833 /* Phi.swift */, - 1CD79C6F2BD80DE100B6C06F /* Phi3.swift */, - 927B80412C83769400500C13 /* PhiMoE.swift */, - 52A776172B94B5EE00AA6E80 /* Qwen2.swift */, - F24B08392BAF1A65008C8D19 /* Cohere.swift */, - ); - path = Models; - sourceTree = ""; - }; C3056BA52BCD973400A31D04 /* lora */ = { isa = PBXGroup; children = ( @@ -595,18 +388,6 @@ path = "llm-tool"; sourceTree = ""; }; - C34E490E2B69A92900FCB841 /* MNIST */ = { - isa = PBXGroup; - children = ( - C34E490F2B69A92900FCB841 /* MNIST.h */, - C34E49142B69C1E300FCB841 /* Files.swift */, - C3932D562B6A060B00A81055 /* MNIST.swift */, - C3932D582B6A0BE400A81055 /* Random.swift */, - C3C3240C2B6CA792007D2D9A /* README.md */, - ); - path = MNIST; - sourceTree = ""; - }; C34E49222B6A026F00FCB841 /* mnist-tool */ = { isa = PBXGroup; children = ( @@ -616,23 +397,6 @@ path = "mnist-tool"; sourceTree = ""; }; - C36BEFC32BC098F3002D4AFE /* StableDiffusion */ = { - isa = PBXGroup; - children = ( - C3D573062C40702D00857A35 /* README.md */, - C36BEFD32BC0B4A9002D4AFE /* Clip.swift */, - C36BEFC92BC09953002D4AFE /* Configuration.swift */, - C36BEFCB2BC09E53002D4AFE /* Sampler.swift */, - C36BEFD72BC0BD9B002D4AFE /* VAE.swift */, - C36BEFD92BC0BDDC002D4AFE /* UNet.swift */, - C36BEFF32BC349FA002D4AFE /* Load.swift */, - C36BEFF52BC34A46002D4AFE /* StableDiffusion.swift */, - C36BEFF72BC59CE1002D4AFE /* Tokenizer.swift */, - C36BEFFB2BC5BA79002D4AFE /* Image.swift */, - ); - path = StableDiffusion; - sourceTree = ""; - }; C36BEFE12BC32988002D4AFE /* image-tool */ = { isa = PBXGroup; children = ( @@ -663,27 +427,6 @@ path = "Preview Content"; sourceTree = ""; }; - C38935C62B869C7A0037B833 /* LLM */ = { - isa = PBXGroup; - children = ( - 92EBB3282CD014B800998339 /* Models */, - C36BEFAF2BBCBAC2002D4AFE /* Lora.swift */, - C36BEFBA2BBF02CC002D4AFE /* Lora+Data.swift */, - C34E48EF2B696E6500FCB841 /* Configuration.swift */, - C3A8B3AB2B9283150002EFB8 /* Models.swift */, - C38935C72B869C7A0037B833 /* LLM.h */, - C38935E02B869F420037B833 /* LLMModel.swift */, - C343B2772CC8091B00334888 /* SwitchLayers.swift */, - C34E48F62B69832600FCB841 /* README.md */, - C34E48ED2B696E6500FCB841 /* Load.swift */, - C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */, - C373EB722C792DD1004E201E /* KVCache.swift */, - C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */, - 927C784D2C7A578A001E5878 /* SuScaledRotaryEmbedding.swift */, - ); - path = LLM; - sourceTree = ""; - }; C39273672B60697700368D5D = { isa = PBXGroup; children = ( @@ -691,7 +434,7 @@ F8D7023A2BB4E223003D7CF5 /* Package.swift */, C3C36A6C2CA714600099FFA4 /* Configuration */, C3056BA62BCD973400A31D04 /* Data */, - C39273822B606A9200368D5D /* Libraries */, + C397D92E2CD440EF00B87EE2 /* Libraries */, C3A8B3AD2B9294E30002EFB8 /* Applications */, C39273812B606A7400368D5D /* Tools */, C39273752B606A0A00368D5D /* Products */, @@ -704,14 +447,11 @@ children = ( C39273742B606A0A00368D5D /* Tutorial */, C397C58B2B62C6A9004B084D /* llm-tool */, - C34E490D2B69A92900FCB841 /* MNIST.framework */, C34E49212B6A026F00FCB841 /* mnist-tool */, C3288D732B6D9313009FF608 /* LinearModelTraining */, - C38935C52B869C7A0037B833 /* LLM.framework */, C3A8B3B22B9295090002EFB8 /* MNISTTrainer.app */, C3A8B3DC2B92A29E0002EFB8 /* LLMEval.app */, C3056BAB2BCD97B700A31D04 /* LoRATrainingExample.app */, - C36BEFC22BC098F3002D4AFE /* StableDiffusion.framework */, C36BEFE02BC32988002D4AFE /* image-tool */, C36BF0012BC5CE55002D4AFE /* StableDiffusionExample.app */, ); @@ -745,16 +485,6 @@ path = Tools; sourceTree = ""; }; - C39273822B606A9200368D5D /* Libraries */ = { - isa = PBXGroup; - children = ( - C36BEFC32BC098F3002D4AFE /* StableDiffusion */, - C38935C62B869C7A0037B833 /* LLM */, - C34E490E2B69A92900FCB841 /* MNIST */, - ); - path = Libraries; - sourceTree = ""; - }; C3A8B3AD2B9294E30002EFB8 /* Applications */ = { isa = PBXGroup; children = ( @@ -821,32 +551,6 @@ }; /* End PBXGroup section */ -/* Begin PBXHeadersBuildPhase section */ - C34E49082B69A92900FCB841 /* Headers */ = { - isa = PBXHeadersBuildPhase; - buildActionMask = 2147483647; - files = ( - C34E49102B69A92900FCB841 /* MNIST.h in Headers */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; - C36BEFBD2BC098F3002D4AFE /* Headers */ = { - isa = PBXHeadersBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - runOnlyForDeploymentPostprocessing = 0; - }; - C38935C02B869C7A0037B833 /* Headers */ = { - isa = PBXHeadersBuildPhase; - buildActionMask = 2147483647; - files = ( - C38935C82B869C7A0037B833 /* LLM.h in Headers */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXHeadersBuildPhase section */ - /* Begin PBXNativeTarget section */ C3056BAA2BCD97B700A31D04 /* LoRATrainingExample */ = { isa = PBXNativeTarget; @@ -860,10 +564,10 @@ buildRules = ( ); dependencies = ( - C3056BC02BCD984F00A31D04 /* PBXTargetDependency */, ); name = LoRATrainingExample; packageProductDependencies = ( + C32A18062CFFD1AA0092A5B6 /* MLXLLM */, ); productName = LoRATrainingExample; productReference = C3056BAB2BCD97B700A31D04 /* LoRATrainingExample.app */; @@ -884,38 +588,14 @@ name = LinearModelTraining; packageProductDependencies = ( C3288D7A2B6D9339009FF608 /* ArgumentParser */, - C3FBCB2E2B8520F20007E490 /* MLX */, - C3FBCB302B8520F20007E490 /* MLXNN */, - C3FBCB322B8520F20007E490 /* MLXOptimizers */, - C3FBCB342B8520F20007E490 /* MLXRandom */, + C32A18472D00E1540092A5B6 /* MLX */, + C32A18492D00E1540092A5B6 /* MLXNN */, + C32A184B2D00E1540092A5B6 /* MLXOptimizers */, ); productName = LinearFunctionModelTraining; productReference = C3288D732B6D9313009FF608 /* LinearModelTraining */; productType = "com.apple.product-type.tool"; }; - C34E490C2B69A92900FCB841 /* MNIST */ = { - isa = PBXNativeTarget; - buildConfigurationList = C34E49112B69A92900FCB841 /* Build configuration list for PBXNativeTarget "MNIST" */; - buildPhases = ( - C34E49082B69A92900FCB841 /* Headers */, - C34E49092B69A92900FCB841 /* Sources */, - C34E490A2B69A92900FCB841 /* Frameworks */, - C34E490B2B69A92900FCB841 /* Resources */, - ); - buildRules = ( - ); - dependencies = ( - ); - name = MNIST; - packageProductDependencies = ( - C34E491B2B69C43600FCB841 /* Gzip */, - C3FBCB282B8520DA0007E490 /* MLX */, - C3FBCB2A2B8520DA0007E490 /* MLXNN */, - ); - productName = MNIST; - productReference = C34E490D2B69A92900FCB841 /* MNIST.framework */; - productType = "com.apple.product-type.framework"; - }; C34E49202B6A026F00FCB841 /* mnist-tool */ = { isa = PBXNativeTarget; buildConfigurationList = C34E49252B6A026F00FCB841 /* Build configuration list for PBXNativeTarget "mnist-tool" */; @@ -928,41 +608,16 @@ buildRules = ( ); dependencies = ( - C34E492D2B6A028800FCB841 /* PBXTargetDependency */, ); name = "mnist-tool"; packageProductDependencies = ( C34E49282B6A028100FCB841 /* ArgumentParser */, - C3FBCB2C2B8520E80007E490 /* MLXOptimizers */, + C32A18002CFFD1810092A5B6 /* MLXMNIST */, ); productName = "mnist-tool"; productReference = C34E49212B6A026F00FCB841 /* mnist-tool */; productType = "com.apple.product-type.tool"; }; - C36BEFC12BC098F3002D4AFE /* StableDiffusion */ = { - isa = PBXNativeTarget; - buildConfigurationList = C36BEFC82BC098F3002D4AFE /* Build configuration list for PBXNativeTarget "StableDiffusion" */; - buildPhases = ( - C36BEFBD2BC098F3002D4AFE /* Headers */, - C36BEFBE2BC098F3002D4AFE /* Sources */, - C36BEFBF2BC098F3002D4AFE /* Frameworks */, - C36BEFC02BC098F3002D4AFE /* Resources */, - ); - buildRules = ( - ); - dependencies = ( - ); - name = StableDiffusion; - packageProductDependencies = ( - C36BEFCD2BC0A194002D4AFE /* MLX */, - C36BEFCF2BC0A194002D4AFE /* MLXNN */, - C36BEFD12BC0A194002D4AFE /* MLXRandom */, - C36BEFF92BC5B996002D4AFE /* Transformers */, - ); - productName = Image; - productReference = C36BEFC22BC098F3002D4AFE /* StableDiffusion.framework */; - productType = "com.apple.product-type.framework"; - }; C36BEFDF2BC32988002D4AFE /* image-tool */ = { isa = PBXNativeTarget; buildConfigurationList = C36BEFE42BC32988002D4AFE /* Build configuration list for PBXNativeTarget "image-tool" */; @@ -975,12 +630,12 @@ buildRules = ( ); dependencies = ( - C36BEFEA2BC329AB002D4AFE /* PBXTargetDependency */, ); name = "image-tool"; packageProductDependencies = ( C36BEFEE2BC329C5002D4AFE /* ArgumentParser */, C36BEFF12BC32A9A002D4AFE /* Progress */, + C3E7D94C2CF6C9B20056C095 /* StableDiffusion */, ); productName = "image-tool"; productReference = C36BEFE02BC32988002D4AFE /* image-tool */; @@ -998,38 +653,12 @@ buildRules = ( ); dependencies = ( - C36BF0132BC5CF17002D4AFE /* PBXTargetDependency */, ); name = StableDiffusionExample; productName = StableDiffusionEval; productReference = C36BF0012BC5CE55002D4AFE /* StableDiffusionExample.app */; productType = "com.apple.product-type.application"; }; - C38935C42B869C7A0037B833 /* LLM */ = { - isa = PBXNativeTarget; - buildConfigurationList = C38935C92B869C7A0037B833 /* Build configuration list for PBXNativeTarget "LLM" */; - buildPhases = ( - C38935C02B869C7A0037B833 /* Headers */, - C38935C12B869C7A0037B833 /* Sources */, - C38935C22B869C7A0037B833 /* Frameworks */, - C38935C32B869C7A0037B833 /* Resources */, - ); - buildRules = ( - ); - dependencies = ( - ); - name = LLM; - packageProductDependencies = ( - C38935CF2B869CC40037B833 /* MLX */, - C38935D12B869CC40037B833 /* MLXNN */, - C38935D32B869CC40037B833 /* MLXRandom */, - C38935D52B869CC40037B833 /* Transformers */, - C36BEFB12BBDE9D0002D4AFE /* MLXOptimizers */, - ); - productName = LLM; - productReference = C38935C52B869C7A0037B833 /* LLM.framework */; - productType = "com.apple.product-type.framework"; - }; C39273732B606A0A00368D5D /* Tutorial */ = { isa = PBXNativeTarget; buildConfigurationList = C39273792B606A0A00368D5D /* Build configuration list for PBXNativeTarget "Tutorial" */; @@ -1044,7 +673,7 @@ ); name = Tutorial; packageProductDependencies = ( - C3FBCB202B8520B80007E490 /* MLX */, + C32A18452D00E1490092A5B6 /* MLX */, ); productName = Tutorial; productReference = C39273742B606A0A00368D5D /* Tutorial */; @@ -1062,11 +691,12 @@ buildRules = ( ); dependencies = ( - C38935DA2B869CCD0037B833 /* PBXTargetDependency */, ); name = "llm-tool"; packageProductDependencies = ( C397C59B2B62C6D0004B084D /* ArgumentParser */, + C32A17FC2CFFB98A0092A5B6 /* MLXLLM */, + C32A17FE2CFFB98A0092A5B6 /* MLXVLM */, ); productName = "mistral-tool"; productReference = C397C58B2B62C6A9004B084D /* llm-tool */; @@ -1084,11 +714,10 @@ buildRules = ( ); dependencies = ( - C3A8B3D62B92A0880002EFB8 /* PBXTargetDependency */, ); name = MNISTTrainer; packageProductDependencies = ( - C3A8B3D12B92A0880002EFB8 /* MLXOptimizers */, + C32A18022CFFD1920092A5B6 /* MLXMNIST */, ); productName = MNISTTrainer; productReference = C3A8B3B22B9295090002EFB8 /* MNISTTrainer.app */; @@ -1106,11 +735,11 @@ buildRules = ( ); dependencies = ( - C3A8B3FB2B92A3360002EFB8 /* PBXTargetDependency */, ); name = LLMEval; packageProductDependencies = ( 81695B402BA373D300F260D8 /* MarkdownUI */, + C32A18042CFFD19F0092A5B6 /* MLXLLM */, ); productName = LLMEval; productReference = C3A8B3DC2B92A29E0002EFB8 /* LLMEval.app */; @@ -1132,26 +761,15 @@ C3288D722B6D9313009FF608 = { CreatedOnToolsVersion = 15.0.1; }; - C34E490C2B69A92900FCB841 = { - CreatedOnToolsVersion = 15.0.1; - LastSwiftMigration = 1500; - }; C34E49202B6A026F00FCB841 = { CreatedOnToolsVersion = 15.0.1; }; - C36BEFC12BC098F3002D4AFE = { - CreatedOnToolsVersion = 15.3; - LastSwiftMigration = 1530; - }; C36BEFDF2BC32988002D4AFE = { CreatedOnToolsVersion = 15.3; }; C36BF0002BC5CE55002D4AFE = { CreatedOnToolsVersion = 15.3; }; - C38935C42B869C7A0037B833 = { - CreatedOnToolsVersion = 15.2; - }; C39273732B606A0A00368D5D = { CreatedOnToolsVersion = 15.0.1; }; @@ -1178,10 +796,10 @@ packageReferences = ( C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */, C34E491A2B69C43600FCB841 /* XCRemoteSwiftPackageReference "GzipSwift" */, - C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */, - C38935BB2B866BFA0037B833 /* XCRemoteSwiftPackageReference "swift-transformers" */, 81695B3F2BA373D300F260D8 /* XCRemoteSwiftPackageReference "swift-markdown-ui" */, C36BEFF02BC32A8C002D4AFE /* XCRemoteSwiftPackageReference "Progress" */, + C397D8F22CD2F60B00B87EE2 /* XCLocalSwiftPackageReference "Source/.." */, + C32A18442D00E13E0092A5B6 /* XCRemoteSwiftPackageReference "mlx-swift" */, ); productRefGroup = C39273752B606A0A00368D5D /* Products */; projectDirPath = ""; @@ -1189,10 +807,7 @@ targets = ( C39273732B606A0A00368D5D /* Tutorial */, C397C58A2B62C6A9004B084D /* llm-tool */, - C38935C42B869C7A0037B833 /* LLM */, - C36BEFC12BC098F3002D4AFE /* StableDiffusion */, C34E49202B6A026F00FCB841 /* mnist-tool */, - C34E490C2B69A92900FCB841 /* MNIST */, C3288D722B6D9313009FF608 /* LinearModelTraining */, C3A8B3B12B9295090002EFB8 /* MNISTTrainer */, C3A8B3DB2B92A29D0002EFB8 /* LLMEval */, @@ -1216,20 +831,6 @@ ); runOnlyForDeploymentPostprocessing = 0; }; - C34E490B2B69A92900FCB841 /* Resources */ = { - isa = PBXResourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - runOnlyForDeploymentPostprocessing = 0; - }; - C36BEFC02BC098F3002D4AFE /* Resources */ = { - isa = PBXResourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - runOnlyForDeploymentPostprocessing = 0; - }; C36BEFFF2BC5CE55002D4AFE /* Resources */ = { isa = PBXResourcesBuildPhase; buildActionMask = 2147483647; @@ -1239,13 +840,6 @@ ); runOnlyForDeploymentPostprocessing = 0; }; - C38935C32B869C7A0037B833 /* Resources */ = { - isa = PBXResourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - runOnlyForDeploymentPostprocessing = 0; - }; C3A8B3B02B9295090002EFB8 /* Resources */ = { isa = PBXResourcesBuildPhase; buildActionMask = 2147483647; @@ -1284,16 +878,6 @@ ); runOnlyForDeploymentPostprocessing = 0; }; - C34E49092B69A92900FCB841 /* Sources */ = { - isa = PBXSourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - C34E49152B69C1E300FCB841 /* Files.swift in Sources */, - C3932D572B6A060B00A81055 /* MNIST.swift in Sources */, - C3932D592B6A0BE400A81055 /* Random.swift in Sources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; C34E491D2B6A026F00FCB841 /* Sources */ = { isa = PBXSourcesBuildPhase; buildActionMask = 2147483647; @@ -1302,22 +886,6 @@ ); runOnlyForDeploymentPostprocessing = 0; }; - C36BEFBE2BC098F3002D4AFE /* Sources */ = { - isa = PBXSourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - C36BEFD82BC0BD9B002D4AFE /* VAE.swift in Sources */, - C36BEFFC2BC5BA79002D4AFE /* Image.swift in Sources */, - C36BEFF42BC349FA002D4AFE /* Load.swift in Sources */, - C36BEFCC2BC09E53002D4AFE /* Sampler.swift in Sources */, - C36BEFD42BC0B4A9002D4AFE /* Clip.swift in Sources */, - C36BEFCA2BC09953002D4AFE /* Configuration.swift in Sources */, - C36BEFF82BC59CE1002D4AFE /* Tokenizer.swift in Sources */, - C36BEFDA2BC0BDDC002D4AFE /* UNet.swift in Sources */, - C36BEFF62BC34A46002D4AFE /* StableDiffusion.swift in Sources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; C36BEFDC2BC32988002D4AFE /* Sources */ = { isa = PBXSourcesBuildPhase; buildActionMask = 2147483647; @@ -1336,35 +904,6 @@ ); runOnlyForDeploymentPostprocessing = 0; }; - C38935C12B869C7A0037B833 /* Sources */ = { - isa = PBXSourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - C38935E12B869F420037B833 /* LLMModel.swift in Sources */, - C380FFF72C8F5053006428A3 /* Internlm2.swift in Sources */, - F24B083A2BAF1A65008C8D19 /* Cohere.swift in Sources */, - C38935E32B86C0FE0037B833 /* Gemma.swift in Sources */, - C373EB732C792DD1004E201E /* KVCache.swift in Sources */, - C38935CD2B869C870037B833 /* Configuration.swift in Sources */, - 1CD79C702BD80DE100B6C06F /* Phi3.swift in Sources */, - 525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */, - C36BEFBB2BBF02CC002D4AFE /* Lora+Data.swift in Sources */, - C36BEFB02BBCBAC2002D4AFE /* Lora.swift in Sources */, - 927B80422C83769800500C13 /* PhiMoE.swift in Sources */, - 7BBD0D6E2BE044A10019C5D7 /* OpenELM.swift in Sources */, - C38935DF2B869DD00037B833 /* Phi.swift in Sources */, - C38935CE2B869C870037B833 /* Load.swift in Sources */, - C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */, - C3A8B3AC2B9283150002EFB8 /* Models.swift in Sources */, - 1C55317A2C5AAB4E00B07ECD /* Gemma2.swift in Sources */, - C3E786AB2B8D1AEC0004D037 /* Evaluate.swift in Sources */, - C38935CC2B869C870037B833 /* Llama.swift in Sources */, - C343B2782CC8091B00334888 /* SwitchLayers.swift in Sources */, - 927C784E2C7A578A001E5878 /* SuScaledRotaryEmbedding.swift in Sources */, - 52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; C39273702B606A0A00368D5D /* Sources */ = { isa = PBXSourcesBuildPhase; buildActionMask = 2147483647; @@ -1405,44 +944,6 @@ }; /* End PBXSourcesBuildPhase section */ -/* Begin PBXTargetDependency section */ - C3056BC02BCD984F00A31D04 /* PBXTargetDependency */ = { - isa = PBXTargetDependency; - target = C38935C42B869C7A0037B833 /* LLM */; - targetProxy = C3056BBF2BCD984F00A31D04 /* PBXContainerItemProxy */; - }; - C34E492D2B6A028800FCB841 /* PBXTargetDependency */ = { - isa = PBXTargetDependency; - target = C34E490C2B69A92900FCB841 /* MNIST */; - targetProxy = C34E492C2B6A028800FCB841 /* PBXContainerItemProxy */; - }; - C36BEFEA2BC329AB002D4AFE /* PBXTargetDependency */ = { - isa = PBXTargetDependency; - target = C36BEFC12BC098F3002D4AFE /* StableDiffusion */; - targetProxy = C36BEFE92BC329AB002D4AFE /* PBXContainerItemProxy */; - }; - C36BF0132BC5CF17002D4AFE /* PBXTargetDependency */ = { - isa = PBXTargetDependency; - target = C36BEFC12BC098F3002D4AFE /* StableDiffusion */; - targetProxy = C36BF0122BC5CF17002D4AFE /* PBXContainerItemProxy */; - }; - C38935DA2B869CCD0037B833 /* PBXTargetDependency */ = { - isa = PBXTargetDependency; - target = C38935C42B869C7A0037B833 /* LLM */; - targetProxy = C38935D92B869CCD0037B833 /* PBXContainerItemProxy */; - }; - C3A8B3D62B92A0880002EFB8 /* PBXTargetDependency */ = { - isa = PBXTargetDependency; - target = C34E490C2B69A92900FCB841 /* MNIST */; - targetProxy = C3A8B3D52B92A0880002EFB8 /* PBXContainerItemProxy */; - }; - C3A8B3FB2B92A3360002EFB8 /* PBXTargetDependency */ = { - isa = PBXTargetDependency; - target = C38935C42B869C7A0037B833 /* LLM */; - targetProxy = C3A8B3FA2B92A3360002EFB8 /* PBXContainerItemProxy */; - }; -/* End PBXTargetDependency section */ - /* Begin XCBuildConfiguration section */ C3056BB82BCD97B800A31D04 /* Debug */ = { isa = XCBuildConfiguration; @@ -1752,7 +1253,7 @@ }; name = Release; }; - C34E49122B69A92900FCB841 /* Debug */ = { + C34E49262B6A026F00FCB841 /* Debug */ = { isa = XCBuildConfiguration; buildSettings = { ALWAYS_SEARCH_USER_PATHS = NO; @@ -1786,16 +1287,9 @@ CLANG_WARN_UNREACHABLE_CODE = YES; CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; CODE_SIGN_STYLE = Automatic; - COMBINE_HIDPI_IMAGES = YES; COPY_PHASE_STRIP = NO; - CURRENT_PROJECT_VERSION = 1; DEAD_CODE_STRIPPING = YES; DEBUG_INFORMATION_FORMAT = dwarf; - DEFINES_MODULE = YES; - DYLIB_COMPATIBILITY_VERSION = 1; - DYLIB_CURRENT_VERSION = 1; - DYLIB_INSTALL_NAME_BASE = "@rpath"; - ENABLE_MODULE_VERIFIER = YES; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_TESTABILITY = YES; ENABLE_USER_SCRIPT_SANDBOXING = YES; @@ -1813,40 +1307,20 @@ GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; - GENERATE_INFOPLIST_FILE = YES; - INFOPLIST_KEY_NSHumanReadableCopyright = ""; - INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; - IPHONEOS_DEPLOYMENT_TARGET = 17.2; - LD_RUNPATH_SEARCH_PATHS = ( - "$(inherited)", - "@executable_path/../Frameworks", - "@loader_path/Frameworks", - ); LOCALIZATION_PREFERS_STRING_CATALOGS = YES; MACOSX_DEPLOYMENT_TARGET = 14.2; - MARKETING_VERSION = 1.0; - MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++"; - MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20"; MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; MTL_FAST_MATH = YES; - PRODUCT_BUNDLE_IDENTIFIER = mlx.MNIST; - PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; + PRODUCT_NAME = "$(TARGET_NAME)"; SDKROOT = macosx; - SKIP_INSTALL = YES; - SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx"; - SUPPORTS_MACCATALYST = NO; SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; - SWIFT_EMIT_LOC_STRINGS = YES; SWIFT_OPTIMIZATION_LEVEL = "-Onone"; SWIFT_STRICT_CONCURRENCY = complete; SWIFT_VERSION = 5.0; - TARGETED_DEVICE_FAMILY = "1,2"; - VERSIONING_SYSTEM = "apple-generic"; - VERSION_INFO_PREFIX = ""; }; name = Debug; }; - C34E49132B69A92900FCB841 /* Release */ = { + C34E49272B6A026F00FCB841 /* Release */ = { isa = XCBuildConfiguration; buildSettings = { ALWAYS_SEARCH_USER_PATHS = NO; @@ -1880,16 +1354,9 @@ CLANG_WARN_UNREACHABLE_CODE = YES; CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; CODE_SIGN_STYLE = Automatic; - COMBINE_HIDPI_IMAGES = YES; COPY_PHASE_STRIP = NO; - CURRENT_PROJECT_VERSION = 1; DEAD_CODE_STRIPPING = YES; DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; - DEFINES_MODULE = YES; - DYLIB_COMPATIBILITY_VERSION = 1; - DYLIB_CURRENT_VERSION = 1; - DYLIB_INSTALL_NAME_BASE = "@rpath"; - ENABLE_MODULE_VERIFIER = YES; ENABLE_NS_ASSERTIONS = NO; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_USER_SCRIPT_SANDBOXING = YES; @@ -1901,39 +1368,19 @@ GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; - GENERATE_INFOPLIST_FILE = YES; - INFOPLIST_KEY_NSHumanReadableCopyright = ""; - INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; - IPHONEOS_DEPLOYMENT_TARGET = 17.2; - LD_RUNPATH_SEARCH_PATHS = ( - "$(inherited)", - "@executable_path/../Frameworks", - "@loader_path/Frameworks", - ); LOCALIZATION_PREFERS_STRING_CATALOGS = YES; MACOSX_DEPLOYMENT_TARGET = 14.2; - MARKETING_VERSION = 1.0; - MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++"; - MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20"; MTL_ENABLE_DEBUG_INFO = NO; MTL_FAST_MATH = YES; - PRODUCT_BUNDLE_IDENTIFIER = mlx.MNIST; - PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; + PRODUCT_NAME = "$(TARGET_NAME)"; SDKROOT = macosx; - SKIP_INSTALL = YES; - SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx"; - SUPPORTS_MACCATALYST = NO; SWIFT_COMPILATION_MODE = wholemodule; - SWIFT_EMIT_LOC_STRINGS = YES; SWIFT_STRICT_CONCURRENCY = complete; SWIFT_VERSION = 5.0; - TARGETED_DEVICE_FAMILY = "1,2"; - VERSIONING_SYSTEM = "apple-generic"; - VERSION_INFO_PREFIX = ""; }; name = Release; }; - C34E49262B6A026F00FCB841 /* Debug */ = { + C36BEFE52BC32988002D4AFE /* Debug */ = { isa = XCBuildConfiguration; buildSettings = { ALWAYS_SEARCH_USER_PATHS = NO; @@ -1968,7 +1415,6 @@ CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; CODE_SIGN_STYLE = Automatic; COPY_PHASE_STRIP = NO; - DEAD_CODE_STRIPPING = YES; DEBUG_INFORMATION_FORMAT = dwarf; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_TESTABILITY = YES; @@ -1988,7 +1434,7 @@ GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 14.2; + MACOSX_DEPLOYMENT_TARGET = 14.4; MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; MTL_FAST_MATH = YES; PRODUCT_NAME = "$(TARGET_NAME)"; @@ -2000,7 +1446,7 @@ }; name = Debug; }; - C34E49272B6A026F00FCB841 /* Release */ = { + C36BEFE62BC32988002D4AFE /* Release */ = { isa = XCBuildConfiguration; buildSettings = { ALWAYS_SEARCH_USER_PATHS = NO; @@ -2035,7 +1481,6 @@ CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; CODE_SIGN_STYLE = Automatic; COPY_PHASE_STRIP = NO; - DEAD_CODE_STRIPPING = YES; DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; ENABLE_NS_ASSERTIONS = NO; ENABLE_STRICT_OBJC_MSGSEND = YES; @@ -2049,7 +1494,7 @@ GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 14.2; + MACOSX_DEPLOYMENT_TARGET = 14.4; MTL_ENABLE_DEBUG_INFO = NO; MTL_FAST_MATH = YES; PRODUCT_NAME = "$(TARGET_NAME)"; @@ -2060,11 +1505,13 @@ }; name = Release; }; - C36BEFC62BC098F3002D4AFE /* Debug */ = { + C36BF00E2BC5CE56002D4AFE /* Debug */ = { isa = XCBuildConfiguration; buildSettings = { ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; CLANG_ANALYZER_NONNULL = YES; CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; @@ -2093,15 +1540,14 @@ CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; CLANG_WARN_UNREACHABLE_CODE = YES; CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_ENTITLEMENTS = Applications/StableDiffusionExample/StableDiffusionExample.entitlements; CODE_SIGN_STYLE = Automatic; COPY_PHASE_STRIP = NO; CURRENT_PROJECT_VERSION = 1; DEBUG_INFORMATION_FORMAT = dwarf; - DEFINES_MODULE = YES; - DYLIB_COMPATIBILITY_VERSION = 1; - DYLIB_CURRENT_VERSION = 1; - DYLIB_INSTALL_NAME_BASE = "@rpath"; - ENABLE_MODULE_VERIFIER = YES; + DEVELOPMENT_ASSET_PATHS = "\"Applications/StableDiffusionExample/Preview Content\""; + DEVELOPMENT_TEAM = ""; + ENABLE_PREVIEWS = YES; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_TESTABILITY = YES; ENABLE_USER_SCRIPT_SANDBOXING = YES; @@ -2120,45 +1566,45 @@ GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; GENERATE_INFOPLIST_FILE = YES; - INFOPLIST_KEY_NSHumanReadableCopyright = ""; - INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; - IPHONEOS_DEPLOYMENT_TARGET = 17.0; - LD_RUNPATH_SEARCH_PATHS = ( - "@executable_path/Frameworks", - "@loader_path/Frameworks", - ); - "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = ( - "@executable_path/../Frameworks", - "@loader_path/Frameworks", - ); + "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES; + "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES; + "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES; + "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES; + "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES; + "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES; + "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault; + "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + IPHONEOS_DEPLOYMENT_TARGET = 17.2; + LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks"; + "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 14.0; + MACOSX_DEPLOYMENT_TARGET = 14.2; MARKETING_VERSION = 1.0; - MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++"; - MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20"; MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; MTL_FAST_MATH = YES; - PRODUCT_BUNDLE_IDENTIFIER = mlx.Image; - PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; + PRODUCT_BUNDLE_IDENTIFIER = "mlx.StableDiffusionExample${DISAMBIGUATOR}"; + PRODUCT_NAME = "$(TARGET_NAME)"; SDKROOT = auto; - SKIP_INSTALL = YES; SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator"; + SUPPORTS_MACCATALYST = NO; SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; SWIFT_EMIT_LOC_STRINGS = YES; SWIFT_OPTIMIZATION_LEVEL = "-Onone"; SWIFT_STRICT_CONCURRENCY = complete; SWIFT_VERSION = 5.0; TARGETED_DEVICE_FAMILY = "1,2,7"; - VERSIONING_SYSTEM = "apple-generic"; - VERSION_INFO_PREFIX = ""; }; name = Debug; }; - C36BEFC72BC098F3002D4AFE /* Release */ = { + C36BF00F2BC5CE56002D4AFE /* Release */ = { isa = XCBuildConfiguration; buildSettings = { ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; CLANG_ANALYZER_NONNULL = YES; CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; @@ -2187,20 +1633,20 @@ CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; CLANG_WARN_UNREACHABLE_CODE = YES; CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_ENTITLEMENTS = Applications/StableDiffusionExample/StableDiffusionExample.entitlements; CODE_SIGN_STYLE = Automatic; COPY_PHASE_STRIP = NO; CURRENT_PROJECT_VERSION = 1; DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; - DEFINES_MODULE = YES; - DYLIB_COMPATIBILITY_VERSION = 1; - DYLIB_CURRENT_VERSION = 1; - DYLIB_INSTALL_NAME_BASE = "@rpath"; - ENABLE_MODULE_VERIFIER = YES; + DEVELOPMENT_ASSET_PATHS = "\"Applications/StableDiffusionExample/Preview Content\""; + DEVELOPMENT_TEAM = ""; ENABLE_NS_ASSERTIONS = NO; + ENABLE_PREVIEWS = YES; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_USER_SCRIPT_SANDBOXING = YES; GCC_C_LANGUAGE_STANDARD = gnu17; GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 2; GCC_WARN_64_TO_32_BIT_CONVERSION = YES; GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; GCC_WARN_UNDECLARED_SELECTOR = YES; @@ -2208,57 +1654,48 @@ GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; GENERATE_INFOPLIST_FILE = YES; - INFOPLIST_KEY_NSHumanReadableCopyright = ""; - INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; - IPHONEOS_DEPLOYMENT_TARGET = 17.0; - LD_RUNPATH_SEARCH_PATHS = ( - "@executable_path/Frameworks", - "@loader_path/Frameworks", - ); - "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = ( - "@executable_path/../Frameworks", - "@loader_path/Frameworks", - ); + "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES; + "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES; + "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES; + "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES; + "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES; + "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES; + "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault; + "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + IPHONEOS_DEPLOYMENT_TARGET = 17.2; + LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks"; + "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 14.0; + MACOSX_DEPLOYMENT_TARGET = 14.2; MARKETING_VERSION = 1.0; - MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++"; - MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20"; MTL_ENABLE_DEBUG_INFO = NO; MTL_FAST_MATH = YES; - PRODUCT_BUNDLE_IDENTIFIER = mlx.Image; - PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; + PRODUCT_BUNDLE_IDENTIFIER = "mlx.StableDiffusionExample${DISAMBIGUATOR}"; + PRODUCT_NAME = "$(TARGET_NAME)"; SDKROOT = auto; - SKIP_INSTALL = YES; SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator"; + SUPPORTS_MACCATALYST = NO; SWIFT_COMPILATION_MODE = wholemodule; SWIFT_EMIT_LOC_STRINGS = YES; SWIFT_STRICT_CONCURRENCY = complete; SWIFT_VERSION = 5.0; TARGETED_DEVICE_FAMILY = "1,2,7"; - VERSIONING_SYSTEM = "apple-generic"; - VERSION_INFO_PREFIX = ""; }; name = Release; }; - C36BEFE52BC32988002D4AFE /* Debug */ = { + C392736C2B60697700368D5D /* Debug */ = { isa = XCBuildConfiguration; + baseConfigurationReference = C3C36A6B2CA714600099FFA4 /* Build.xcconfig */; buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; + ARCHS = "$(ARCHS_STANDARD)"; ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; - CLANG_ANALYZER_NONNULL = YES; - CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_ENABLE_OBJC_WEAK = YES; CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; CLANG_WARN_BOOL_CONVERSION = YES; CLANG_WARN_COMMA = YES; CLANG_WARN_CONSTANT_CONVERSION = YES; CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_DOCUMENTATION_COMMENTS = YES; CLANG_WARN_EMPTY_BODY = YES; CLANG_WARN_ENUM_CONVERSION = YES; CLANG_WARN_INFINITE_RECURSION = YES; @@ -2266,502 +1703,21 @@ CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; CLANG_WARN_STRICT_PROTOTYPES = YES; CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; CLANG_WARN_UNREACHABLE_CODE = YES; CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - CODE_SIGN_STYLE = Automatic; - COPY_PHASE_STRIP = NO; - DEBUG_INFORMATION_FORMAT = dwarf; + DEAD_CODE_STRIPPING = YES; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_TESTABILITY = YES; - ENABLE_USER_SCRIPT_SANDBOXING = YES; - GCC_C_LANGUAGE_STANDARD = gnu17; - GCC_DYNAMIC_NO_PIC = NO; + EXCLUDED_ARCHS = x86_64; GCC_NO_COMMON_BLOCKS = YES; - GCC_OPTIMIZATION_LEVEL = 0; - GCC_PREPROCESSOR_DEFINITIONS = ( - "DEBUG=1", - "$(inherited)", - ); GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_ABOUT_RETURN_TYPE = YES; GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 14.4; - MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; - MTL_FAST_MATH = YES; - PRODUCT_NAME = "$(TARGET_NAME)"; - SDKROOT = macosx; - SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; - SWIFT_OPTIMIZATION_LEVEL = "-Onone"; - SWIFT_STRICT_CONCURRENCY = complete; - SWIFT_VERSION = 5.0; - }; - name = Debug; - }; - C36BEFE62BC32988002D4AFE /* Release */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; - CLANG_ANALYZER_NONNULL = YES; - CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_ENABLE_OBJC_WEAK = YES; - CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_COMMA = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_DOCUMENTATION_COMMENTS = YES; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; - CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; - CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; - CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; - CLANG_WARN_STRICT_PROTOTYPES = YES; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - CODE_SIGN_STYLE = Automatic; - COPY_PHASE_STRIP = NO; - DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; - ENABLE_NS_ASSERTIONS = NO; - ENABLE_STRICT_OBJC_MSGSEND = YES; - ENABLE_USER_SCRIPT_SANDBOXING = YES; - GCC_C_LANGUAGE_STANDARD = gnu17; - GCC_NO_COMMON_BLOCKS = YES; - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 14.4; - MTL_ENABLE_DEBUG_INFO = NO; - MTL_FAST_MATH = YES; - PRODUCT_NAME = "$(TARGET_NAME)"; - SDKROOT = macosx; - SWIFT_COMPILATION_MODE = wholemodule; - SWIFT_STRICT_CONCURRENCY = complete; - SWIFT_VERSION = 5.0; - }; - name = Release; - }; - C36BF00E2BC5CE56002D4AFE /* Debug */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; - ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; - ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; - CLANG_ANALYZER_NONNULL = YES; - CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_ENABLE_OBJC_WEAK = YES; - CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_COMMA = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_DOCUMENTATION_COMMENTS = YES; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; - CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; - CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; - CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; - CLANG_WARN_STRICT_PROTOTYPES = YES; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - CODE_SIGN_ENTITLEMENTS = Applications/StableDiffusionExample/StableDiffusionExample.entitlements; - CODE_SIGN_STYLE = Automatic; - COPY_PHASE_STRIP = NO; - CURRENT_PROJECT_VERSION = 1; - DEBUG_INFORMATION_FORMAT = dwarf; - DEVELOPMENT_ASSET_PATHS = "\"Applications/StableDiffusionExample/Preview Content\""; - DEVELOPMENT_TEAM = ""; - ENABLE_PREVIEWS = YES; - ENABLE_STRICT_OBJC_MSGSEND = YES; - ENABLE_TESTABILITY = YES; - ENABLE_USER_SCRIPT_SANDBOXING = YES; - GCC_C_LANGUAGE_STANDARD = gnu17; - GCC_DYNAMIC_NO_PIC = NO; - GCC_NO_COMMON_BLOCKS = YES; - GCC_OPTIMIZATION_LEVEL = 0; - GCC_PREPROCESSOR_DEFINITIONS = ( - "DEBUG=1", - "$(inherited)", - ); - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - GENERATE_INFOPLIST_FILE = YES; - "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES; - "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES; - "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES; - "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES; - "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES; - "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES; - "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault; - "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault; - INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; - INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; - IPHONEOS_DEPLOYMENT_TARGET = 17.2; - LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks"; - "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; - LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 14.2; - MARKETING_VERSION = 1.0; - MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; - MTL_FAST_MATH = YES; - PRODUCT_BUNDLE_IDENTIFIER = "mlx.StableDiffusionExample${DISAMBIGUATOR}"; - PRODUCT_NAME = "$(TARGET_NAME)"; - SDKROOT = auto; - SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator"; - SUPPORTS_MACCATALYST = NO; - SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; - SWIFT_EMIT_LOC_STRINGS = YES; - SWIFT_OPTIMIZATION_LEVEL = "-Onone"; - SWIFT_STRICT_CONCURRENCY = complete; - SWIFT_VERSION = 5.0; - TARGETED_DEVICE_FAMILY = "1,2,7"; - }; - name = Debug; - }; - C36BF00F2BC5CE56002D4AFE /* Release */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; - ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; - ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; - CLANG_ANALYZER_NONNULL = YES; - CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_ENABLE_OBJC_WEAK = YES; - CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_COMMA = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_DOCUMENTATION_COMMENTS = YES; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; - CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; - CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; - CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; - CLANG_WARN_STRICT_PROTOTYPES = YES; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - CODE_SIGN_ENTITLEMENTS = Applications/StableDiffusionExample/StableDiffusionExample.entitlements; - CODE_SIGN_STYLE = Automatic; - COPY_PHASE_STRIP = NO; - CURRENT_PROJECT_VERSION = 1; - DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; - DEVELOPMENT_ASSET_PATHS = "\"Applications/StableDiffusionExample/Preview Content\""; - DEVELOPMENT_TEAM = ""; - ENABLE_NS_ASSERTIONS = NO; - ENABLE_PREVIEWS = YES; - ENABLE_STRICT_OBJC_MSGSEND = YES; - ENABLE_USER_SCRIPT_SANDBOXING = YES; - GCC_C_LANGUAGE_STANDARD = gnu17; - GCC_NO_COMMON_BLOCKS = YES; - GCC_OPTIMIZATION_LEVEL = 2; - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - GENERATE_INFOPLIST_FILE = YES; - "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES; - "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES; - "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES; - "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES; - "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES; - "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES; - "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault; - "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault; - INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; - INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; - IPHONEOS_DEPLOYMENT_TARGET = 17.2; - LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks"; - "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; - LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 14.2; - MARKETING_VERSION = 1.0; - MTL_ENABLE_DEBUG_INFO = NO; - MTL_FAST_MATH = YES; - PRODUCT_BUNDLE_IDENTIFIER = "mlx.StableDiffusionExample${DISAMBIGUATOR}"; - PRODUCT_NAME = "$(TARGET_NAME)"; - SDKROOT = auto; - SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator"; - SUPPORTS_MACCATALYST = NO; - SWIFT_COMPILATION_MODE = wholemodule; - SWIFT_EMIT_LOC_STRINGS = YES; - SWIFT_STRICT_CONCURRENCY = complete; - SWIFT_VERSION = 5.0; - TARGETED_DEVICE_FAMILY = "1,2,7"; - }; - name = Release; - }; - C38935CA2B869C7A0037B833 /* Debug */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; - CLANG_ANALYZER_NONNULL = YES; - CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_ENABLE_OBJC_WEAK = YES; - CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_COMMA = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_DOCUMENTATION_COMMENTS = YES; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; - CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; - CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; - CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; - CLANG_WARN_STRICT_PROTOTYPES = YES; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - CODE_SIGN_STYLE = Automatic; - COPY_PHASE_STRIP = NO; - CURRENT_PROJECT_VERSION = 1; - DEAD_CODE_STRIPPING = YES; - DEBUG_INFORMATION_FORMAT = dwarf; - DEFINES_MODULE = YES; - DYLIB_COMPATIBILITY_VERSION = 1; - DYLIB_CURRENT_VERSION = 1; - DYLIB_INSTALL_NAME_BASE = "@rpath"; - ENABLE_MODULE_VERIFIER = YES; - ENABLE_STRICT_OBJC_MSGSEND = YES; - ENABLE_TESTABILITY = YES; - ENABLE_USER_SCRIPT_SANDBOXING = YES; - GCC_C_LANGUAGE_STANDARD = gnu17; - GCC_DYNAMIC_NO_PIC = NO; - GCC_NO_COMMON_BLOCKS = YES; - GCC_OPTIMIZATION_LEVEL = 0; - GCC_PREPROCESSOR_DEFINITIONS = ( - "DEBUG=1", - "$(inherited)", - ); - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - GENERATE_INFOPLIST_FILE = YES; - INFOPLIST_KEY_NSHumanReadableCopyright = ""; - INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; - IPHONEOS_DEPLOYMENT_TARGET = 17.0; - LD_RUNPATH_SEARCH_PATHS = ( - "@executable_path/Frameworks", - "@loader_path/Frameworks", - ); - "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = ( - "@executable_path/../Frameworks", - "@loader_path/Frameworks", - ); - LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 14.0; - MARKETING_VERSION = 1.0; - MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++"; - MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20"; - MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; - MTL_FAST_MATH = YES; - PRODUCT_BUNDLE_IDENTIFIER = mlx.LLM; - PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; - SDKROOT = auto; - SKIP_INSTALL = YES; - SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator"; - SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; - SWIFT_EMIT_LOC_STRINGS = YES; - SWIFT_OPTIMIZATION_LEVEL = "-Onone"; - SWIFT_STRICT_CONCURRENCY = complete; - SWIFT_VERSION = 5.0; - TARGETED_DEVICE_FAMILY = "1,2,7"; - VERSIONING_SYSTEM = "apple-generic"; - VERSION_INFO_PREFIX = ""; - }; - name = Debug; - }; - C38935CB2B869C7A0037B833 /* Release */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; - CLANG_ANALYZER_NONNULL = YES; - CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_ENABLE_OBJC_WEAK = YES; - CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_COMMA = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_DOCUMENTATION_COMMENTS = YES; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; - CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; - CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; - CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; - CLANG_WARN_STRICT_PROTOTYPES = YES; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - CODE_SIGN_STYLE = Automatic; - COPY_PHASE_STRIP = NO; - CURRENT_PROJECT_VERSION = 1; - DEAD_CODE_STRIPPING = YES; - DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; - DEFINES_MODULE = YES; - DYLIB_COMPATIBILITY_VERSION = 1; - DYLIB_CURRENT_VERSION = 1; - DYLIB_INSTALL_NAME_BASE = "@rpath"; - ENABLE_MODULE_VERIFIER = YES; - ENABLE_NS_ASSERTIONS = NO; - ENABLE_STRICT_OBJC_MSGSEND = YES; - ENABLE_USER_SCRIPT_SANDBOXING = YES; - GCC_C_LANGUAGE_STANDARD = gnu17; - GCC_NO_COMMON_BLOCKS = YES; - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - GENERATE_INFOPLIST_FILE = YES; - INFOPLIST_KEY_NSHumanReadableCopyright = ""; - INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; - IPHONEOS_DEPLOYMENT_TARGET = 17.0; - LD_RUNPATH_SEARCH_PATHS = ( - "@executable_path/Frameworks", - "@loader_path/Frameworks", - ); - "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = ( - "@executable_path/../Frameworks", - "@loader_path/Frameworks", - ); - LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 14.0; - MARKETING_VERSION = 1.0; - MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++"; - MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20"; - MTL_ENABLE_DEBUG_INFO = NO; - MTL_FAST_MATH = YES; - PRODUCT_BUNDLE_IDENTIFIER = mlx.LLM; - PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; - SDKROOT = auto; - SKIP_INSTALL = YES; - SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator"; - SWIFT_COMPILATION_MODE = wholemodule; - SWIFT_EMIT_LOC_STRINGS = YES; - SWIFT_STRICT_CONCURRENCY = complete; - SWIFT_VERSION = 5.0; - TARGETED_DEVICE_FAMILY = "1,2,7"; - VERSIONING_SYSTEM = "apple-generic"; - VERSION_INFO_PREFIX = ""; - }; - name = Release; - }; - C392736C2B60697700368D5D /* Debug */ = { - isa = XCBuildConfiguration; - baseConfigurationReference = C3C36A6B2CA714600099FFA4 /* Build.xcconfig */; - buildSettings = { - ARCHS = "$(ARCHS_STANDARD)"; - ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; - CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_COMMA = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; - CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; - CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; - CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; - CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; - CLANG_WARN_STRICT_PROTOTYPES = YES; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - DEAD_CODE_STRIPPING = YES; - ENABLE_STRICT_OBJC_MSGSEND = YES; - ENABLE_TESTABILITY = YES; - EXCLUDED_ARCHS = x86_64; - GCC_NO_COMMON_BLOCKS = YES; - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; ONLY_ACTIVE_ARCH = YES; @@ -3440,15 +2396,6 @@ defaultConfigurationIsVisible = 0; defaultConfigurationName = Release; }; - C34E49112B69A92900FCB841 /* Build configuration list for PBXNativeTarget "MNIST" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - C34E49122B69A92900FCB841 /* Debug */, - C34E49132B69A92900FCB841 /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; C34E49252B6A026F00FCB841 /* Build configuration list for PBXNativeTarget "mnist-tool" */ = { isa = XCConfigurationList; buildConfigurations = ( @@ -3458,15 +2405,6 @@ defaultConfigurationIsVisible = 0; defaultConfigurationName = Release; }; - C36BEFC82BC098F3002D4AFE /* Build configuration list for PBXNativeTarget "StableDiffusion" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - C36BEFC62BC098F3002D4AFE /* Debug */, - C36BEFC72BC098F3002D4AFE /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; C36BEFE42BC32988002D4AFE /* Build configuration list for PBXNativeTarget "image-tool" */ = { isa = XCConfigurationList; buildConfigurations = ( @@ -3485,15 +2423,6 @@ defaultConfigurationIsVisible = 0; defaultConfigurationName = Release; }; - C38935C92B869C7A0037B833 /* Build configuration list for PBXNativeTarget "LLM" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - C38935CA2B869C7A0037B833 /* Debug */, - C38935CB2B869C7A0037B833 /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; C392736B2B60697700368D5D /* Build configuration list for PBXProject "mlx-swift-examples" */ = { isa = XCConfigurationList; buildConfigurations = ( @@ -3541,6 +2470,13 @@ }; /* End XCConfigurationList section */ +/* Begin XCLocalSwiftPackageReference section */ + C397D8F22CD2F60B00B87EE2 /* XCLocalSwiftPackageReference "Source/.." */ = { + isa = XCLocalSwiftPackageReference; + relativePath = Source/..; + }; +/* End XCLocalSwiftPackageReference section */ + /* Begin XCRemoteSwiftPackageReference section */ 81695B3F2BA373D300F260D8 /* XCRemoteSwiftPackageReference "swift-markdown-ui" */ = { isa = XCRemoteSwiftPackageReference; @@ -3550,6 +2486,14 @@ minimumVersion = 2.3.1; }; }; + C32A18442D00E13E0092A5B6 /* XCRemoteSwiftPackageReference "mlx-swift" */ = { + isa = XCRemoteSwiftPackageReference; + repositoryURL = "https://github.com/ml-explore/mlx-swift"; + requirement = { + kind = upToNextMajorVersion; + minimumVersion = 0.21.2; + }; + }; C34E491A2B69C43600FCB841 /* XCRemoteSwiftPackageReference "GzipSwift" */ = { isa = XCRemoteSwiftPackageReference; repositoryURL = "https://github.com/1024jp/GzipSwift"; @@ -3566,14 +2510,6 @@ minimumVersion = 0.4.0; }; }; - C38935BB2B866BFA0037B833 /* XCRemoteSwiftPackageReference "swift-transformers" */ = { - isa = XCRemoteSwiftPackageReference; - repositoryURL = "https://github.com/huggingface/swift-transformers"; - requirement = { - kind = upToNextMajorVersion; - minimumVersion = 0.1.13; - }; - }; C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */ = { isa = XCRemoteSwiftPackageReference; repositoryURL = "https://github.com/apple/swift-argument-parser.git"; @@ -3582,14 +2518,6 @@ minimumVersion = 1.4.0; }; }; - C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */ = { - isa = XCRemoteSwiftPackageReference; - repositoryURL = "https://github.com/ml-explore/mlx-swift"; - requirement = { - kind = upToNextMajorVersion; - minimumVersion = 0.18.0; - }; - }; /* End XCRemoteSwiftPackageReference section */ /* Begin XCSwiftPackageProductDependency section */ @@ -3603,120 +2531,85 @@ package = C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */; productName = ArgumentParser; }; - C34E491B2B69C43600FCB841 /* Gzip */ = { + C32A17FC2CFFB98A0092A5B6 /* MLXLLM */ = { isa = XCSwiftPackageProductDependency; - package = C34E491A2B69C43600FCB841 /* XCRemoteSwiftPackageReference "GzipSwift" */; - productName = Gzip; + package = C397D8F22CD2F60B00B87EE2 /* XCLocalSwiftPackageReference "Source/.." */; + productName = MLXLLM; }; - C34E49282B6A028100FCB841 /* ArgumentParser */ = { - isa = XCSwiftPackageProductDependency; - package = C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */; - productName = ArgumentParser; - }; - C36BEFB12BBDE9D0002D4AFE /* MLXOptimizers */ = { + C32A17FE2CFFB98A0092A5B6 /* MLXVLM */ = { isa = XCSwiftPackageProductDependency; - package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; - productName = MLXOptimizers; + package = C397D8F22CD2F60B00B87EE2 /* XCLocalSwiftPackageReference "Source/.." */; + productName = MLXVLM; }; - C36BEFCD2BC0A194002D4AFE /* MLX */ = { + C32A18002CFFD1810092A5B6 /* MLXMNIST */ = { isa = XCSwiftPackageProductDependency; - package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; - productName = MLX; + package = C397D8F22CD2F60B00B87EE2 /* XCLocalSwiftPackageReference "Source/.." */; + productName = MLXMNIST; }; - C36BEFCF2BC0A194002D4AFE /* MLXNN */ = { + C32A18022CFFD1920092A5B6 /* MLXMNIST */ = { isa = XCSwiftPackageProductDependency; - package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; - productName = MLXNN; + package = C397D8F22CD2F60B00B87EE2 /* XCLocalSwiftPackageReference "Source/.." */; + productName = MLXMNIST; }; - C36BEFD12BC0A194002D4AFE /* MLXRandom */ = { + C32A18042CFFD19F0092A5B6 /* MLXLLM */ = { isa = XCSwiftPackageProductDependency; - package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; - productName = MLXRandom; + package = C397D8F22CD2F60B00B87EE2 /* XCLocalSwiftPackageReference "Source/.." */; + productName = MLXLLM; }; - C36BEFEE2BC329C5002D4AFE /* ArgumentParser */ = { + C32A18062CFFD1AA0092A5B6 /* MLXLLM */ = { isa = XCSwiftPackageProductDependency; - package = C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */; - productName = ArgumentParser; + package = C397D8F22CD2F60B00B87EE2 /* XCLocalSwiftPackageReference "Source/.." */; + productName = MLXLLM; }; - C36BEFF12BC32A9A002D4AFE /* Progress */ = { + C32A18082CFFD1B70092A5B6 /* StableDiffusion */ = { isa = XCSwiftPackageProductDependency; - package = C36BEFF02BC32A8C002D4AFE /* XCRemoteSwiftPackageReference "Progress" */; - productName = Progress; + package = C397D8F22CD2F60B00B87EE2 /* XCLocalSwiftPackageReference "Source/.." */; + productName = StableDiffusion; }; - C36BEFF92BC5B996002D4AFE /* Transformers */ = { + C32A18452D00E1490092A5B6 /* MLX */ = { isa = XCSwiftPackageProductDependency; - package = C38935BB2B866BFA0037B833 /* XCRemoteSwiftPackageReference "swift-transformers" */; - productName = Transformers; + package = C32A18442D00E13E0092A5B6 /* XCRemoteSwiftPackageReference "mlx-swift" */; + productName = MLX; }; - C38935CF2B869CC40037B833 /* MLX */ = { + C32A18472D00E1540092A5B6 /* MLX */ = { isa = XCSwiftPackageProductDependency; - package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; + package = C32A18442D00E13E0092A5B6 /* XCRemoteSwiftPackageReference "mlx-swift" */; productName = MLX; }; - C38935D12B869CC40037B833 /* MLXNN */ = { + C32A18492D00E1540092A5B6 /* MLXNN */ = { isa = XCSwiftPackageProductDependency; - package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; + package = C32A18442D00E13E0092A5B6 /* XCRemoteSwiftPackageReference "mlx-swift" */; productName = MLXNN; }; - C38935D32B869CC40037B833 /* MLXRandom */ = { + C32A184B2D00E1540092A5B6 /* MLXOptimizers */ = { isa = XCSwiftPackageProductDependency; - package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; - productName = MLXRandom; - }; - C38935D52B869CC40037B833 /* Transformers */ = { - isa = XCSwiftPackageProductDependency; - package = C38935BB2B866BFA0037B833 /* XCRemoteSwiftPackageReference "swift-transformers" */; - productName = Transformers; + package = C32A18442D00E13E0092A5B6 /* XCRemoteSwiftPackageReference "mlx-swift" */; + productName = MLXOptimizers; }; - C397C59B2B62C6D0004B084D /* ArgumentParser */ = { + C34E49282B6A028100FCB841 /* ArgumentParser */ = { isa = XCSwiftPackageProductDependency; package = C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */; productName = ArgumentParser; }; - C3A8B3D12B92A0880002EFB8 /* MLXOptimizers */ = { - isa = XCSwiftPackageProductDependency; - package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; - productName = MLXOptimizers; - }; - C3FBCB202B8520B80007E490 /* MLX */ = { - isa = XCSwiftPackageProductDependency; - package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; - productName = MLX; - }; - C3FBCB282B8520DA0007E490 /* MLX */ = { - isa = XCSwiftPackageProductDependency; - package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; - productName = MLX; - }; - C3FBCB2A2B8520DA0007E490 /* MLXNN */ = { - isa = XCSwiftPackageProductDependency; - package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; - productName = MLXNN; - }; - C3FBCB2C2B8520E80007E490 /* MLXOptimizers */ = { - isa = XCSwiftPackageProductDependency; - package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; - productName = MLXOptimizers; - }; - C3FBCB2E2B8520F20007E490 /* MLX */ = { + C36BEFEE2BC329C5002D4AFE /* ArgumentParser */ = { isa = XCSwiftPackageProductDependency; - package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; - productName = MLX; + package = C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */; + productName = ArgumentParser; }; - C3FBCB302B8520F20007E490 /* MLXNN */ = { + C36BEFF12BC32A9A002D4AFE /* Progress */ = { isa = XCSwiftPackageProductDependency; - package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; - productName = MLXNN; + package = C36BEFF02BC32A8C002D4AFE /* XCRemoteSwiftPackageReference "Progress" */; + productName = Progress; }; - C3FBCB322B8520F20007E490 /* MLXOptimizers */ = { + C397C59B2B62C6D0004B084D /* ArgumentParser */ = { isa = XCSwiftPackageProductDependency; - package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; - productName = MLXOptimizers; + package = C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */; + productName = ArgumentParser; }; - C3FBCB342B8520F20007E490 /* MLXRandom */ = { + C3E7D94C2CF6C9B20056C095 /* StableDiffusion */ = { isa = XCSwiftPackageProductDependency; - package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; - productName = MLXRandom; + package = C397D8F22CD2F60B00B87EE2 /* XCLocalSwiftPackageReference "Source/.." */; + productName = StableDiffusion; }; /* End XCSwiftPackageProductDependency section */ }; diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index c3e91f5..51193ce 100644 --- a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "6750e2209d3e8ec777b601627ba31d0346e487e6b7d748ce241ec3d5623411a2", + "originHash" : "347ce608ed233db4ed416d22692a515e7f4fd2fd3eed7904f75bb8b35eb5366c", "pins" : [ { "identity" : "gzipswift", @@ -24,8 +24,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift", "state" : { - "revision" : "d649c62b77c487c25012910b0d02b30283d388ca", - "version" : "0.18.1" + "revision" : "70dbb62128a5a1471a5ab80363430adb33470cab", + "version" : "0.21.2" } }, {