Skip to content

Conversation

@eous
Copy link

@eous eous commented Jan 7, 2026

This PR introduces several features for memory-efficient fine-tuning of large models:

CPU Offload Optimizer

  • New CPUOffloadOptimizersContainer using DeepSpeed's CPUAdam
  • Keeps optimizer states (momentum, variance) on CPU to reduce GPU memory
  • Supports DTensor/FSDP2 with proper gradient redistribution
  • Includes cleanup() method for releasing pinned memory resources
  • Shape validation for gradient buffers with automatic reallocation

Differential Learning Rates & Weight Decay

  • New LRMultipliers and WeightDecayMultipliers config dataclasses
  • Per-parameter-group settings: embeddings, output, attention, experts, routers, norms, mhc, bias
  • OptimizersContainerWithParamGroups for standard (non-CPU-offload) differential training
  • Pre-compiled regex patterns for efficient parameter classification

HuggingFace Checkpoint Loading

  • Direct safetensors loading bypassing dcp.load() pickle issues with DTensors
  • Rank 0 broadcast pattern for distributed loading efficiency
  • Proper handling for both distributed and non-distributed modes
  • Uses configured export_dtype instead of hardcoded bfloat16
  • Filter quantization tensor keys when converting unquantized checkpoints

Loss Function

  • Added ignore_index=-100 to cross_entropy_loss for padding token support
  • Exported CROSS_ENTROPY_IGNORE_INDEX constant

Code Quality Improvements

  • Extracted _load_hf_state_dict_to_model() helper to reduce duplication
  • Added Counter for efficient dtype counting
  • Named constants for magic numbers
  • Documented private _local_tensor API usage with noqa comments

Reference Configuration

  • Added persona_zeta_49k.toml as known good config for GPT-OSS 20B fine-tuning
  • Demonstrates CPU offload, differential LR/WD, 49K context, sample packing
  • Validated on 2x nvidia rtx pro 6000 blackwell gpu's with FSDP sharding

…checkpoint loading

This PR introduces several features for memory-efficient fine-tuning of large models:

 ## CPU Offload Optimizer
- New `CPUOffloadOptimizersContainer` using DeepSpeed's CPUAdam
- Keeps optimizer states (momentum, variance) on CPU to reduce GPU memory
- Supports DTensor/FSDP2 with proper gradient redistribution
- Includes `cleanup()` method for releasing pinned memory resources
- Shape validation for gradient buffers with automatic reallocation

 ## Differential Learning Rates & Weight Decay
- New `LRMultipliers` and `WeightDecayMultipliers` config dataclasses
- Per-parameter-group settings: embeddings, output, attention, experts, routers, norms, mhc, bias
- `OptimizersContainerWithParamGroups` for standard (non-CPU-offload) differential training
- Pre-compiled regex patterns for efficient parameter classification

 ## HuggingFace Checkpoint Loading
- Direct safetensors loading bypassing dcp.load() pickle issues with DTensors
- Rank 0 broadcast pattern for distributed loading efficiency
- Proper handling for both distributed and non-distributed modes
- Uses configured `export_dtype` instead of hardcoded bfloat16
- Filter quantization tensor keys when converting unquantized checkpoints

 ## Loss Function
- Added `ignore_index=-100` to cross_entropy_loss for padding token support
- Exported `CROSS_ENTROPY_IGNORE_INDEX` constant

 ## Code Quality Improvements
- Extracted `_load_hf_state_dict_to_model()` helper to reduce duplication
- Added `Counter` for efficient dtype counting
- Named constants for magic numbers
- Documented private `_local_tensor` API usage with noqa comments

 ## Reference Configuration
- Added `persona_zeta_49k.toml` as known good config for GPT-OSS 20B fine-tuning
- Demonstrates CPU offload, differential LR/WD, 49K context, sample packing
- Validated on 2x nvidia rtx pro 6000 blackwell gpu's with FSDP sharding
Copilot AI review requested due to automatic review settings January 7, 2026 01:29
@meta-cla
Copy link

meta-cla bot commented Jan 7, 2026

Hi @eous!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds comprehensive memory-efficient fine-tuning capabilities for large GPT-OSS models, introducing CPU offload optimization, differential learning rates/weight decay, improved HuggingFace checkpoint loading, and sample packing for datasets.

Key Changes:

  • CPU offload optimizer using DeepSpeed's CPUAdam to reduce GPU memory by keeping optimizer states on CPU
  • Differential learning rates and weight decay per parameter group (embeddings, attention, experts, routers, norms)
  • Direct safetensors loading bypassing DCP pickle issues with DTensors
  • Sample packing dataset that preserves boundaries with proper padding mask support
  • MoE enhancements: router/expert bias support, TopK-then-Score vs Score-then-TopK routing modes, bincount replacing histc
  • Attention sink LSE renormalization for GPT-OSS models
  • YaRN RoPE scaling for extended context

Reviewed changes

Copilot reviewed 23 out of 23 changed files in this pull request and generated 14 comments.

Show a summary per file
File Description
torchtitan/components/optimizer.py Adds CPUOffloadOptimizersContainer and OptimizersContainerWithParamGroups for differential LR/WD; includes parameter classification logic
torchtitan/components/loss.py Adds CROSS_ENTROPY_IGNORE_INDEX constant and ignore_index parameter to cross_entropy_loss for padding token support
torchtitan/components/checkpoint.py Implements direct safetensors loading with _load_hf_safetensors_unquantized and _load_hf_state_dict_to_model helpers; adds MXFP4 dequantization
torchtitan/config/job_config.py Adds LRMultipliers, WeightDecayMultipliers dataclasses; adds cpu_offload, pack_samples, add_bos_eos, freeze_router_bias config options
torchtitan/hf_datasets/text_datasets.py Implements HuggingFacePackedDataset for sample packing; adds many dataset configurations (hardcoded paths issue)
torchtitan/models/moe/moe.py Adds use_router_bias, topk_before_score, use_expert_bias parameters; replaces histc with bincount for reliability
torchtitan/models/gpt_oss/model/model.py Implements attention sink LSE renormalization; adds YaRN RoPE scaling for extended context
torchtitan/models/gpt_oss/model/moe.py Adds optional expert biases support, compute_dtype parameter, caching optimizations, DeepEP integration
torchtitan/models/gpt_oss/model/state_dict_adapter.py Adds load_hf_safetensors_direct method, MXFP4 dequantization, expert weight transposition handling
torchtitan/models/gpt_oss/model/args.py Adds moe_impl configuration for standard vs DeepEP backend
torchtitan/models/gpt_oss/infra/parallelize.py Adds freeze_router_bias logic, DeepEP support, Float8 tensor parallel variants
torchtitan/models/gpt_oss/infra/expert_parallel.py Adds conditional bias distribution (only when use_expert_bias=True)
torchtitan/models/gpt_oss/train_configs/persona_zeta_49k.toml Reference configuration demonstrating CPU offload and differential LR/WD (has hardcoded paths)
torchtitan/models/gpt_oss/test_attention_sink.py New test file validating LSE renormalization equivalence and edge cases
torchtitan/models/gpt_oss/freeze_router_bias.py Helper utilities for freezing router biases during fine-tuning
torchtitan/train.py Adds differential LR logging and CPU offload optimizer sync after checkpoint load
scripts/checkpoint_conversion/convert_to_hf.py Filters quantization tensor keys when converting unquantized checkpoints
torchtitan/distributed/utils.py Enables autocast for all non-FSDP cases (not just DDP)

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

import re
_PARAM_GROUP_PATTERNS = {
# Bias pattern FIRST - catches all bias parameters
'bias': re.compile(r'\.bias$'),
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pattern for matching biases is too broad: r'\.bias$' will match any parameter ending in .bias, including potentially non-bias parameters. While this is likely correct for the current model architecture, it would be more robust to be more specific (e.g., matching known bias parameter locations like in linear layers, attention, etc.) or to document that this assumes all .bias parameters should receive the bias weight decay treatment.

Copilot uses AI. Check for mistakes.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional. The pattern matches all bias parameters because standard practice in transformer training is to disable weight decay on all bias terms, not just specific ones. This is well-established (see AdamW paper, HuggingFace Trainer defaults, PyTorch examples). The pattern is placed first in the ordered dict specifically to catch biases before any other pattern matches.

Optimizer.__init__(self, all_params, optimizer_kwargs)


class OptimizersContainerWithParamGroups(OptimizersContainer):
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class does not call OptimizersContainer.init during initialization. (OptimizersContainerWithParamGroups.init may be missing a call to a base class init)

Copilot uses AI. Check for mistakes.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base class init is intentionally not called because the subclasses have different optimizer creation logic:

  • OptimizersContainer creates one optimizer per model_part
  • OptimizersContainerWithParamGroups creates a single optimizer with param groups
  • CPUOffloadOptimizersContainer creates CPU-based optimizers with shadow params

However, all subclasses:

  1. Manually set required attributes (self.optimizers, self.model_parts)
  2. Call self._post_init() which invokes Optimizer.init(self, all_params, optimizer_kwargs)

This ensures proper initialization of the Optimizer base class functionality (hooks, state tracking) while allowing different optimizer creation strategies. See lines 375, 448, and 602.

pass


class CPUOffloadOptimizersContainer(OptimizersContainer):
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class does not call OptimizersContainer.init during initialization. (CPUOffloadOptimizersContainer.init may be missing a call to a base class init)

Copilot uses AI. Check for mistakes.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base class init is intentionally not called because the subclasses have different optimizer creation logic:

  • OptimizersContainer creates one optimizer per model_part
  • OptimizersContainerWithParamGroups creates a single optimizer with param groups
  • CPUOffloadOptimizersContainer creates CPU-based optimizers with shadow params

However, all subclasses:

  1. Manually set required attributes (self.optimizers, self.model_parts)
  2. Call self._post_init() which invokes Optimizer.init(self, all_params, optimizer_kwargs)

This ensures proper initialization of the Optimizer base class functionality (hooks, state tracking) while allowing different optimizer creation strategies. See lines 375, 448, and 602.

@meta-cla
Copy link

meta-cla bot commented Jan 7, 2026

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 7, 2026
eous added 5 commits January 6, 2026 21:00
- Remove local path datasets from text_datasets.py
- Rename persona_zeta_49k.toml to 20b_finetune_reference.toml
- Update config to use generic paths and c4 dataset
Config dataclasses existed but no implementation code used them.
Mask num_padding-1 positions (where INPUT is padding), not num_padding.
The first PAD token should be predicted from the last real token.
- Rename gpu_param -> model_param (clearer for DTensor case)
- Add debug logging for skipped params without gradients
- Catch specific exceptions (AttributeError, KeyError) instead of bare Exception
- Add debug logging when fallback is triggered
@wwwjn wwwjn requested a review from shuhuayu January 7, 2026 22:28
Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution! It would be better if you could help split this PR into several smaller pieces. It would be easier to review and discuss which feature to accept, for example:

  • Optimizer CPU offload (How many users are actually interested in this comparing to other techniques to save memory?)
  • HuggingFacePackedDataset (Is this common practice in SFT?
  • Differential learning rate / weight decay
  • GPT-OSS configs + MoE fix
  • HuggingFace Checkpoint Loading (Why does this can not be handled properly with StateDictAdapter + DCP today?)

router_logits = self.gate(x)

if self.topk_before_score:
# TopK-then-Score: Select top-K by raw logits, then apply score_func (GPT-OSS style)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shuhuayu can we double check if this part is different between GPT-oss and Deepseek-v3?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into it.

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing, but please split the PR into multiple ones. This PR contains both core components change and model changes, but many of them are unrelated to each other.

For checkpoint.py, DCP already has the ability to save and load the HF format. If the current implementation in DCP is insufficient, please file a draft PR or RFC or request to PyTorch. TorchTitan's philosophy is to reuse PyTorch core components as much as possible. DCP is one of them.

For optimizer.py, while I can see the need of multiple parameter groups, it is unclear that we need to create a different container. You can have multiple optimizers and multiple parameter groups. The reason why we need one model part to one optimizer is because of distributed state_dict design. We can probably eliminate the usage of distributed state_dict now, but that has not been done yet. The multiple parameter groups support along deserves a PR, seperating from CPUOffload optimzer.

For loss.py, the change is not needed. Or at I don't see why the change is required as it simply just explains the PyTorch core API usage.

For maybe_enable_amp, it also needs another PR because it changes the original logic.

Model changes deserver another PR(s) as well.

Comment on lines +18 to 30
# Standard PyTorch ignore index for cross-entropy loss (skips padding tokens)
CROSS_ENTROPY_IGNORE_INDEX = -100


def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""Common cross-entropy loss function for Transformer models training."""
"""Common cross-entropy loss function for Transformer models training.
Labels with CROSS_ENTROPY_IGNORE_INDEX (-100) are ignored, which is the
standard PyTorch convention for skipping padding tokens in loss computation.
"""
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1).float(), labels.flatten(0, 1)
pred.flatten(0, 1).float(), labels.flatten(0, 1), ignore_index=CROSS_ENTROPY_IGNORE_INDEX
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need this change? The default ignore_index is -100.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a mental note I added while I was working on the dataset loader changes, it can be backed out.

return torch.autocast(
device_type,
dtype=TORCH_DTYPE_MAP[mixed_precision_param],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc, AMP doesn't work well with PP or TP

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't extensively tested this path (a full finetune) but it seemed to work with TP/EP but it can be backed out.

@eous
Copy link
Author

eous commented Jan 8, 2026

Thanks for your contribution! It would be better if you could help split this PR into several smaller pieces. It would be easier to review and discuss which feature to accept, for example:

Absolutely, I wasn't sure the best approach and after I saw gpt-oss got promoted out of experiments I figured I ran out of time trying to polish my improvements. I will submit these as seperate PR's. Additional comments inline below.

* Optimizer CPU offload (How many users are actually interested in this comparing to other techniques to save memory?)

Unfortunately as a gpu-poor; I was having to get creative to train my persona series of models at interesting context lengths (https://huggingface.co/eousphoros/persona_zeta_20b_49k was trained using the code in this pr)

* HuggingFacePackedDataset (Is this common practice in SFT?

I can't say how common it is, as I am hobbyist but I have found it to have given me the best results for the datasets I am working with. (https://huggingface.co/datasets/eousphoros/persona_zeta_train dataset the model above was trained with)

* Differential learning rate / weight decay

SFT fine-tuning optimizations

* GPT-OSS configs + MoE fix

SFT fine-tuning correctness

* HuggingFace Checkpoint Loading (Why does this can not be handled properly with StateDictAdapter + DCP today?)

This is one of the older changes so I am not recalling all of the reasons for the changes atm but I do recall some issues with transposed vs not transposed alignment issues when loading huggingface mxfp4 and bf16 versions of gpt-oss into torch titan.

@eous
Copy link
Author

eous commented Jan 8, 2026

This monolithic PR has been split into 7 focused PRs for easier review:

#2210 - Differential Learning Rate / Weight Decay

Adds per-parameter-group learning rates and weight decay multipliers for fine-grained control during fine-tuning.

Key changes:

  • LRMultipliers and WeightDecayMultipliers dataclasses in job_config.py
  • classify_parameters_for_groups() with regex-based pattern matching
  • OptimizersContainerWithParamGroups class
  • Pattern groups: bias, embeddings, output, attention, experts, routers, norms
  • Bias matched first, takes precedence for all *.bias parameters

#2211 - Attention Sink LSE Fix (Bug Fix)

Fixes mathematically incorrect sigmoid-based attention sink implementation with proper LSE renormalization.

Key changes:

  • Replace sigmoid(lse - sink) with exp(old_lse - new_lse) renormalization
  • Add clamping [-20, 0] for numerical stability
  • Mathematically equivalent to HuggingFace's concat+softmax approach

#2212 - MoE Routing Enhancements

Adds GPT-OSS style MoE routing with TopK-then-Score strategy.

Key changes:

  • topk_before_score: bool - GPT-OSS uses TopK-then-Score (vs DeepSeek's Score-then-TopK)
  • use_router_bias: bool - GPT-OSS has learned router biases
  • use_expert_bias: bool - GPT-OSS has learned expert biases
  • Fix histc → bincount for more reliable token counting

#2213 - AMP Fix for TP

Removes overly conservative restriction that disabled mixed precision for TP-only configurations.

Key changes:

  • Enable torch.autocast for TP-only training (operates at operator level, orthogonal to TP)
  • Clarify that PP uses schedule-based execution (unaffected)

#2214 - Expert Bias Support

Adds optional expert biases (mlp1_bias, mlp2_bias) required for loading GPT-OSS pretrained models.

Key changes:

  • use_expert_bias parameter in GptOssGroupedExperts
  • compute_dtype parameter for configurable compute precision
  • Cache invalidation after DTensor parallelization
  • Handle None biases in both for-loop and grouped_mm paths

#2215 - Router/Expert Bias Freezing

Config options to freeze router and/or expert biases during MoE fine-tuning.

Key changes:

  • freeze_router_bias: bool - preserve pretrained routing behavior
  • freeze_expert_bias: bool - preserve expert bias values
  • Warnings when freeze options enabled but no biases found

#2216 - YaRN RoPE Extensions

Implements YaRN (Yet another RoPE extensioN) for context length extension (4096 → 131072 tokens).

Key changes:

  • Frequency correction functions for smooth RoPE interpolation
  • precompute_rope_cache with YaRN scaling (rope_factor, beta_fast/slow)
  • mscale attention scaling factor (0.1 * ln(rope_factor) + 1.0) for numerical stability

Not Included (per reviewer feedback)

  • CPU Offload Optimizer - needs separate discussion
  • HuggingFace Packed Dataset - questions if common practice
  • HuggingFace Checkpoint Loading - DCP should handle this
    @wwwjn @fegin

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants