|
| 1 | +# [RFC] Add source-target pairs to send/recv ops |
| 2 | + |
| 3 | +Status: In Review<br/> |
| 4 | +Initial version: 04/21/2025<br/> |
| 5 | +Last updated: 04/21/2025<br/> |
| 6 | +Discussion thread: N/A |
| 7 | + |
| 8 | +## Overview |
| 9 | + |
| 10 | +This RFC proposes adding a new attribute `source_target_pairs` to `send` and |
| 11 | +`recv` ops. `source_target_pairs` allows users to specify peer-to-peer |
| 12 | +communication patterns using global device IDs (zero-indexed integers). |
| 13 | +Currently this feature is only available on GPUs. |
| 14 | + |
| 15 | +## Background |
| 16 | + |
| 17 | +SPMD-based pipeline parallelism relies on optimizations in XLA to pipeline |
| 18 | +send/recv operations in such a way that compute and communication are |
| 19 | +overlapped. The user expresses this through collective permutes and relies on |
| 20 | +XLA to decompose these into send/recv operations, which are then pipelined |
| 21 | +separately, allowing for the staggering that is unique to pipeline parallelism. |
| 22 | +The limitation of this approach is that it encapsulates the latency hiding |
| 23 | +mechanism in the compiler and allows for little control by the user. When this |
| 24 | +mechanism fails, the user has little choice but to debug XLA itself. This RFC is |
| 25 | +proposed in conjunction with exposing send/recv operations through the JAX |
| 26 | +`shard_map` API. |
| 27 | + |
| 28 | +## Proposed Specification |
| 29 | + |
| 30 | +### send |
| 31 | + |
| 32 | +#### Semantics |
| 33 | + |
| 34 | +Sends `inputs` to a channel `channel_id` and produces a `result` token. |
| 35 | + |
| 36 | +If `is_host_transfer` is `true`, then the operation transfers data to the |
| 37 | +host. Otherwise, it transfers data to another device based on the values of |
| 38 | +`source_target_pairs`. This flag duplicates the information provided in |
| 39 | +`channel_type`, so in the future we are planning to only keep one of them |
| 40 | +([#666](https://github.com/openxla/stablehlo/issues/666)). |
| 41 | + |
| 42 | +#### Inputs |
| 43 | + |
| 44 | +| Label | Name | Type | Constraints | |
| 45 | +|-------|-----------------------|-------------------------------------------------|-------------| |
| 46 | +| (I1) | `inputs` | variadic number of tensors or quantized tensors | | |
| 47 | +| (I2) | `token` | `token` | | |
| 48 | +| (I3) | `source_target_pairs` | 2-dimensional tensor constant of type `si64` | (C1-C4) | |
| 49 | +| (I4) | `channel_id` | constant of type `si64` | | |
| 50 | +| (I5) | `channel_type` | enum of `DEVICE_TO_DEVICE` and `DEVICE_TO_HOST` | (C5) | |
| 51 | +| (I6) | `is_host_transfer` | constant of type `i1` | (C5) | |
| 52 | + |
| 53 | +#### Outputs |
| 54 | + |
| 55 | +| Name | Type | |
| 56 | +|----------|---------| |
| 57 | +| `result` | `token` | |
| 58 | + |
| 59 | +#### Constraints |
| 60 | + |
| 61 | +* (C1) `dim(source_target_pairs, 1) = 2`. |
| 62 | +* (C2) `is_unique(source_target_pairs[:, 0])`. |
| 63 | +* (C3) `is_unique(source_target_pairs[:, 1])`. |
| 64 | +* (C4) `0 <= source_target_pairs < N`, where `N` is defined as: |
| 65 | + * `num_replicas` if `cross_replica` is used. |
| 66 | + * `num_partitions` if `cross_partition` is used. |
| 67 | +* (C5) `channel_type` is defined as: |
| 68 | + * `DEVICE_TO_HOST` if `is_host_transfer = true`, |
| 69 | + * `DEVICE_TO_DEVICE` otherwise. |
| 70 | + |
| 71 | +#### Examples |
| 72 | + |
| 73 | +```mlir |
| 74 | +%result = "stablehlo.send"(%operand, %token) { |
| 75 | + source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>, |
| 76 | + channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>, |
| 77 | + is_host_transfer = true |
| 78 | +} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token |
| 79 | +``` |
| 80 | + |
| 81 | +### recv |
| 82 | + |
| 83 | +#### Semantics |
| 84 | + |
| 85 | +Receives data from a channel with `channel_id` and produces `results`. |
| 86 | + |
| 87 | +If `is_host_transfer` is `true`, then the operation transfers data from the |
| 88 | +host. Otherwise, it transfers data from another device based on the values of |
| 89 | +`source_target_pairs`. This flag duplicates the information provided in |
| 90 | +`channel_type`, so in the future we are planning to only keep one of them |
| 91 | +([#666](https://github.com/openxla/stablehlo/issues/666)). |
| 92 | + |
| 93 | +`results` consist of payload values which come first and a token which comes |
| 94 | +last. In the future, we are planning to split the payload and the token into two |
| 95 | +separate outputs to improve clarity |
| 96 | +([#670](https://github.com/openxla/stablehlo/issues/670)). |
| 97 | + |
| 98 | +#### Inputs |
| 99 | + |
| 100 | +| Label | Name | Type | Constraints | |
| 101 | +|-------|-----------------------|-------------------------------------------------|-------------| |
| 102 | +| (I1) | `token` | `token` | | |
| 103 | +| (I2) | `source_target_pairs` | 2-dimensional tensor constant of type `si64` | (C1-C4) | |
| 104 | +| (I3) | `channel_id` | constant of type `si64` | | |
| 105 | +| (I4) | `channel_type` | enum of `DEVICE_TO_DEVICE` and `DEVICE_TO_HOST` | (C5) | |
| 106 | +| (I5) | `is_host_transfer` | constant of type `i1` | (C5) | |
| 107 | + |
| 108 | +#### Outputs |
| 109 | + |
| 110 | +| Name | Type | Constraints | |
| 111 | +|-----------|---------------------------------------------------------|-------------| |
| 112 | +| `results` | variadic number of tensors, quantized tensors or tokens | (C2-C4) | |
| 113 | + |
| 114 | +#### Constraints |
| 115 | + |
| 116 | +* (C1) `dim(source_target_pairs, 1) = 2`. |
| 117 | +* (C2) `is_unique(source_target_pairs[:, 0])`. |
| 118 | +* (C3) `is_unique(source_target_pairs[:, 1])`. |
| 119 | +* (C4) `0 <= source_target_pairs < N`, where `N` is defined as: |
| 120 | + * `num_replicas` if `cross_replica` is used. |
| 121 | + * `num_partitions` if `cross_partition` is used. |
| 122 | +* (C5) `channel_type` is defined as: |
| 123 | + * `DEVICE_TO_HOST` if `is_host_transfer = true`, |
| 124 | + * `DEVICE_TO_DEVICE` otherwise. |
| 125 | + |
| 126 | +#### Examples |
| 127 | + |
| 128 | +```mlir |
| 129 | +%results0, %results1 = "stablehlo.recv"(%token) { |
| 130 | + source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>, |
| 131 | + channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>, |
| 132 | + is_host_transfer = false |
| 133 | +} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token) |
| 134 | +``` |
| 135 | + |
| 136 | +## A Note On Backward Compatibility |
| 137 | + |
| 138 | +The feature introduced in this RFC technically makes the semantics more strict |
| 139 | +for `send` and `recv`, given that any instances of |
| 140 | +`send(is_host_transfer=false)` that are serialized will no longer be |
| 141 | +deserializable. However, this is unlikely to impact existing users as this |
| 142 | +would have been undefined behavior as it is. |
0 commit comments