Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add VLM support, refactor common LM code into MLXLMCommon. breaking API changes #151

Merged
merged 14 commits into from
Dec 10, 2024
Merged
7 changes: 4 additions & 3 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:

mac_build_and_test:
macos:
xcode: 15.2.0
xcode: 16.0.0
resource_class: macos.m1.medium.gen1
steps:
- checkout
Expand All @@ -35,8 +35,9 @@ jobs:
xcrun --show-sdk-build-version
swift --version
find . -name Package.resolved -exec rm {} \;
xcodebuild -skipPackagePluginValidation -scheme llm-tool
xcodebuild -skipPackagePluginValidation -scheme mnist-tool
xcodebuild -scheme llm-tool
xcodebuild -scheme image-tool
xcodebuild -scheme mnist-tool

workflows:
build_and_test:
Expand Down
30 changes: 13 additions & 17 deletions Applications/LLMEval/ContentView.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// Copyright © 2024 Apple Inc.

import LLM
import MLX
import MLXLLM
import MLXLMCommon
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See PR description -- split LLM -> LLM and LMCommon. Switched local names to match what people get via swiftpm (MLXLLM, etc.).

import MLXRandom
import MarkdownUI
import Metal
Expand Down Expand Up @@ -159,7 +160,7 @@ class LLMEvaluator {

/// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on
/// more devices.
let modelConfiguration = ModelConfiguration.phi3_5_4bit
let modelConfiguration = ModelRegistry.phi3_5_4bit
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From PR description:

  • constants for models have moved from ModelConfiguration to ModelRegistry
  • this is MLXLM.ModelRegistry and there is also MLXVLM.ModelRegistry
-    let modelConfiguration = ModelConfiguration.phi3_5_4bit
+    let modelConfiguration = ModelRegistry.phi3_5_4bit


/// parameters controlling the output
let generateParameters = GenerateParameters(temperature: 0.6)
Expand All @@ -185,17 +186,17 @@ class LLMEvaluator {
// limit the buffer cache
MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)

let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration)
{
let modelContainer = try await LLMModelFactory.shared.loadContainer(
configuration: modelConfiguration
) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From PR description:

  • the loadModelContainer() function is now LLMModelFactory.shared.loadContainer()
  • there is a new VLMModelFactory with identical methods for loading VLMs
-     let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration)
-    {
+     let modelContainer = try await LLMModelFactory.shared.loadContainer(
+          configuration: modelConfiguration
+    ) {

[modelConfiguration] progress in
Task { @MainActor in
self.modelInfo =
"Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%"
}
}
let numParams = await modelContainer.perform {
[] model, _ in
return model.numParameters()
let numParams = await modelContainer.perform { context in
context.model.numParameters()
}

self.modelInfo =
Expand All @@ -217,22 +218,17 @@ class LLMEvaluator {
do {
let modelContainer = try await load()

let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
try tokenizer.applyChatTemplate(messages: messages)
}

// each time you generate you will get something new
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))

let result = await modelContainer.perform { model, tokenizer in
LLM.generate(
promptTokens: promptTokens, parameters: generateParameters, model: model,
tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
let result = try await modelContainer.perform { context in
let input = try await context.processor.prepare(input: .init(prompt: prompt))
return try MLXLMCommon.generate(
input: input, parameters: generateParameters, context: context
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See Deprecations section in PR description. Old code still worked but this switches to the newer methods that support both LLM and VLM

) { tokens in
// update the output -- this will make the view show the text as it generates
if tokens.count % displayEveryNTokens == 0 {
let text = tokenizer.decode(tokens: tokens)
let text = context.tokenizer.decode(tokens: tokens)
Task { @MainActor in
self.output = text
}
Expand Down
2 changes: 1 addition & 1 deletion Applications/LLMEval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ The example application uses Phi2 model by default, see [ContentView.swift](Cont
let modelConfiguration = ModelConfiguration.phi4bit
```

There are some pre-configured models in [LLM/Models.swift](../../Libraries/LLM/Models.swift#L62)
There are some pre-configured models in [MLXLLM/LLMModelFactory.swift](../../Libraries/MLXLLM/LLMModelFactory.swift#L78)
and you can load any weights from Hugging Face where there
is a model architecture defined and you have enough
memory.
Expand Down
2 changes: 1 addition & 1 deletion Applications/LLMEval/ViewModels/DeviceStat.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import Foundation
import LLM
import MLX
import MLXLLM

@Observable
final class DeviceStat: @unchecked Sendable {
Expand Down
57 changes: 27 additions & 30 deletions Applications/LoRATrainingExample/ContentView.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// Copyright © 2024 Apple Inc.

import LLM
import MLX
import MLXLLM
import MLXLMCommon
import MLXNN
import MLXOptimizers
import MLXRandom
Expand Down Expand Up @@ -122,7 +123,7 @@ class LoRAEvaluator {

var output = ""

private let modelConfiguration = ModelConfiguration.mistral7B4bit
private let modelConfiguration = ModelRegistry.mistral7B4bit
private var model: ModelState = .idle

private let loraLayers = 4
Expand All @@ -141,8 +142,9 @@ class LoRAEvaluator {
progress = .init(title: "Loading \(name)", current: 0, limit: 1)
}

let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration)
{
let modelContainer = try await LLMModelFactory.shared.loadContainer(
configuration: modelConfiguration
) {
progress in
Task { @MainActor in
self.progress = .init(
Expand All @@ -160,7 +162,7 @@ class LoRAEvaluator {

private func loadLoRAData(name: String) throws -> [String]? {
if let url = Bundle.main.url(forResource: name, withExtension: "jsonl") {
return try LLM.loadLoRAData(url: url)
return try MLXLLM.loadLoRAData(url: url)
}
return nil
}
Expand Down Expand Up @@ -196,9 +198,9 @@ class LoRAEvaluator {
let modelContainer = try await loadModel()

// apply LoRA adapters and train
await modelContainer.perform { model, _ in
await modelContainer.perform { context in
LoRATrain.convert(
model: model, layers: loraLayers(model: model))
model: context.model, layers: loraLayers(model: context.model))
}

let train = try loadLoRAData(name: "train")
Expand All @@ -208,11 +210,11 @@ class LoRAEvaluator {
return
}

try await modelContainer.perform { model, tokenizer in
try await modelContainer.perform { context in
let optimizer = Adam(learningRate: learningRate)
try LoRATrain.train(
model: model, train: train, validate: valid, optimizer: optimizer,
tokenizer: tokenizer,
model: context.model, train: train, validate: valid, optimizer: optimizer,
tokenizer: context.tokenizer,
parameters: parameters
) { progress in
Task { @MainActor in
Expand Down Expand Up @@ -240,9 +242,10 @@ class LoRAEvaluator {
return
}

let loss = await modelContainer.perform { model, tokenizer in
let loss = await modelContainer.perform { context in
LoRATrain.evaluate(
model: model, dataset: test, tokenizer: tokenizer, batchSize: 1, batchCount: 0)
model: context.model, dataset: test,
tokenizer: context.tokenizer, batchSize: 1, batchCount: 0)
}

self.progress = nil
Expand All @@ -269,26 +272,20 @@ class LoRAEvaluator {

let modelContainer = try await loadModel()

let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
try tokenizer.applyChatTemplate(messages: messages)
}

// evaluate
let result = await modelContainer.perform { model, tokenizer in
LLM.generate(
promptTokens: promptTokens, parameters: generateParameters, model: model,
tokenizer: tokenizer,
extraEOSTokens: modelConfiguration.extraEOSTokens,
didGenerate: { tokens in
if tokens.count % evaluateShowEvery == 0 {
let fullOutput = tokenizer.decode(tokens: tokens)
Task { @MainActor in
self.output = fullOutput
}
let result = try await modelContainer.perform { context in
let input = try await context.processor.prepare(input: .init(prompt: prompt))
return try MLXLMCommon.generate(
input: input, parameters: generateParameters, context: context
) { tokens in
if tokens.count % evaluateShowEvery == 0 {
let fullOutput = context.tokenizer.decode(tokens: tokens)
Task { @MainActor in
self.output = fullOutput
}
return tokens.count >= maxTokens ? .stop : .more
})
}
return tokens.count >= maxTokens ? .stop : .more
}
}

self.output = result.output
Expand Down
2 changes: 1 addition & 1 deletion Applications/MNISTTrainer/ContentView.swift
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// Copyright © 2024 Apple Inc.

import MLX
import MLXMNIST
import MLXNN
import MLXOptimizers
import MLXRandom
import MNIST
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just moving to consistent naming with swiftpm

import SwiftUI

struct TrainingView: View {
Expand Down
2 changes: 1 addition & 1 deletion Applications/MNISTTrainer/PredictionView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
//

import MLX
import MLXMNIST
import MLXNN
import MNIST
import SwiftUI

struct Canvas: View {
Expand Down
Loading