-
Notifications
You must be signed in to change notification settings - Fork 137
[RFC] Add source_target_pairs attribute to send and recv ops in stableHLO #2784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
SPMD-based pipeline parallelism relies on optimizations in XLA to pipeline | ||
send/recv operations in such a way that compute and communication are | ||
overlapped. The user expresses this through collective permutes and relies on | ||
XLA to decompose these into send/recv operations, which are then pipelined |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is more for my understanding the motivation:
The hlo-level send
/recv
s, generated by the XLA decomposition, do carry the source/target info using attribs like frontend_attributes {_xla_send_recv_source_target_pairs={{3,0}}}
. I am not clear how exposing this information at a high-level Stablehlo representation would help as the only stablehlo operations, exported from Jax, are collective permutes (not the stablehlo send/recv's).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per the last sentence in this paragraph, I'm also working on adding psend
and precv
to jax.lax.parallel
. I didn't include it in this RFC because I figured it's not related to stableHLO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah! I missed that. That makes sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rosiezou Is there an available RFC for the JAX side of things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not an official RFC but jax-ml/jax#28101 is a WIP of the JAX changes.
%results0, %results1 = "stablehlo.recv"(%token) { | ||
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>, | ||
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>, | ||
is_host_transfer = true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we keep it false
to make it clear that the source_target_pairs
will have some effect in this particular example.
#### Semantics | ||
|
||
Sends `inputs` to a channel `channel_id`. Inputs are then sent to other devices | ||
in the order specified by `source_target_pairs`. The operation produces a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to recv
can we defer the mention of in the order specified by source_target_pairs
to the second paragraph (as it is now), unless we somehow capture the event of is_host_transfer
being true
in the same paragraph.
My goal here to keep the wording consistent with recv.
No description provided.