Skip to content

Commit

Permalink
#sdy Only propagate backwards through stablehlo.broadcast_in_dim in…
Browse files Browse the repository at this point in the history
… the 2nd op-priority iteration.

The rational for this is that in case of a conflict, it's always better to insert a reshard on the operand of the broadcast rather than on the result, since the latter is bigger.

PiperOrigin-RevId: 727817413
  • Loading branch information
tomnatan30 authored and copybara-github committed Feb 24, 2025
1 parent 231c438 commit 5e459ea
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ namespace {
using GetDirectionToPropagateFnPtr = PropagationDirection (*)(Operation*,
int64_t);

PropagationDirection isPassThrough(Operation* op, int64_t) {
PropagationDirection isPassThroughOp(Operation* op, int64_t) {
if (isElementwise(op) ||
isa<stablehlo::ReshapeOp, stablehlo::TransposeOp, DataFlowEdgeOp>(op)) {
return PropagationDirection::BOTH;
Expand All @@ -64,18 +64,21 @@ PropagationDirection isPassThrough(Operation* op, int64_t) {

// NOTE: if the `op` has no sharding rule, then we will assume it uses an
// identity sharding rule. For example, `DataFlowEdgeOp`.
PropagationDirection onlyPassThroughFactors(Operation* op,
int64_t factorIndex) {
PropagationDirection onlyPassThroughFactorsBroadcastBackward(
Operation* op, int64_t factorIndex) {
if (auto shardingRule =
op->getAttrOfType<OpShardingRuleAttr>(kShardingRuleAttr);
shardingRule && !shardingRule.isPassThroughFactor(factorIndex)) {
return PropagationDirection::NONE;
}
if (isa<stablehlo::BroadcastInDimOp>(op)) {
return PropagationDirection::BACKWARD;
}
return PropagationDirection::BOTH;
}

constexpr std::array<GetDirectionToPropagateFnPtr, 3> opPropagationSchedule = {
isPassThrough, onlyPassThroughFactors, propagateAny};
isPassThroughOp, onlyPassThroughFactorsBroadcastBackward, propagateAny};

// Returns the direction in which the given operation should be propagated.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,20 @@ func.func @pass_through_factor_higher_priority_than_reduction_factor(
%arg0: tensor<32x1024xf32>,
%arg1: tensor<1024x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>}
) -> (tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>}) {
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %arg1, precision = [DEFAULT, DEFAULT] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", ?}, {"b", ?}]>]>} : (tensor<32x1024xf32>, tensor<1024x16xf32>) -> tensor<32x16xf32>
// CHECK-NEXT: %[[DOT:.*]] = stablehlo.dot %arg0, %arg1, precision = [DEFAULT, DEFAULT] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", ?}, {"b", ?}]>]>}
// CHECK-NEXT: return %[[DOT]] : tensor<32x16xf32>
%0 = stablehlo.dot %arg0, %arg1, precision = [DEFAULT, DEFAULT] : (tensor<32x1024xf32>, tensor<1024x16xf32>) -> tensor<32x16xf32>
return %0 : tensor<32x16xf32>
}

// CHECK-LABEL: func @broadcast_forward_higher_priority_than_backwards
func.func @broadcast_forward_higher_priority_than_backwards(
%arg0: tensor<32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>}
) -> (tensor<32x16x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"a"}, {}]>}) {
// CHECK-NEXT: %[[BROADCAST_1:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {"a", ?}]>]>}
// CHECK-NEXT: %[[BROADCAST_2:.*]] = stablehlo.broadcast_in_dim %[[BROADCAST_1]], dims = [0, 1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {"a", ?}, {?}]>]>}
// CHECK-NEXT: return %[[BROADCAST_2]]
%0 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<32xf32>) -> tensor<32x16xf32>
%1 = stablehlo.broadcast_in_dim %0, dims = [0, 1] : (tensor<32x16xf32>) -> tensor<32x16x8xf32>
return %1 : tensor<32x16x8xf32>
}

0 comments on commit 5e459ea

Please sign in to comment.