-
Notifications
You must be signed in to change notification settings - Fork 165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add VLM support, refactor common LM code into MLXLMCommon. breaking API changes #151
Changes from all commits
a65bb8c
46ec286
1abe590
0f737cb
ce5bcb7
da07d25
85b4108
2a725a7
c6a10f8
f18213d
82397f5
234c21f
420d2e6
85ccb82
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
// Copyright © 2024 Apple Inc. | ||
|
||
import LLM | ||
import MLX | ||
import MLXLLM | ||
import MLXLMCommon | ||
import MLXRandom | ||
import MarkdownUI | ||
import Metal | ||
|
@@ -159,7 +160,7 @@ class LLMEvaluator { | |
|
||
/// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on | ||
/// more devices. | ||
let modelConfiguration = ModelConfiguration.phi3_5_4bit | ||
let modelConfiguration = ModelRegistry.phi3_5_4bit | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From PR description:
- let modelConfiguration = ModelConfiguration.phi3_5_4bit
+ let modelConfiguration = ModelRegistry.phi3_5_4bit |
||
|
||
/// parameters controlling the output | ||
let generateParameters = GenerateParameters(temperature: 0.6) | ||
|
@@ -185,17 +186,17 @@ class LLMEvaluator { | |
// limit the buffer cache | ||
MLX.GPU.set(cacheLimit: 20 * 1024 * 1024) | ||
|
||
let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration) | ||
{ | ||
let modelContainer = try await LLMModelFactory.shared.loadContainer( | ||
configuration: modelConfiguration | ||
) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From PR description:
- let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration)
- {
+ let modelContainer = try await LLMModelFactory.shared.loadContainer(
+ configuration: modelConfiguration
+ ) { |
||
[modelConfiguration] progress in | ||
Task { @MainActor in | ||
self.modelInfo = | ||
"Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%" | ||
} | ||
} | ||
let numParams = await modelContainer.perform { | ||
[] model, _ in | ||
return model.numParameters() | ||
let numParams = await modelContainer.perform { context in | ||
context.model.numParameters() | ||
} | ||
|
||
self.modelInfo = | ||
|
@@ -217,22 +218,17 @@ class LLMEvaluator { | |
do { | ||
let modelContainer = try await load() | ||
|
||
let messages = [["role": "user", "content": prompt]] | ||
let promptTokens = try await modelContainer.perform { _, tokenizer in | ||
try tokenizer.applyChatTemplate(messages: messages) | ||
} | ||
|
||
// each time you generate you will get something new | ||
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000)) | ||
|
||
let result = await modelContainer.perform { model, tokenizer in | ||
LLM.generate( | ||
promptTokens: promptTokens, parameters: generateParameters, model: model, | ||
tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens | ||
let result = try await modelContainer.perform { context in | ||
let input = try await context.processor.prepare(input: .init(prompt: prompt)) | ||
return try MLXLMCommon.generate( | ||
input: input, parameters: generateParameters, context: context | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See Deprecations section in PR description. Old code still worked but this switches to the newer methods that support both LLM and VLM |
||
) { tokens in | ||
// update the output -- this will make the view show the text as it generates | ||
if tokens.count % displayEveryNTokens == 0 { | ||
let text = tokenizer.decode(tokens: tokens) | ||
let text = context.tokenizer.decode(tokens: tokens) | ||
Task { @MainActor in | ||
self.output = text | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
import Foundation | ||
import LLM | ||
import MLX | ||
import MLXLLM | ||
|
||
@Observable | ||
final class DeviceStat: @unchecked Sendable { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,10 @@ | ||
// Copyright © 2024 Apple Inc. | ||
|
||
import MLX | ||
import MLXMNIST | ||
import MLXNN | ||
import MLXOptimizers | ||
import MLXRandom | ||
import MNIST | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just moving to consistent naming with swiftpm |
||
import SwiftUI | ||
|
||
struct TrainingView: View { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,8 +6,8 @@ | |
// | ||
|
||
import MLX | ||
import MLXMNIST | ||
import MLXNN | ||
import MNIST | ||
import SwiftUI | ||
|
||
struct Canvas: View { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See PR description -- split LLM -> LLM and LMCommon. Switched local names to match what people get via swiftpm (MLXLLM, etc.).