Skip to content

Conversation

@eous
Copy link

@eous eous commented Jan 8, 2026

The previous sigmoid-based attention sink implementation was mathematically incorrect. This fix uses proper LSE (Log-Sum-Exp) renormalization that is equivalent to HuggingFace's concat+softmax approach.

Mathematical equivalence:

  • HF approach: concat sink logit to scores, softmax over K+1, drop sink position
  • Our approach: compute combined_lse = logsumexp([lse, sink]), then renormalize output by exp(old_lse - new_lse)

Changes:

  • Replace sigmoid(lse - sink) with proper LSE renormalization
  • Add clamping [-20, 0] for numerical stability
  • Add comprehensive test suite validating equivalence to HF reference

Reference: HuggingFace transformers/integrations/flex_attention.py lines 309-322

The previous sigmoid-based attention sink implementation was mathematically
incorrect. This fix uses proper LSE (Log-Sum-Exp) renormalization that is
equivalent to HuggingFace's concat+softmax approach.

Mathematical equivalence:
- HF approach: concat sink logit to scores, softmax over K+1, drop sink position
- Our approach: compute combined_lse = logsumexp([lse, sink]), then
  renormalize output by exp(old_lse - new_lse)

Changes:
- Replace sigmoid(lse - sink) with proper LSE renormalization
- Add clamping [-20, 0] for numerical stability
- Add comprehensive test suite validating equivalence to HF reference

Reference: HuggingFace transformers/integrations/flex_attention.py lines 309-322
Copilot AI review requested due to automatic review settings January 8, 2026 23:49
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 8, 2026
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a mathematical error in the attention sink implementation by replacing an incorrect sigmoid-based approach with proper LSE (Log-Sum-Exp) renormalization. The new implementation is mathematically equivalent to HuggingFace's concat+softmax approach and includes comprehensive test coverage.

Key Changes:

  • Replace sigmoid-based attention sink rescaling with LSE renormalization in the attention mechanism
  • Add numerical stability clamping to the renormalization factor computation
  • Implement comprehensive test suite validating mathematical equivalence with HuggingFace reference implementation

Reviewed changes

Copilot reviewed 2 out of 3 changed files in this pull request and generated 2 comments.

File Description
torchtitan/models/gpt_oss/model/model.py Replaced incorrect sigmoid(lse - sink) with proper LSE renormalization using logsumexp and exp(lse - combined_lse), with clamping for numerical stability
torchtitan/models/gpt_oss/tests/test_attention_sink.py Added comprehensive test suite with equivalence tests, probability mass preservation checks, and edge case validation against HuggingFace reference

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

values = torch.randn(batch_size, num_heads, seq_len, head_dim)

# Per-head sink weights (typically small negative to positive)
sinks = torch.randn(num_heads) * 2 # Range roughly [-4, 4]
Copy link

Copilot AI Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states the range is "roughly [-4, 4]" but torch.randn() * 2 produces values with mean=0 and std=2, so approximately 99.7% of values fall within [-6, 6] (3 standard deviations). Consider updating the comment to reflect the actual approximate range as "roughly [-6, 6]".

Suggested change
sinks = torch.randn(num_heads) * 2 # Range roughly [-4, 4]
sinks = torch.randn(num_heads) * 2 # Range roughly [-6, 6]

Copilot uses AI. Check for mistakes.
scores, values, sinks = setup_tensors

# Standard attention
probs = torch.softmax(scores, dim=-1)
Copy link

Copilot AI Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable probs is not used.

Suggested change
probs = torch.softmax(scores, dim=-1)

Copilot uses AI. Check for mistakes.
- Fix comment: randn()*2 range is [-6, 6] not [-4, 4]
- Remove unused probs variable in test_probability_mass_preserved
@tianyu-l
Copy link
Contributor

@wwwjn could you take a look?

@wwwjn wwwjn self-assigned this Jan 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants