diff --git a/shardy/dialect/sdy/ir/compatibility_test/BUILD b/shardy/dialect/sdy/ir/compatibility_test/BUILD new file mode 100644 index 00000000..bc9fb898 --- /dev/null +++ b/shardy/dialect/sdy/ir/compatibility_test/BUILD @@ -0,0 +1,23 @@ +# Lit tests for the SDY dialect. + +load("//shardy:lit.bzl", "glob_lit_tests") + +package(default_visibility = ["//visibility:public"]) + +filegroup( + name = "test_data", + testonly = True, + data = [ + "//shardy/dialect/sdy/ir/compatibility_test:compatibility_test.mlir.bc", + "//shardy/tools:sdy_opt", + "//shardy/tools:sdy_translate", + "@llvm-project//llvm:FileCheck", + ], +) + +glob_lit_tests( + name = "compatibility_tests", + data = [":test_data"], + driver = "@llvm-project//mlir:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/shardy/dialect/sdy/ir/compatibility_test/compatibility_test.mlir b/shardy/dialect/sdy/ir/compatibility_test/compatibility_test.mlir new file mode 100644 index 00000000..124f0a08 --- /dev/null +++ b/shardy/dialect/sdy/ir/compatibility_test/compatibility_test.mlir @@ -0,0 +1,173 @@ +// Smoke test: +// RUN: sdy_opt %s.bc | FileCheck %s +// RUN: sdy_opt %s.bc | sdy_translate --serialize | sdy_opt | FileCheck %s +// RUN: sdy_opt %s.bc | sdy_translate --serialize --strip-debuginfo | sdy_opt | FileCheck %s +// RUN: sdy_translate --deserialize %s.bc | sdy_opt | FileCheck %s +// +// Backward compatibility test: +// RUN: sdy_translate --serialize %s | sdy_opt > %t.0 +// RUN: sdy_opt %s > %t.1 +// RUN: diff %t.0 %t.1 +// +// Forward compatibility test: +// RUN: sdy_translate %s --serialize -strip-debuginfo > %t.2 +// RUN: diff %s.bc %t.2 + +// CHECK: sdy.mesh @empty_mesh = <[]> +sdy.mesh @empty_mesh = <[]> + +// CHECK: sdy.mesh @maximal_mesh_1 = <[], device_ids=[0]> +sdy.mesh @maximal_mesh_1 = <[], device_ids=[0]> + +// CHECK: sdy.mesh @maximal_mesh_2 = <[], device_ids=[3]> +sdy.mesh @maximal_mesh_2 = <[], device_ids=[3]> + +// CHECK: sdy.mesh @mesh_xy = <["x"=2, "y"=4]> +sdy.mesh @mesh_xy = <["x"=2, "y"=4]> + +// CHECK: sdy.mesh @mesh_x_non_iota_device_ids = <["x"=4], device_ids=[0, 3, 2, 1]> +sdy.mesh @mesh_x_non_iota_device_ids = <["x"=4], device_ids=[0, 3, 2, 1]> + +// CHECK: sdy.mesh @mesh_xyz = <["x"=2, "y"=2, "z"=2]> +sdy.mesh @mesh_xyz = <["x"=2, "y"=2, "z"=2]> + +// CHECK-LABEL: func @sharding_constraint +func.func @sharding_constraint(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK-NEXT: sdy.sharding_constraint %arg0 <@mesh_xy, [{}, {"x"}], replicated={"y"}> + %0 = sdy.sharding_constraint %arg0 <@mesh_xy, [{}, {"x"}], replicated={"y"}> : tensor<16x8xf32> + return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: func @reshard +func.func @reshard(%arg0 : tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK-NEXT: sdy.reshard %arg0 <@mesh_xy, [{}, {"y"}], replicated={"x"}> + %0 = sdy.reshard %arg0 <@mesh_xy, [{}, {"y"}], replicated={"x"}> : tensor<16x8xf32> + return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: func @manual_computation +func.func @manual_computation(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { + // CHECK{LITERAL}: sdy.manual_computation(%arg0) in_shardings=[<@mesh_xy, [{"x", ?}, {?}]>] out_shardings=[<@mesh_xy, [{"x", ?}, {?}]>] manual_axes={"x"} (%arg1: tensor<8x32xf32>) { + // CHECK-NEXT: sdy.return %arg1 : tensor<8x32xf32> + // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_xy, [{"x", ?}, {?}]>] out_shardings=[<@mesh_xy, [{"x", ?}, {?}]>] manual_axes={"x"} (%arg1: tensor<8x32xf32>) { + sdy.return %arg1 : tensor<8x32xf32> + } : (tensor<16x32xf32>) -> tensor<16x32xf32> + func.return %0: tensor<16x32xf32> +} + +// CHECK-LABEL: func @sharding_group +func.func @sharding_group(%arg0: tensor<8xf32>) { + // CHECK sdy.sharding_group %arg0 group_id=21 type=AS : tensor<8xf32> + sdy.sharding_group %arg0 group_id=21 : tensor<8xf32> + func.return +} + +// CHECK-LABEL: func @constant +func.func @constant() { + // CHECK-NEXT: sdy.constant dense<1.000000e+00> : tensor<8x16xf32> + %0 = sdy.constant dense<1.000000e+00> : tensor<8x16xf32> + func.return +} + +// CHECK-LABEL: func @data_flow_edge +func.func @data_flow_edge(%arg0: tensor<32x96xf32>, %arg1: tensor<32x96xf32>) + -> (tensor<32x96xf32>, tensor<32x96xf32>) { + // CHECK-NEXT: sdy.data_flow_edge %arg0 + // CHECK-NEXT: sdy.data_flow_edge %arg1 sharding=<@mesh_x_non_iota_device_ids, [{"x"}, {?}]> + %1 = sdy.data_flow_edge %arg0 : tensor<32x96xf32> + %2 = sdy.data_flow_edge %arg1 sharding=<@mesh_x_non_iota_device_ids, [{"x"}, {?}]> : tensor<32x96xf32> + return %1, %2 : tensor<32x96xf32>, tensor<32x96xf32> +} + +// CHECK-LABEL: func @propagation_barrier +func.func @propagation_barrier(%arg0 : tensor<8xf32>, %arg1: tensor<16x8xf32>, %arg2: tensor<8x16xf32>) + -> (tensor<8xf32>, tensor<16x8xf32>, tensor<8x16xf32>) { + // CHECK-NEXT: sdy.propagation_barrier %arg0 allowed_direction=NONE + // CHECK-NEXT: sdy.propagation_barrier %arg1 allowed_direction=FORWARD + // CHECK-NEXT: sdy.propagation_barrier %arg2 allowed_direction=BACKWARD + %0 = sdy.propagation_barrier %arg0 allowed_direction=NONE : tensor<8xf32> + %1 = sdy.propagation_barrier %arg1 allowed_direction=FORWARD : tensor<16x8xf32> + %2 = sdy.propagation_barrier %arg2 allowed_direction=BACKWARD : tensor<8x16xf32> + return %0, %1, %2 : tensor<8xf32>, tensor<16x8xf32>, tensor<8x16xf32> +} + +// CHECK-LABEL: func @named_computation +func.func @named_computation(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) { + // CHECK-NEXT: %0:2 = sdy.named_computation<"foo">(%arg0, %arg1) (%arg2: tensor<8x2xi32>, %arg3: tensor<4x2xi32>) { + // CHECK-NEXT: sdy.return %arg2, %arg3 : tensor<8x2xi32>, tensor<4x2xi32> + // CHECK-NEXT: } : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) + %0:2 = sdy.named_computation<"foo">(%arg0, %arg1) (%arg2: tensor<8x2xi32>, %arg3: tensor<4x2xi32>) { + sdy.return %arg2, %arg3 : tensor<8x2xi32>, tensor<4x2xi32> + } : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) + return %0#0, %0#1 : tensor<8x2xi32>, tensor<4x2xi32> +} + +// CHECK-LABEL: func @tensor_sharding +func.func @tensor_sharding(%arg0 : tensor<8x8xf32>, %arg1 : tensor<8x8xf32>) -> (tensor<64xf32>, tensor<8x8xf32>) { + // CHECK-NEXT: stablehlo.custom_call @bar(%arg0, %arg1) + // CHECK-SAME{LITERAL}: #sdy.sharding_per_value<[<@mesh_xy, [{"x", "y"}]>, <@mesh_xy, [{"x"}p0, {"y":(1)2}p123]>]> + %0:2 = stablehlo.custom_call @bar(%arg0, %arg1) + {sdy.sharding = #sdy.sharding_per_value<[<@mesh_xy, [{"x", "y"}]>, <@mesh_xy, [{"x"}p0, {"y":(1)2}p123]>]>} + : (tensor<8x8xf32>, tensor<8x8xf32>) -> (tensor<64xf32>, tensor<8x8xf32>) + return %0#0, %0#1 : tensor<64xf32>, tensor<8x8xf32> +} + +// CHECK-LABEL: func @tensor_sharding_on_parameter_result +// CHECK-SAME{LITERAL}: (%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}p2]>}) -> (tensor<64xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x", "y"}]>}) +func.func @tensor_sharding_on_parameter_result(%arg0 : tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}p2]>}) + -> (tensor<64xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x", "y"}]>}) { + %0 = stablehlo.custom_call @foo(%arg0) : (tensor<8x8xf32>) -> (tensor<64xf32>) + return %0 : tensor<64xf32> +} + +// CHECK-LABEL: func @tensor_sharding_scalar +// CHECK-SAME{LITERAL}: (%arg0: tensor {sdy.sharding = #sdy.sharding<@mesh_xy, []>}) -> (tensor<64xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x", "y"}]>}) +func.func @tensor_sharding_scalar(%arg0 : tensor {sdy.sharding = #sdy.sharding<@mesh_xy, []>}) + -> (tensor<64xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x", "y"}]>}) { + %0 = stablehlo.custom_call @foo(%arg0) : (tensor) -> (tensor<64xf32>) + return %0 : tensor<64xf32> +} + +// CHECK-LABEL: func @tensor_sharding_dynamic_shape +func.func @tensor_sharding_dynamic_shape(%arg0 : tensor) -> (tensor) { + // CHECK-NEXT: stablehlo.custom_call @bar(%arg0) + // CHECK-SAME{LITERAL}: #sdy.sharding_per_value<[<@mesh_xyz, [{"x", "y"}, {}], replicated={"z"}>]> + %0 = stablehlo.custom_call @bar(%arg0) + {sdy.sharding = #sdy.sharding_per_value<[<@mesh_xyz, [{"x", "y"}, {}], replicated={"z"}>]>} + : (tensor) -> (tensor) + return %0 : tensor +} + +// CHECK-LABEL: func @sharding_rule_scalar +func.func @sharding_rule_scalar(%arg0: tensor) -> tensor { + // CHECK: {sdy.sharding_rule = #sdy.op_sharding_rule<([], [])->([]), custom>} + %0 = stablehlo.custom_call @foo(%arg0, %arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([], [])->([]), custom>} : + (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @sharding_rule_tensor +func.func @sharding_rule_tensor(%arg0: tensor<2x4xf32>) -> tensor<8xf32> { + // CHECK: {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}>} + %0 = stablehlo.reshape %arg0 {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([ij]) {i=2, j=4}>} : (tensor<2x4xf32>) -> tensor<8xf32> + return %0 : tensor<8xf32> +} + +// CHECK-LABEL: func @sharding_rule_tensor_with_many_dimensions +func.func @sharding_rule_tensor_with_many_dimensions(%arg0: tensor<2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2xf32>) -> tensor<2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x8xf32> { + // CHECK: #sdy.op_sharding_rule<([i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, z_1, z_2, z_3, z_4, z_5, z_6, z_7, z_8, z_9, z_10]) + // CHECK-SAME: ->([i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, z_1, z_2, z_3, z_4, z_5, z_6, z_7, z_8z_9z_10]) + // CHECK-SAME: {i=2, j=2, k=2, l=2, m=2, n=2, o=2, p=2, q=2, r=2, s=2, t=2, u=2, v=2, w=2, x=2, y=2, z=2, z_1=2, z_2=2, z_3=2, z_4=2, z_5=2, z_6=2, z_7=2, z_8=2, z_9=2, z_10=2}>} : + %0 = stablehlo.custom_call @foo(%arg0) + {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, z_1, z_2, z_3, z_4, z_5, z_6, z_7, z_8, z_9, z_10])->([i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, z_1, z_2, z_3, z_4, z_5, z_6, z_7, z_8z_9z_10]) {i=2, j=2, k=2, l=2, m=2, n=2, o=2, p=2, q=2, r=2, s=2, t=2, u=2, v=2, w=2, x=2, y=2, z=2, z_1=2, z_2=2, z_3=2, z_4=2, z_5=2, z_6=2, z_7=2, z_8=2, z_9=2, z_10=2}>} + : (tensor<2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2xf32>) -> tensor<2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x8xf32> + return %0 : tensor<2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x2x8xf32> +} + +// CHECK-LABEL: func @custom_sharding_rule_custom_call +func.func @custom_sharding_rule_custom_call(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { + // CHECK: {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([i, j]) {i=16, j=32}, custom>} + %0 = stablehlo.custom_call @foo(%arg0) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j])->([i, j]) {i=16, j=32}, custom>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + func.return %0: tensor<16x32xf32> +} diff --git a/shardy/dialect/sdy/ir/compatibility_test/compatibility_test.mlir.bc b/shardy/dialect/sdy/ir/compatibility_test/compatibility_test.mlir.bc new file mode 100644 index 00000000..2688982f Binary files /dev/null and b/shardy/dialect/sdy/ir/compatibility_test/compatibility_test.mlir.bc differ diff --git a/shardy/lit.cfg.py b/shardy/lit.cfg.py index e8ce83c5..dc90dc91 100644 --- a/shardy/lit.cfg.py +++ b/shardy/lit.cfg.py @@ -36,6 +36,7 @@ tools = [ 'FileCheck', 'sdy_opt', + 'sdy_translate', ] tool_dirs = [ config.llvm_tools_dir, diff --git a/shardy/tools/BUILD b/shardy/tools/BUILD index eb216c64..1cb3b61a 100644 --- a/shardy/tools/BUILD +++ b/shardy/tools/BUILD @@ -15,3 +15,24 @@ cc_binary( "@llvm-project//mlir:QuantOps", ], ) + +cc_binary( + name = "sdy_translate", + srcs = ["sdy_translate_main.cc"], + deps = [ + "//shardy/dialect/sdy/ir:dialect", + "//shardy/dialect/sdy/ir:register", + "//shardy/dialect/sdy/transforms:passes", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:TranslateLib", + ], +) diff --git a/shardy/tools/sdy_translate_main.cc b/shardy/tools/sdy_translate_main.cc new file mode 100644 index 00000000..74dac8af --- /dev/null +++ b/shardy/tools/sdy_translate_main.cc @@ -0,0 +1,81 @@ +/* Copyright 2025 The Shardy Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// MLIR `translate` tool for allowing SDY dialect bytecode emission. +// +// Usage: +// sdy_translate -serialize +// sdy_translate -deserialize + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/LogicalResult.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Tools/mlir-translate/MlirTranslateMain.h" +#include "mlir/Tools/mlir-translate/Translation.h" +#include "mlir/Transforms/Passes.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/register.h" + +namespace mlir { + +namespace { +llvm::cl::opt stripDebuginfoOption( + "strip-debuginfo", llvm::cl::desc("Strip debug info from all operations"), + llvm::cl::init(false)); + +void registerDialectsForSdy(DialectRegistry ®istry) { + mlir::sdy::registerAllDialects(registry); + registry.insert(); +} + +TranslateFromMLIRRegistration serializeRegistration( + "serialize", "Serialize SDY program into a portable artifact", + [](mlir::ModuleOp module, llvm::raw_ostream &os) -> llvm::LogicalResult { + if (stripDebuginfoOption) { + PassManager pm(module->getContext()); + pm.addPass(createStripDebugInfoPass()); + if (failed(pm.run(module))) + return module.emitError("failed to strip debuginfo"); + } + const auto *producer = "SDY"; + BytecodeWriterConfig writerConfig(producer); + return writeBytecodeToFile(module, os, writerConfig); + }, + [](DialectRegistry ®istry) { registerDialectsForSdy(registry); }); + +TranslateToMLIRRegistration deserializeRegistration( + "deserialize", "Deserialize a portable artifact into a SDY program", + [](llvm::StringRef input, mlir::MLIRContext *context) { + context->loadDialect(); + auto module = parseSourceString(input, context); + return module; + }, + [](DialectRegistry ®istry) { registerDialectsForSdy(registry); }); +} // namespace + +} // namespace mlir + +int main(int argc, char **argv) { + return mlir::asMainReturnCode( + mlir::mlirTranslateMain(argc, argv, "SDY transformation driver\n")); +}