Skip to content

Commit

Permalink
#sdy avoid sideways operand propagation for elementwise ops which can…
Browse files Browse the repository at this point in the history
…'t propagate to the result.

We don't want to propagate sideways through operands if the element-wise factor is used in the result and is not sharded in the same way. We want to avoid this to avoid the following situation which can happen when a `sharding_constraint` is added onto the operand during Shardy import:

```mlir
%arg0: [{"a", ?}]
%arg1: [{?}]
%0 = add %arg0, %arg1 : [{}]
```

We don't want to do an all-gather on both %arg0 and %arg1 due to "a" propagating sideways. Instead with the code below, since "a" can't propagate to `%0`, we will only do an all-gather on %arg0.

Long term we should undo this and allow sideways propagation, but have our explicit reshard pass make sure the result is all-gathered instead of both operands.

PiperOrigin-RevId: 726920093
  • Loading branch information
bartchr808 authored and copybara-github committed Feb 14, 2025
1 parent 1023b94 commit 4fbb266
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 18 deletions.
1 change: 1 addition & 0 deletions shardy/dialect/sdy/transforms/propagation/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ cc_library(
":utils",
"//shardy/dialect/sdy/ir:dialect",
"//shardy/dialect/sdy/transforms/common:macros",
"//shardy/dialect/sdy/transforms/common:op_properties",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h"
#include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h"
#include "shardy/dialect/sdy/transforms/common/op_properties.h"

namespace mlir {
namespace sdy {
Expand Down Expand Up @@ -124,12 +125,8 @@ UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings(
factorToSourceTensor[j].index, j);
});

// The propagation on each tensor is independent. This strategy can propagate
// different shardings to different tensors along the same factor. Examples
// are provided in the docstring of this class.
for (const auto& [tensorIndex, tensorFactorShardings] :
llvm::enumerate(llvm::concat<const TensorFactorShardings>(
projection.getOperands(), projection.getResults()))) {
llvm::enumerate(projection.getResults())) {
const FactorIndexToSharding& factorIndexToSharding =
tensorFactorShardings.factorIndexToSharding;

Expand Down Expand Up @@ -171,18 +168,90 @@ UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings(
axisRef, factorIndexToSharding, factorIndex);
},
mesh, conservativePropagation);
tensorUpdated |=
expandTensorSharding(projection, tensorIndex, factorIndex, newAxes);
tensorUpdated |= expandTensorSharding(
projection, tensorIndex + projection.getNumOperands(), factorIndex,
newAxes);
}
result.updateResults[tensorIndex] = tensorUpdated;
}

for (const auto& [tensorIndex, tensorFactorShardings] :
llvm::enumerate(projection.getOperands())) {
const FactorIndexToSharding& factorIndexToSharding =
tensorFactorShardings.factorIndexToSharding;

if (tensorIndex < projection.getNumOperands()) {
result.updateOperands[tensorIndex] = tensorUpdated;
} else {
result.updateResults[tensorIndex - projection.getNumOperands()] =
tensorUpdated;
// Propagate the axes got in Step 1, resolving conflicts between factors by
// following the order of preference in `sortedFactorIndices`.
bool tensorUpdated = false;
for (int64_t factorIndex : sortedFactorIndices) {
auto factorShardingIt = factorIndexToSharding.find(factorIndex);
if (factorShardingIt == factorIndexToSharding.end()) {
continue;
}
const FactorSharding& factorSharding = factorShardingIt->second;
SmallVector<AxisRefAttr> newAxes = axesPerFactor[factorIndex];

// Resolve conflicts within a factor.
truncateAxesByRemovingConflicts(
newAxes,
[&, factorIndex = factorIndex,
&tensorFactorShardings = tensorFactorShardings](
AxisRefAttr axisRef, int64_t prevShardedSize) {
return compatiblePrefixNoConflictsWithinFactor(
axisRef, tensorFactorShardings.replicatedAxes, factorSharding,
prevShardedSize, factorSizes[factorIndex], mesh);
},
mesh, conservativePropagation);
if (!isStrictPrefix(factorSharding.axisRefs, newAxes)) {
continue;
}

// Resolve conflicts (overlapping sharding axes) between factors.
//
// Note that we pass `factorIndexToSharding`, which might have been
// updated for a previous factor (previous iteration), thus we are
// checking for conflicts w.r.t. the updated state of this tensor.
truncateAxesByRemovingConflicts(
newAxes,
[&, factorIndex = factorIndex](AxisRefAttr axisRef, int64_t) {
return compatiblePrefixNoConflictsAcrossFactors(
axisRef, factorIndexToSharding, factorIndex);
},
mesh, conservativePropagation);

// Do not propagate sideways through operands if the element-wise factor
// is used in the result and is not sharded in the same way. We want to
// avoid this to avoid the following situation which can happen when a
// `sharding_constraint` is added onto the operand during Shardy import:
// ```
// %arg0: [{"a", ?}]
// %arg1: [{?}]
// %0 = add %arg0, %arg1 : [{}]
// ```
// We don't want to do an all-gather on both %arg0 and %arg1 due to "a"
// propagating sideways. Instead with the code below, since "a" can't
// propagate to `%0`, we will only do an all-gather on %arg0.
//
// TODO(b/396642774): Long term we should undo this and allow sideways
// propagation, but have our explicit reshard pass make sure the result is
// all-gathered instead of both operands.
if (op && isElementwise(op)) {
for (const TensorFactorShardings& resultFactorSharding :
projection.getResults()) {
if (auto it =
resultFactorSharding.factorIndexToSharding.find(factorIndex);
it != resultFactorSharding.factorIndexToSharding.end() &&
it->getSecond().axisRefs != newAxes) {
newAxes.clear();
break;
}
}
}
tensorUpdated |=
expandTensorSharding(projection, tensorIndex, factorIndex, newAxes);
}
result.updateOperands[tensorIndex] = tensorUpdated;
}

return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ func.func @input_output_source_sharding(
// CHECK-LABEL: partial_axes_match
// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", "b", ?}, {?}]>,
// CHECK-SAME: sdy.sharding_origins = {a = "self", b = "self"}},
// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>,
// CHECK-SAME: sdy.sharding_origins = {a = "input: 0", b = "self"}}
// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"b", ?}]>,
// CHECK-SAME: sdy.sharding_origins = {b = "self"}}
// CHECK-SAME: ) -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", "b", ?}, {?}]>,
// CHECK-SAME: sdy.sharding_origins = {a = "input: 0", b = "input: 0"}}) {
func.func @partial_axes_match(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,17 @@ func.func @multiple_conflicts_across_factors(
(tensor<2x8x4xf32>, tensor<2x4x16xf32>) -> tensor<2x8x16xf32>
return %0 : tensor<2x8x16xf32>
}


// CHECK-LABEL: func @avoid_sideways_propagation_if_conflicting_with_result(
// CHECK-SAME: %arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a"}]>},
// CHECK-SAME: %arg1: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{?}]>})
// CHECK-SAME: -> tensor<8xf32>
func.func @avoid_sideways_propagation_if_conflicting_with_result(
%arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{"a"}]>},
%arg1: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh_a_2_b_2, [{?}]>})
-> tensor<8xf32> {
// CHECK-NEXT: stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_a_2_b_2, [{}]>]>}
%0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_a_2_b_2, [{}]>]>} : tensor<8xf32>
return %0 : tensor<8xf32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func.func @arg_lower_priority_than_return_value(
// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>},
// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {"b", ?}]>},
// CHECK-SAME: %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {"b", ?}]>},
// CHECK-SAME: %arg3: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {"b", ?}]>})
// CHECK-SAME: %arg3: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {?}]>})
// CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {?}]>}) {
func.func @arg_lower_priority_than_return_value_with_replicated(
%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p1, {"b"}p1]>},
Expand All @@ -72,7 +72,7 @@ func.func @arg_lower_priority_than_return_value_with_replicated(
// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>},
// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>},
// CHECK-SAME: %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>},
// CHECK-SAME: %arg3: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>})
// CHECK-SAME: %arg3: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{?}, {"b", ?}]>})
// CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {"b", ?}]>}) {
func.func @arg_higher_priority_than_return_value(
%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p0, {"b"}p0]>},
Expand Down Expand Up @@ -143,10 +143,13 @@ func.func @dim_with_lower_priority_gets_further_sharded_by_higher(
return %1, %2 : tensor<8x8xf32>, tensor<8x8xf32>
}

// TODO(b/396642774): `%arg2` should be sharded on `[{"a", ?}, {"b", ?}] once we
// allow sideways propagation by reverting cl/726920093. Note that this behavior
// is matching GSPMD, but like described in b/396642774 is not ideal.
// CHECK-LABEL: func @different_priorities_with_closed_empty_dim(
// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>},
// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>},
// CHECK-SAME: %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>},
// CHECK-SAME: %arg2: tensor<8x8xf32>,
// CHECK-SAME: %arg3: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {?}]>})
// CHECK-SAME: -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c", ?}, {?}]>}) {
func.func @different_priorities_with_closed_empty_dim(
Expand Down

0 comments on commit 4fbb266

Please sign in to comment.