-
Notifications
You must be signed in to change notification settings - Fork 676
feat(gpt-oss): Add CPU offload optimizer, differential LR/WD, and more #2205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…checkpoint loading This PR introduces several features for memory-efficient fine-tuning of large models: ## CPU Offload Optimizer - New `CPUOffloadOptimizersContainer` using DeepSpeed's CPUAdam - Keeps optimizer states (momentum, variance) on CPU to reduce GPU memory - Supports DTensor/FSDP2 with proper gradient redistribution - Includes `cleanup()` method for releasing pinned memory resources - Shape validation for gradient buffers with automatic reallocation ## Differential Learning Rates & Weight Decay - New `LRMultipliers` and `WeightDecayMultipliers` config dataclasses - Per-parameter-group settings: embeddings, output, attention, experts, routers, norms, mhc, bias - `OptimizersContainerWithParamGroups` for standard (non-CPU-offload) differential training - Pre-compiled regex patterns for efficient parameter classification ## HuggingFace Checkpoint Loading - Direct safetensors loading bypassing dcp.load() pickle issues with DTensors - Rank 0 broadcast pattern for distributed loading efficiency - Proper handling for both distributed and non-distributed modes - Uses configured `export_dtype` instead of hardcoded bfloat16 - Filter quantization tensor keys when converting unquantized checkpoints ## Loss Function - Added `ignore_index=-100` to cross_entropy_loss for padding token support - Exported `CROSS_ENTROPY_IGNORE_INDEX` constant ## Code Quality Improvements - Extracted `_load_hf_state_dict_to_model()` helper to reduce duplication - Added `Counter` for efficient dtype counting - Named constants for magic numbers - Documented private `_local_tensor` API usage with noqa comments ## Reference Configuration - Added `persona_zeta_49k.toml` as known good config for GPT-OSS 20B fine-tuning - Demonstrates CPU offload, differential LR/WD, 49K context, sample packing - Validated on 2x nvidia rtx pro 6000 blackwell gpu's with FSDP sharding
|
Hi @eous! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds comprehensive memory-efficient fine-tuning capabilities for large GPT-OSS models, introducing CPU offload optimization, differential learning rates/weight decay, improved HuggingFace checkpoint loading, and sample packing for datasets.
Key Changes:
- CPU offload optimizer using DeepSpeed's CPUAdam to reduce GPU memory by keeping optimizer states on CPU
- Differential learning rates and weight decay per parameter group (embeddings, attention, experts, routers, norms)
- Direct safetensors loading bypassing DCP pickle issues with DTensors
- Sample packing dataset that preserves boundaries with proper padding mask support
- MoE enhancements: router/expert bias support, TopK-then-Score vs Score-then-TopK routing modes, bincount replacing histc
- Attention sink LSE renormalization for GPT-OSS models
- YaRN RoPE scaling for extended context
Reviewed changes
Copilot reviewed 23 out of 23 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
| torchtitan/components/optimizer.py | Adds CPUOffloadOptimizersContainer and OptimizersContainerWithParamGroups for differential LR/WD; includes parameter classification logic |
| torchtitan/components/loss.py | Adds CROSS_ENTROPY_IGNORE_INDEX constant and ignore_index parameter to cross_entropy_loss for padding token support |
| torchtitan/components/checkpoint.py | Implements direct safetensors loading with _load_hf_safetensors_unquantized and _load_hf_state_dict_to_model helpers; adds MXFP4 dequantization |
| torchtitan/config/job_config.py | Adds LRMultipliers, WeightDecayMultipliers dataclasses; adds cpu_offload, pack_samples, add_bos_eos, freeze_router_bias config options |
| torchtitan/hf_datasets/text_datasets.py | Implements HuggingFacePackedDataset for sample packing; adds many dataset configurations (hardcoded paths issue) |
| torchtitan/models/moe/moe.py | Adds use_router_bias, topk_before_score, use_expert_bias parameters; replaces histc with bincount for reliability |
| torchtitan/models/gpt_oss/model/model.py | Implements attention sink LSE renormalization; adds YaRN RoPE scaling for extended context |
| torchtitan/models/gpt_oss/model/moe.py | Adds optional expert biases support, compute_dtype parameter, caching optimizations, DeepEP integration |
| torchtitan/models/gpt_oss/model/state_dict_adapter.py | Adds load_hf_safetensors_direct method, MXFP4 dequantization, expert weight transposition handling |
| torchtitan/models/gpt_oss/model/args.py | Adds moe_impl configuration for standard vs DeepEP backend |
| torchtitan/models/gpt_oss/infra/parallelize.py | Adds freeze_router_bias logic, DeepEP support, Float8 tensor parallel variants |
| torchtitan/models/gpt_oss/infra/expert_parallel.py | Adds conditional bias distribution (only when use_expert_bias=True) |
| torchtitan/models/gpt_oss/train_configs/persona_zeta_49k.toml | Reference configuration demonstrating CPU offload and differential LR/WD (has hardcoded paths) |
| torchtitan/models/gpt_oss/test_attention_sink.py | New test file validating LSE renormalization equivalence and edge cases |
| torchtitan/models/gpt_oss/freeze_router_bias.py | Helper utilities for freezing router biases during fine-tuning |
| torchtitan/train.py | Adds differential LR logging and CPU offload optimizer sync after checkpoint load |
| scripts/checkpoint_conversion/convert_to_hf.py | Filters quantization tensor keys when converting unquantized checkpoints |
| torchtitan/distributed/utils.py | Enables autocast for all non-FSDP cases (not just DDP) |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| import re | ||
| _PARAM_GROUP_PATTERNS = { | ||
| # Bias pattern FIRST - catches all bias parameters | ||
| 'bias': re.compile(r'\.bias$'), |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pattern for matching biases is too broad: r'\.bias$' will match any parameter ending in .bias, including potentially non-bias parameters. While this is likely correct for the current model architecture, it would be more robust to be more specific (e.g., matching known bias parameter locations like in linear layers, attention, etc.) or to document that this assumes all .bias parameters should receive the bias weight decay treatment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is intentional. The pattern matches all bias parameters because standard practice in transformer training is to disable weight decay on all bias terms, not just specific ones. This is well-established (see AdamW paper, HuggingFace Trainer defaults, PyTorch examples). The pattern is placed first in the ordered dict specifically to catch biases before any other pattern matches.
| Optimizer.__init__(self, all_params, optimizer_kwargs) | ||
|
|
||
|
|
||
| class OptimizersContainerWithParamGroups(OptimizersContainer): |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class does not call OptimizersContainer.init during initialization. (OptimizersContainerWithParamGroups.init may be missing a call to a base class init)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The base class init is intentionally not called because the subclasses have different optimizer creation logic:
- OptimizersContainer creates one optimizer per model_part
- OptimizersContainerWithParamGroups creates a single optimizer with param groups
- CPUOffloadOptimizersContainer creates CPU-based optimizers with shadow params
However, all subclasses:
- Manually set required attributes (self.optimizers, self.model_parts)
- Call self._post_init() which invokes Optimizer.init(self, all_params, optimizer_kwargs)
This ensures proper initialization of the Optimizer base class functionality (hooks, state tracking) while allowing different optimizer creation strategies. See lines 375, 448, and 602.
| pass | ||
|
|
||
|
|
||
| class CPUOffloadOptimizersContainer(OptimizersContainer): |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class does not call OptimizersContainer.init during initialization. (CPUOffloadOptimizersContainer.init may be missing a call to a base class init)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The base class init is intentionally not called because the subclasses have different optimizer creation logic:
- OptimizersContainer creates one optimizer per model_part
- OptimizersContainerWithParamGroups creates a single optimizer with param groups
- CPUOffloadOptimizersContainer creates CPU-based optimizers with shadow params
However, all subclasses:
- Manually set required attributes (self.optimizers, self.model_parts)
- Call self._post_init() which invokes Optimizer.init(self, all_params, optimizer_kwargs)
This ensures proper initialization of the Optimizer base class functionality (hooks, state tracking) while allowing different optimizer creation strategies. See lines 375, 448, and 602.
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
- Remove local path datasets from text_datasets.py - Rename persona_zeta_49k.toml to 20b_finetune_reference.toml - Update config to use generic paths and c4 dataset
Config dataclasses existed but no implementation code used them.
Mask num_padding-1 positions (where INPUT is padding), not num_padding. The first PAD token should be predicted from the last real token.
- Rename gpu_param -> model_param (clearer for DTensor case) - Add debug logging for skipped params without gradients
- Catch specific exceptions (AttributeError, KeyError) instead of bare Exception - Add debug logging when fallback is triggered
wwwjn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contribution! It would be better if you could help split this PR into several smaller pieces. It would be easier to review and discuss which feature to accept, for example:
- Optimizer CPU offload (How many users are actually interested in this comparing to other techniques to save memory?)
- HuggingFacePackedDataset (Is this common practice in SFT?
- Differential learning rate / weight decay
- GPT-OSS configs + MoE fix
- HuggingFace Checkpoint Loading (Why does this can not be handled properly with StateDictAdapter + DCP today?)
| router_logits = self.gate(x) | ||
|
|
||
| if self.topk_before_score: | ||
| # TopK-then-Score: Select top-K by raw logits, then apply score_func (GPT-OSS style) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shuhuayu can we double check if this part is different between GPT-oss and Deepseek-v3?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll look into it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing, but please split the PR into multiple ones. This PR contains both core components change and model changes, but many of them are unrelated to each other.
For checkpoint.py, DCP already has the ability to save and load the HF format. If the current implementation in DCP is insufficient, please file a draft PR or RFC or request to PyTorch. TorchTitan's philosophy is to reuse PyTorch core components as much as possible. DCP is one of them.
For optimizer.py, while I can see the need of multiple parameter groups, it is unclear that we need to create a different container. You can have multiple optimizers and multiple parameter groups. The reason why we need one model part to one optimizer is because of distributed state_dict design. We can probably eliminate the usage of distributed state_dict now, but that has not been done yet. The multiple parameter groups support along deserves a PR, seperating from CPUOffload optimzer.
For loss.py, the change is not needed. Or at I don't see why the change is required as it simply just explains the PyTorch core API usage.
For maybe_enable_amp, it also needs another PR because it changes the original logic.
Model changes deserver another PR(s) as well.
| # Standard PyTorch ignore index for cross-entropy loss (skips padding tokens) | ||
| CROSS_ENTROPY_IGNORE_INDEX = -100 | ||
|
|
||
|
|
||
| def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: | ||
| """Common cross-entropy loss function for Transformer models training.""" | ||
| """Common cross-entropy loss function for Transformer models training. | ||
| Labels with CROSS_ENTROPY_IGNORE_INDEX (-100) are ignored, which is the | ||
| standard PyTorch convention for skipping padding tokens in loss computation. | ||
| """ | ||
| return torch.nn.functional.cross_entropy( | ||
| pred.flatten(0, 1).float(), labels.flatten(0, 1) | ||
| pred.flatten(0, 1).float(), labels.flatten(0, 1), ignore_index=CROSS_ENTROPY_IGNORE_INDEX | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need this change? The default ignore_index is -100.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a mental note I added while I was working on the dataset loader changes, it can be backed out.
| return torch.autocast( | ||
| device_type, | ||
| dtype=TORCH_DTYPE_MAP[mixed_precision_param], | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iirc, AMP doesn't work well with PP or TP
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't extensively tested this path (a full finetune) but it seemed to work with TP/EP but it can be backed out.
Absolutely, I wasn't sure the best approach and after I saw gpt-oss got promoted out of experiments I figured I ran out of time trying to polish my improvements. I will submit these as seperate PR's. Additional comments inline below.
Unfortunately as a gpu-poor; I was having to get creative to train my persona series of models at interesting context lengths (https://huggingface.co/eousphoros/persona_zeta_20b_49k was trained using the code in this pr)
I can't say how common it is, as I am hobbyist but I have found it to have given me the best results for the datasets I am working with. (https://huggingface.co/datasets/eousphoros/persona_zeta_train dataset the model above was trained with)
SFT fine-tuning optimizations
SFT fine-tuning correctness
This is one of the older changes so I am not recalling all of the reasons for the changes atm but I do recall some issues with transposed vs not transposed alignment issues when loading huggingface mxfp4 and bf16 versions of gpt-oss into torch titan. |
|
This monolithic PR has been split into 7 focused PRs for easier review: #2210 - Differential Learning Rate / Weight Decay Adds per-parameter-group learning rates and weight decay multipliers for fine-grained control during fine-tuning. Key changes:
#2211 - Attention Sink LSE Fix (Bug Fix) Fixes mathematically incorrect sigmoid-based attention sink implementation with proper LSE renormalization. Key changes:
#2212 - MoE Routing Enhancements Adds GPT-OSS style MoE routing with TopK-then-Score strategy. Key changes:
#2213 - AMP Fix for TP Removes overly conservative restriction that disabled mixed precision for TP-only configurations. Key changes:
#2214 - Expert Bias Support Adds optional expert biases (mlp1_bias, mlp2_bias) required for loading GPT-OSS pretrained models. Key changes:
#2215 - Router/Expert Bias Freezing Config options to freeze router and/or expert biases during MoE fine-tuning. Key changes:
#2216 - YaRN RoPE Extensions Implements YaRN (Yet another RoPE extensioN) for context length extension (4096 → 131072 tokens). Key changes:
Not Included (per reviewer feedback) |
This PR introduces several features for memory-efficient fine-tuning of large models:
CPU Offload Optimizer
CPUOffloadOptimizersContainerusing DeepSpeed's CPUAdamcleanup()method for releasing pinned memory resourcesDifferential Learning Rates & Weight Decay
LRMultipliersandWeightDecayMultipliersconfig dataclassesOptimizersContainerWithParamGroupsfor standard (non-CPU-offload) differential trainingHuggingFace Checkpoint Loading
export_dtypeinstead of hardcoded bfloat16Loss Function
ignore_index=-100to cross_entropy_loss for padding token supportCROSS_ENTROPY_IGNORE_INDEXconstantCode Quality Improvements
_load_hf_state_dict_to_model()helper to reduce duplicationCounterfor efficient dtype counting_local_tensorAPI usage with noqa commentsReference Configuration
persona_zeta_49k.tomlas known good config for GPT-OSS 20B fine-tuning