Skip to content

Commit b1e99f5

Browse files
authored
[RFC] Add source_target_pairs attribute to send and recv ops in stableHLO (#2784)
1 parent ba4ab03 commit b1e99f5

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# [RFC] Add source-target pairs to send/recv ops
2+
3+
Status: In Review<br/>
4+
Initial version: 04/21/2025<br/>
5+
Last updated: 04/21/2025<br/>
6+
Discussion thread: N/A
7+
8+
## Overview
9+
10+
This RFC proposes adding a new attribute `source_target_pairs` to `send` and
11+
`recv` ops. `source_target_pairs` allows users to specify peer-to-peer
12+
communication patterns using global device IDs (zero-indexed integers).
13+
Currently this feature is only available on GPUs.
14+
15+
## Background
16+
17+
SPMD-based pipeline parallelism relies on optimizations in XLA to pipeline
18+
send/recv operations in such a way that compute and communication are
19+
overlapped. The user expresses this through collective permutes and relies on
20+
XLA to decompose these into send/recv operations, which are then pipelined
21+
separately, allowing for the staggering that is unique to pipeline parallelism.
22+
The limitation of this approach is that it encapsulates the latency hiding
23+
mechanism in the compiler and allows for little control by the user. When this
24+
mechanism fails, the user has little choice but to debug XLA itself. This RFC is
25+
proposed in conjunction with exposing send/recv operations through the JAX
26+
`shard_map` API.
27+
28+
## Proposed Specification
29+
30+
### send
31+
32+
#### Semantics
33+
34+
Sends `inputs` to a channel `channel_id` and produces a `result` token.
35+
36+
If `is_host_transfer` is `true`, then the operation transfers data to the
37+
host. Otherwise, it transfers data to another device based on the values of
38+
`source_target_pairs`. This flag duplicates the information provided in
39+
`channel_type`, so in the future we are planning to only keep one of them
40+
([#666](https://github.com/openxla/stablehlo/issues/666)).
41+
42+
#### Inputs
43+
44+
| Label | Name | Type | Constraints |
45+
|-------|-----------------------|-------------------------------------------------|-------------|
46+
| (I1) | `inputs` | variadic number of tensors or quantized tensors | |
47+
| (I2) | `token` | `token` | |
48+
| (I3) | `source_target_pairs` | 2-dimensional tensor constant of type `si64` | (C1-C4) |
49+
| (I4) | `channel_id` | constant of type `si64` | |
50+
| (I5) | `channel_type` | enum of `DEVICE_TO_DEVICE` and `DEVICE_TO_HOST` | (C5) |
51+
| (I6) | `is_host_transfer` | constant of type `i1` | (C5) |
52+
53+
#### Outputs
54+
55+
| Name | Type |
56+
|----------|---------|
57+
| `result` | `token` |
58+
59+
#### Constraints
60+
61+
* (C1) `dim(source_target_pairs, 1) = 2`.
62+
* (C2) `is_unique(source_target_pairs[:, 0])`.
63+
* (C3) `is_unique(source_target_pairs[:, 1])`.
64+
* (C4) `0 <= source_target_pairs < N`, where `N` is defined as:
65+
* `num_replicas` if `cross_replica` is used.
66+
* `num_partitions` if `cross_partition` is used.
67+
* (C5) `channel_type` is defined as:
68+
* `DEVICE_TO_HOST` if `is_host_transfer = true`,
69+
* `DEVICE_TO_DEVICE` otherwise.
70+
71+
#### Examples
72+
73+
```mlir
74+
%result = "stablehlo.send"(%operand, %token) {
75+
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
76+
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
77+
is_host_transfer = true
78+
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
79+
```
80+
81+
### recv
82+
83+
#### Semantics
84+
85+
Receives data from a channel with `channel_id` and produces `results`.
86+
87+
If `is_host_transfer` is `true`, then the operation transfers data from the
88+
host. Otherwise, it transfers data from another device based on the values of
89+
`source_target_pairs`. This flag duplicates the information provided in
90+
`channel_type`, so in the future we are planning to only keep one of them
91+
([#666](https://github.com/openxla/stablehlo/issues/666)).
92+
93+
`results` consist of payload values which come first and a token which comes
94+
last. In the future, we are planning to split the payload and the token into two
95+
separate outputs to improve clarity
96+
([#670](https://github.com/openxla/stablehlo/issues/670)).
97+
98+
#### Inputs
99+
100+
| Label | Name | Type | Constraints |
101+
|-------|-----------------------|-------------------------------------------------|-------------|
102+
| (I1) | `token` | `token` | |
103+
| (I2) | `source_target_pairs` | 2-dimensional tensor constant of type `si64` | (C1-C4) |
104+
| (I3) | `channel_id` | constant of type `si64` | |
105+
| (I4) | `channel_type` | enum of `DEVICE_TO_DEVICE` and `DEVICE_TO_HOST` | (C5) |
106+
| (I5) | `is_host_transfer` | constant of type `i1` | (C5) |
107+
108+
#### Outputs
109+
110+
| Name | Type | Constraints |
111+
|-----------|---------------------------------------------------------|-------------|
112+
| `results` | variadic number of tensors, quantized tensors or tokens | (C2-C4) |
113+
114+
#### Constraints
115+
116+
* (C1) `dim(source_target_pairs, 1) = 2`.
117+
* (C2) `is_unique(source_target_pairs[:, 0])`.
118+
* (C3) `is_unique(source_target_pairs[:, 1])`.
119+
* (C4) `0 <= source_target_pairs < N`, where `N` is defined as:
120+
* `num_replicas` if `cross_replica` is used.
121+
* `num_partitions` if `cross_partition` is used.
122+
* (C5) `channel_type` is defined as:
123+
* `DEVICE_TO_HOST` if `is_host_transfer = true`,
124+
* `DEVICE_TO_DEVICE` otherwise.
125+
126+
#### Examples
127+
128+
```mlir
129+
%results0, %results1 = "stablehlo.recv"(%token) {
130+
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
131+
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
132+
is_host_transfer = false
133+
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
134+
```
135+
136+
## A Note On Backward Compatibility
137+
138+
The feature introduced in this RFC technically makes the semantics more strict
139+
for `send` and `recv`, given that any instances of
140+
`send(is_host_transfer=false)` that are serialized will no longer be
141+
deserializable. However, this is unlikely to impact existing users as this
142+
would have been undefined behavior as it is.

0 commit comments

Comments
 (0)