-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Fix] Improve dim_size handling in SetTransformerAggregation to prevent CUDA crash #10220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[Fix] Improve dim_size handling in SetTransformerAggregation to prevent CUDA crash #10220
Conversation
[Fix] Add dim_size validation and fallback to SetTransformerAggregation
base repository: pyg-team/pytorch_geometric ← compare: KAVYANSHTYAGI:fix/set-transformer-aggregation-index-check
for more information, see https://pre-commit.ci
@@ -94,6 +94,15 @@ def forward( | |||
max_num_elements: Optional[int] = None, | |||
) -> Tensor: | |||
|
|||
if dim_size is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is already handled in to_dense_batch
.
if int(index.max()) >= dim_size: | ||
raise ValueError( | ||
f"SetTransformerAggregation error: index.max() = {int(index.max())}, " | ||
f"but dim_size = {dim_size}. This causes an indexing error on GPU. " | ||
f"Ensure data.batch is set or dim_size is passed explicitly.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This leads to a device sync, so we should avoid this.
Fix: Remove device sync and dim_size fallback in SetTransformerAggregation - Removed redundant dim_size = index.max() + 1 logic (handled in to_dense_batch). - Added GPU-safe index validation to avoid CUDA crashes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the feedback!
I've removed the fallback dim_size = index.max() + 1 to avoid redundancy with to_dense_batch, as suggested.
Also eliminated device sync by replacing the .max() check with a GPU-safe tensor comparison using (index >= dim_size).any().
Let me know if any further simplification is preferred!
for more information, see https://pre-commit.ci
This PR improves the robustness of
SetTransformerAggregation
by:dim_size = index.max() + 1
ifdim_size
is not provided.index.max() >= dim_size
to avoid CUDA crashes during evaluation.This is helpful especially for datasets like PPI where
data.batch
may be missing. It replaces hard-to-debug GPU errors with clear and early validation.