Skip to content

[Fix] Improve dim_size handling in SetTransformerAggregation to prevent CUDA crash #10220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

KAVYANSHTYAGI
Copy link

This PR improves the robustness of SetTransformerAggregation by:

  • Automatically setting dim_size = index.max() + 1 if dim_size is not provided.
  • Raising a clear error if index.max() >= dim_size to avoid CUDA crashes during evaluation.

This is helpful especially for datasets like PPI where data.batch may be missing. It replaces hard-to-debug GPU errors with clear and early validation.

[Fix] Add dim_size validation and fallback to SetTransformerAggregation
base repository: pyg-team/pytorch_geometric ← compare: KAVYANSHTYAGI:fix/set-transformer-aggregation-index-check
@@ -94,6 +94,15 @@ def forward(
max_num_elements: Optional[int] = None,
) -> Tensor:

if dim_size is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is already handled in to_dense_batch.

Comment on lines 100 to 104
if int(index.max()) >= dim_size:
raise ValueError(
f"SetTransformerAggregation error: index.max() = {int(index.max())}, "
f"but dim_size = {dim_size}. This causes an indexing error on GPU. "
f"Ensure data.batch is set or dim_size is passed explicitly.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This leads to a device sync, so we should avoid this.

Fix: Remove device sync and dim_size fallback in SetTransformerAggregation

- Removed redundant dim_size = index.max() + 1 logic (handled in to_dense_batch).
- Added GPU-safe index validation to avoid CUDA crashes.
Copy link
Author

@KAVYANSHTYAGI KAVYANSHTYAGI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the feedback!

I've removed the fallback dim_size = index.max() + 1 to avoid redundancy with to_dense_batch, as suggested.

Also eliminated device sync by replacing the .max() check with a GPU-safe tensor comparison using (index >= dim_size).any().

Let me know if any further simplification is preferred!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants