-
Notifications
You must be signed in to change notification settings - Fork 676
fix(gpt-oss): correct attention sink from sigmoid to LSE renormalization #2211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The previous sigmoid-based attention sink implementation was mathematically incorrect. This fix uses proper LSE (Log-Sum-Exp) renormalization that is equivalent to HuggingFace's concat+softmax approach. Mathematical equivalence: - HF approach: concat sink logit to scores, softmax over K+1, drop sink position - Our approach: compute combined_lse = logsumexp([lse, sink]), then renormalize output by exp(old_lse - new_lse) Changes: - Replace sigmoid(lse - sink) with proper LSE renormalization - Add clamping [-20, 0] for numerical stability - Add comprehensive test suite validating equivalence to HF reference Reference: HuggingFace transformers/integrations/flex_attention.py lines 309-322
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes a mathematical error in the attention sink implementation by replacing an incorrect sigmoid-based approach with proper LSE (Log-Sum-Exp) renormalization. The new implementation is mathematically equivalent to HuggingFace's concat+softmax approach and includes comprehensive test coverage.
Key Changes:
- Replace sigmoid-based attention sink rescaling with LSE renormalization in the attention mechanism
- Add numerical stability clamping to the renormalization factor computation
- Implement comprehensive test suite validating mathematical equivalence with HuggingFace reference implementation
Reviewed changes
Copilot reviewed 2 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| torchtitan/models/gpt_oss/model/model.py | Replaced incorrect sigmoid(lse - sink) with proper LSE renormalization using logsumexp and exp(lse - combined_lse), with clamping for numerical stability |
| torchtitan/models/gpt_oss/tests/test_attention_sink.py | Added comprehensive test suite with equivalence tests, probability mass preservation checks, and edge case validation against HuggingFace reference |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| values = torch.randn(batch_size, num_heads, seq_len, head_dim) | ||
|
|
||
| # Per-head sink weights (typically small negative to positive) | ||
| sinks = torch.randn(num_heads) * 2 # Range roughly [-4, 4] |
Copilot
AI
Jan 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment states the range is "roughly [-4, 4]" but torch.randn() * 2 produces values with mean=0 and std=2, so approximately 99.7% of values fall within [-6, 6] (3 standard deviations). Consider updating the comment to reflect the actual approximate range as "roughly [-6, 6]".
| sinks = torch.randn(num_heads) * 2 # Range roughly [-4, 4] | |
| sinks = torch.randn(num_heads) * 2 # Range roughly [-6, 6] |
| scores, values, sinks = setup_tensors | ||
|
|
||
| # Standard attention | ||
| probs = torch.softmax(scores, dim=-1) |
Copilot
AI
Jan 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable probs is not used.
| probs = torch.softmax(scores, dim=-1) |
- Fix comment: randn()*2 range is [-6, 6] not [-4, 4] - Remove unused probs variable in test_probability_mass_preserved
|
@wwwjn could you take a look? |
The previous sigmoid-based attention sink implementation was mathematically incorrect. This fix uses proper LSE (Log-Sum-Exp) renormalization that is equivalent to HuggingFace's concat+softmax approach.
Mathematical equivalence:
Changes:
Reference: HuggingFace transformers/integrations/flex_attention.py lines 309-322