Skip to content

[RFC] Add source_target_pairs attribute to send and recv ops in stableHLO #2784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

rosiezou
Copy link
Member

No description provided.

@rosiezou rosiezou requested a review from GleasonK April 23, 2025 05:12
@sdasgup3 sdasgup3 self-requested a review April 23, 2025 17:43
@mjsML
Copy link
Member

mjsML commented Apr 23, 2025

SPMD-based pipeline parallelism relies on optimizations in XLA to pipeline
send/recv operations in such a way that compute and communication are
overlapped. The user expresses this through collective permutes and relies on
XLA to decompose these into send/recv operations, which are then pipelined
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more for my understanding the motivation:

The hlo-level send/recvs, generated by the XLA decomposition, do carry the source/target info using attribs like frontend_attributes {_xla_send_recv_source_target_pairs={{3,0}}}. I am not clear how exposing this information at a high-level Stablehlo representation would help as the only stablehlo operations, exported from Jax, are collective permutes (not the stablehlo send/recv's).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per the last sentence in this paragraph, I'm also working on adding psend and precv to jax.lax.parallel. I didn't include it in this RFC because I figured it's not related to stableHLO.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! I missed that. That makes sense

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rosiezou Is there an available RFC for the JAX side of things.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not an official RFC but jax-ml/jax#28101 is a WIP of the JAX changes.

%results0, %results1 = "stablehlo.recv"(%token) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep it false to make it clear that the source_target_pairs will have some effect in this particular example.

#### Semantics

Sends `inputs` to a channel `channel_id`. Inputs are then sent to other devices
in the order specified by `source_target_pairs`. The operation produces a
Copy link
Member

@sdasgup3 sdasgup3 Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to recv can we defer the mention of in the order specified by source_target_pairs to the second paragraph (as it is now), unless we somehow capture the event of is_host_transfer being true in the same paragraph.

My goal here to keep the wording consistent with recv.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants