Skip to content

Commit

Permalink
[SDY] Add equi-sharding op ShardingGroupOp parsing and printing.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653252805
  • Loading branch information
Varcho authored and copybara-github committed Jul 17, 2024
1 parent 158707d commit efcf98f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
25 changes: 25 additions & 0 deletions shardy/dialect/sdy/ir/ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,31 @@ def Sdy_ManualComputationOp : Sdy_Op<"manual_computation",
}];
}

def Sdy_ShardingGroupOp : Sdy_Op<"sharding_group",
// Op is non-pure since it modifies the internal representation of the
// sharding group.
[]>{
let summary = "Sharding group operation";
let description = [{
This op provides an interface to assign tensors to sharding groups (
groups of tensors that will be enforced to have identical shardings).
During propagation, as soon as one group element is sharded, all other
members will be sharded in exactly the same way. This operation takes the
argument group ID and returns no result, but instead modifies the internal
sharding group representation to add the input tensor to the group with the
given ID.
}];

let arguments = (ins
AnyRankedTensor:$input,
I64Attr:$group_id);

// Dangling op has no results.
let results = (outs);

let assemblyFormat = "$input `group_id````=```$group_id attr-dict `:` type($input)";
}

def Sdy_ConstantOp : Sdy_Op<"constant",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Constant operation";
Expand Down
8 changes: 8 additions & 0 deletions shardy/dialect/sdy/ir/test/sharding_group_parse_print.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: sdy_opt %s 2>&1 | FileCheck %s

// CHECK-LABEL: func @add_to_default_group_type
func.func @add_to_default_group_type(%arg0: tensor<8xf32>) {
// CHECK sdy.sharding_group %arg0 group_id=21 type=AS : tensor<8xf32>
sdy.sharding_group %arg0 group_id=21 : tensor<8xf32>
func.return
}

0 comments on commit efcf98f

Please sign in to comment.