Skip to content

Conversation

@snowmead
Copy link

@snowmead snowmead commented Jan 14, 2026

Checklist

  • Confirmed that cargo run-checks command has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#2943

Changes

Added a new lora module in the burn-nn crate which implements a considerable set of features inspired by the PEFT library.

Features

  • Training LoRA adapters on linear layers with frozen base weights
  • Configurable rank, alpha scaling, and dropout
  • RSLoRA scaling option (alpha / sqrt(rank) instead of alpha / rank)
  • Multiple initialization methods: Kaiming, Gaussian, Zeros
  • Bias training modes: None (fully frozen) or LoraOnly (train bias only)
  • Merging LoRA weights into base layer for zero-overhead inference
  • Saving adapters (LoRA matrices + config) independently of base model
  • Loading adapters onto fresh base models
  • Runtime adapter swapping without reloading base weights

Not implemented for the sake of concision

I figure these can be implemented in separate PRs and this particular change is to layout the foundations.

  • Layers: Only Linear supported (Conv, Embedding)
  • Variants: No QLoRA, DoRA, LoRA+, or AdaLoRA
  • Multi-adapter: No fusion or simultaneous adapters
  • Convenience: No auto-wrap by pattern or architecture presets

Testing

An example crate examples/lora-finetuning/ is implemented to test the features and are visualized using the existing SupervisedTraining implementation.

Non-exhaustive list of features tested:

i. Saving adapters which saves lora matrices and the configuration
ii. Loading adapters which load the lora matrices and the configurations and apply them to the base model to train
iii. Roundtrip verification ensuring that a model loaded from disk (base model + adapters) produces identical inference results to the original trained mode

Implements LoRA module in burn-nn enabling efficient fine-tuning of large
models by adding trainable low-rank matrices while keeping base weights frozen.

Key additions:
- LoRA layer support for Linear via `LoraAdaptable` trait
- Adapter persistence for saving/loading trained LoRA weights independently
- Comprehensive configuration (rank, alpha, dropout, bias modes, RSLoRA)
- `merge()` for zero-overhead inference after training
- burn-book documentation and lora-finetuning example

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@snowmead snowmead force-pushed the feat/lora-finetuning branch from b667004 to b05cbb4 Compare January 14, 2026 04:35
@snowmead
Copy link
Author

Open questions:

i. Unsure if the lora section in the burn-book should be under advanced - "burn-book/src/advanced/lora.md"

ii. Should LoraConfig instead accept an Initializer as the init field type for complete flexibility and control or should I keep the LoraInit enum?

pub struct LoraConfig {
      pub init_a: Initializer,
      pub init_b: Initializer,
      // ...
  }

@snowmead snowmead marked this pull request as ready for review January 14, 2026 05:30
Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we should implement a more "generalized" approach to LoRA. Specifically, the logic for the update should reside within the Param<Tensor<B, 2>> type itself. In the current Linear layer implementation, parameters are retrieved using the self.weights.val() function. If this function were responsible for performing the weight update, calculating $W = X + AB$, then any module using 2D tensors as parameters could support LoRA adaptations natively, without requiring modifications to the existing module logic.

@snowmead
Copy link
Author

snowmead commented Jan 14, 2026

@nathanielsimard

I believe we should implement a more "generalized" approach to LoRA. Specifically, the logic for the update should reside within the Param<Tensor<B, 2>> type itself. In the current Linear layer implementation, parameters are retrieved using the self.weights.val() function. If this function were responsible for performing the weight update, calculating W = X + A B , then any module using 2D tensors as parameters could support LoRA adaptations natively, without requiring modifications to the existing module logic.

Yes I assumed this would be another approach. So do you mean something like this?

let output = self.weight.val();  // Returns W + A @ B * scaling (if LoRA attached/enabled)

Would lora matrices/Params and configurations be nested inside the base Param, since having this one layer up would not be generalized as you describe.

Something like this?

 pub struct Param<T: Parameter> {
      pub id: ParamId,
      // Base weight W (frozen)
      state: OnceCell<T>,                    

      // LoRA state
      lora_a: Option<Param<T>>,
      lora_b: Option<Param<T>>,
      lora_scaling: f64,

      // ... other fields
  }

But this would drastically change the visiting logic traversal, and would turn this into a recursive type which is probably a terrible idea.

Did you have something clever/cleaner in mind?

@nathanielsimard
Copy link
Member

Something like this?

 pub struct Param<T: Parameter> {
      pub id: ParamId,
      // Base weight W (frozen)
      state: OnceCell<T>,                    

      // LoRA state
      lora_a: Option<Param<T>>,
      lora_b: Option<Param<T>>,
      lora_scaling: f64,

      // ... other fields
  }

But this would drastically change the visiting logic traversal, and would turn this into a recursive type which is probably a terrible idea.

Did you have something clever/cleaner in mind?

I would not make it recursive, but yes I would put extra tensors in the paramter, but hardcoded for simplicity:

  trait Parameter {
     type LoraParams
  }
  pub struct Param<T: Parameter> {
       pub id: ParamId,
       // Base weight W (frozen)
       state: OnceCell<T>,                    
 
       // LoRA state
       loras: Option<P::LoraParams>,
       // ... other fields
   }

Then when we implement parameter for Tensor we can select the right lora parameters.
I'm not saying this is an easy feature to implement, but if we can make it work, the generalization is great, users won't have to modify their modules to support Lora training.

@snowmead
Copy link
Author

snowmead commented Jan 14, 2026

@nathanielsimard

Something like this?

 pub struct Param<T: Parameter> {
      pub id: ParamId,
      // Base weight W (frozen)
      state: OnceCell<T>,                    

      // LoRA state
      lora_a: Option<Param<T>>,
      lora_b: Option<Param<T>>,
      lora_scaling: f64,

      // ... other fields
  }

But this would drastically change the visiting logic traversal, and would turn this into a recursive type which is probably a terrible idea.
Did you have something clever/cleaner in mind?

I would not make it recursive, but yes I would put extra tensors in the paramter, but hardcoded for simplicity:

  trait Parameter {
     type LoraParams
  }
  pub struct Param<T: Parameter> {
       pub id: ParamId,
       // Base weight W (frozen)
       state: OnceCell<T>,                    
 
       // LoRA state
       loras: Option<P::LoraParams>,
       // ... other fields
   }

Then when we implement parameter for Tensor we can select the right lora parameters. I'm not saying this is an easy feature to implement, but if we can make it work, the generalization is great, users won't have to modify their modules to support Lora training.

Completely agree with you, I attempted this originally because I thought maybe it wouldn't be such a burden for users, but also because I didn't think we would have wanted LoRA embedded in the Parameters directly since I believed the Parameter shouldn't know about any particular training methods.

I will look into the architecture a little more in the direction you were thinking, devise something feasible with the same level of control and features so we can discuss about it further before implementation.

Thanks for your input.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants