Skip to content

Conversation

@eous
Copy link

@eous eous commented Jan 8, 2026

… for MoE fine-tuning

Add config options to freeze router and/or expert biases during MoE fine-tuning, preserving pretrained routing behavior and expert bias values.

Changes to job_config.py:

  • Add freeze_router_bias: bool = False
  • Add freeze_expert_bias: bool = False
  • Document dependency on use_router_bias/use_expert_bias in MoEArgs

Changes to parallelize.py:

  • Add freeze_moe_biases() function
  • Apply freezing before parallelization in parallelize_gptoss()
  • Add warnings when freeze options enabled but no biases found

Note: These options require the model config to have use_router_bias=True and/or use_expert_bias=True in MoEArgs (e.g., GPT-OSS models).

… for MoE fine-tuning

Add config options to freeze router and/or expert biases during MoE
fine-tuning, preserving pretrained routing behavior and expert bias values.

Changes to job_config.py:
- Add freeze_router_bias: bool = False
- Add freeze_expert_bias: bool = False
- Document dependency on use_router_bias/use_expert_bias in MoEArgs

Changes to parallelize.py:
- Add freeze_moe_biases() function
- Apply freezing before parallelization in parallelize_gptoss()
- Add warnings when freeze options enabled but no biases found

Note: These options require the model config to have use_router_bias=True
and/or use_expert_bias=True in MoEArgs (e.g., GPT-OSS models).
Copilot AI review requested due to automatic review settings January 8, 2026 23:50
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 8, 2026
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds configuration options to freeze router and expert biases during MoE (Mixture of Experts) fine-tuning, which helps preserve pretrained routing behavior and prevent instability from bias updates.

Key changes:

  • Added freeze_router_bias and freeze_expert_bias boolean configuration fields to the Training class
  • Implemented freeze_moe_biases() function to freeze MoE bias parameters
  • Integrated the freezing logic into parallelize_gptoss() before model parallelization

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
torchtitan/config/job_config.py Added two new boolean config fields with comprehensive documentation explaining their purpose and dependencies on MoEArgs
torchtitan/models/gpt_oss/infra/parallelize.py Implemented freeze_moe_biases function and integrated it into parallelize_gptoss with appropriate logging and warnings

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

if freeze_router and "moe.router.gate.bias" in name:
param.requires_grad = False
router_frozen += 1
elif freeze_expert and ("experts.mlp1_bias" in name or "experts.mlp2_bias" in name):
Copy link

Copilot AI Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here will miss freezing expert biases when both freeze_router and freeze_expert are True. The elif condition means that if a parameter name contains "moe.router.gate.bias", it will never check for expert bias patterns even if freeze_expert is True. This should be two separate if statements to allow both types of parameters to be checked independently.

Suggested change
elif freeze_expert and ("experts.mlp1_bias" in name or "experts.mlp2_bias" in name):
if freeze_expert and ("experts.mlp1_bias" in name or "experts.mlp2_bias" in name):

Copilot uses AI. Check for mistakes.
Comment on lines +40 to +68
def freeze_moe_biases(
model: nn.Module,
freeze_router: bool = False,
freeze_expert: bool = False,
) -> tuple[int, int]:
"""
Freeze router gate biases and/or expert biases in all MoE layers.
This is recommended for fine-tuning MoE models to preserve pretrained
routing behavior and prevent instability from bias updates.
Args:
model: The model containing MoE layers with router gates and experts.
freeze_router: Whether to freeze router gate biases.
freeze_expert: Whether to freeze expert biases (mlp1_bias, mlp2_bias).
Returns:
Tuple of (router_frozen_count, expert_frozen_count).
"""
router_frozen = 0
expert_frozen = 0
for name, param in model.named_parameters():
if freeze_router and "moe.router.gate.bias" in name:
param.requires_grad = False
router_frozen += 1
elif freeze_expert and ("experts.mlp1_bias" in name or "experts.mlp2_bias" in name):
param.requires_grad = False
expert_frozen += 1
return router_frozen, expert_frozen
Copy link

Copilot AI Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The repository has comprehensive unit and integration test coverage. The new freeze_moe_biases function and the integration in parallelize_gptoss should have test coverage to verify:

  1. The function correctly freezes router biases when freeze_router=True
  2. The function correctly freezes expert biases when freeze_expert=True
  3. Both types can be frozen simultaneously
  4. Warnings are logged when biases are not found
  5. The counts returned are accurate

Consider adding a unit test for freeze_moe_biases and an integration test that validates the freezing behavior in a training scenario.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant