-
Notifications
You must be signed in to change notification settings - Fork 790
feat: Add LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning #4321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Implements LoRA module in burn-nn enabling efficient fine-tuning of large models by adding trainable low-rank matrices while keeping base weights frozen. Key additions: - LoRA layer support for Linear via `LoraAdaptable` trait - Adapter persistence for saving/loading trained LoRA weights independently - Comprehensive configuration (rank, alpha, dropout, bias modes, RSLoRA) - `merge()` for zero-overhead inference after training - burn-book documentation and lora-finetuning example Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
b667004 to
b05cbb4
Compare
|
Open questions: i. Unsure if the lora section in the burn-book should be under ii. Should pub struct LoraConfig {
pub init_a: Initializer,
pub init_b: Initializer,
// ...
} |
nathanielsimard
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we should implement a more "generalized" approach to LoRA. Specifically, the logic for the update should reside within the Param<Tensor<B, 2>> type itself. In the current Linear layer implementation, parameters are retrieved using the self.weights.val() function. If this function were responsible for performing the weight update, calculating
Yes I assumed this would be another approach. So do you mean something like this? let output = self.weight.val(); // Returns W + A @ B * scaling (if LoRA attached/enabled)Would lora matrices/ Something like this? pub struct Param<T: Parameter> {
pub id: ParamId,
// Base weight W (frozen)
state: OnceCell<T>,
// LoRA state
lora_a: Option<Param<T>>,
lora_b: Option<Param<T>>,
lora_scaling: f64,
// ... other fields
}But this would drastically change the visiting logic traversal, and would turn this into a recursive type which is probably a terrible idea. Did you have something clever/cleaner in mind? |
I would not make it recursive, but yes I would put extra tensors in the paramter, but hardcoded for simplicity: trait Parameter {
type LoraParams
}
pub struct Param<T: Parameter> {
pub id: ParamId,
// Base weight W (frozen)
state: OnceCell<T>,
// LoRA state
loras: Option<P::LoraParams>,
// ... other fields
}Then when we implement parameter for |
Completely agree with you, I attempted this originally because I thought maybe it wouldn't be such a burden for users, but also because I didn't think we would have wanted LoRA embedded in the Parameters directly since I believed the Parameter shouldn't know about any particular training methods. I will look into the architecture a little more in the direction you were thinking, devise something feasible with the same level of control and features so we can discuss about it further before implementation. Thanks for your input. |
Checklist
cargo run-checkscommand has been executed.Related Issues/PRs
#2943
Changes
Added a new
loramodule in theburn-nncrate which implements a considerable set of features inspired by the PEFT library.Features
alpha / sqrt(rank)instead ofalpha / rank)None(fully frozen) orLoraOnly(train bias only)Not implemented for the sake of concision
I figure these can be implemented in separate PRs and this particular change is to layout the foundations.
Linearsupported (Conv, Embedding)Testing
An example crate
examples/lora-finetuning/is implemented to test the features and are visualized using the existingSupervisedTrainingimplementation.Non-exhaustive list of features tested:
i. Saving adapters which saves lora matrices and the configuration
ii. Loading adapters which load the lora matrices and the configurations and apply them to the base model to train
iii. Roundtrip verification ensuring that a model loaded from disk (base model + adapters) produces identical inference results to the original trained mode