diff --git a/shardy/dialect/sdy/transforms/import/BUILD b/shardy/dialect/sdy/transforms/import/BUILD index f5028c29..8b1a7459 100644 --- a/shardy/dialect/sdy/transforms/import/BUILD +++ b/shardy/dialect/sdy/transforms/import/BUILD @@ -48,6 +48,7 @@ cc_library( "lift_inlined_meshes.cc", "manual_axes_cleanup.cc", "sharding_group_import.cc", + "translate_mesh.cc", ], hdrs = [ "passes.h", diff --git a/shardy/dialect/sdy/transforms/import/passes.td b/shardy/dialect/sdy/transforms/import/passes.td index c2ed7f5e..ddfedc12 100644 --- a/shardy/dialect/sdy/transforms/import/passes.td +++ b/shardy/dialect/sdy/transforms/import/passes.td @@ -153,3 +153,22 @@ def ManualAxesCleanupPass : Pass<"sdy-manual-axes-cleanup", "ModuleOp"> { }]; let dependentDialects = ["mlir::sdy::SdyDialect"]; } + +def TranslateMeshPass : Pass<"sdy-translate-mesh", "ModuleOp"> { + let summary = "Replaces "; + let description = [{ + 1. For any in/out sharding that hasn't specified a manual axis, add that + manual axis to its replicated_axes. This is to ensure manual axes are + always fully specified. + 2. Sorts the manual axes in mesh axis declaration order. + }]; + let dependentDialects = ["mlir::sdy::SdyDialect"]; + + let options = [ + ListOption<"axisNames", "axis-names", "std::string", + "Names of the new axes.">, + Option<"oldMeshName", "old-mesh-name", "std::string", + /*default=*/"", + "The name of the old mesh to be replaced."> + ]; +} diff --git a/shardy/dialect/sdy/transforms/import/test/translate_mesh.mlir b/shardy/dialect/sdy/transforms/import/test/translate_mesh.mlir new file mode 100644 index 00000000..4839da09 --- /dev/null +++ b/shardy/dialect/sdy/transforms/import/test/translate_mesh.mlir @@ -0,0 +1,15 @@ +// RUN: sdy_opt %s -sdy-translate-mesh="old-mesh-name=my_mesh axis-names='data,model'" 2>&1 | FileCheck %s + +// CHECK-LABEL: @my_mesh +// CHECK-SAME{LITERAL}: <["data"=2, "model"=4]> +sdy.mesh @my_mesh = <["a"=2, "b"=4]> + +// CHECK-NOT: <["a"=2, "b"=4]> + +// CHECK-LABEL: @foo +func.func @foo(%arg0 : tensor<8x8xf32>) -> tensor<8x8xf32> { + // CHECK-NEXT: stablehlo.add + // CHECK-SAME{LITERAL}: #sdy.sharding_per_value<[<@my_mesh, [{"data", ?}p1, {}], replicated={"model"}>]> + %0 = stablehlo.add %arg0, %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@my_mesh, [{"a", ?}p1, {}], replicated={"b"}>]>} : tensor<8x8xf32> + return %0 : tensor<8x8xf32> +} diff --git a/shardy/dialect/sdy/transforms/import/translate_mesh.cc b/shardy/dialect/sdy/transforms/import/translate_mesh.cc new file mode 100644 index 00000000..25de9d41 --- /dev/null +++ b/shardy/dialect/sdy/transforms/import/translate_mesh.cc @@ -0,0 +1,133 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include // IWYU pragma: keep +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/transforms/common/sharding_walker.h" +#include "shardy/dialect/sdy/transforms/import/passes.h" // IWYU pragma: keep + +namespace mlir { +namespace sdy { + +#define GEN_PASS_DEF_TRANSLATEMESHPASS +#include "shardy/dialect/sdy/transforms/import/passes.h.inc" + +namespace { + +LogicalResult translateMesh(ModuleOp moduleOp, + StringRef oldMeshName, + ArrayRef newMeshAxisNames) { + MLIRContext* context = moduleOp.getContext(); + auto oldMeshOp = SymbolTable::lookupNearestSymbolFrom( + moduleOp, SymbolRefAttr::get(context, oldMeshName)); + if (!oldMeshOp) { + return moduleOp.emitError() + << "Mesh " << oldMeshName << " not found in module."; + } + sdy::MeshAttr oldMesh = oldMeshOp.getMesh(); + if (oldMesh.getAxes().size() != newMeshAxisNames.size()) { + return moduleOp.emitError() + << "Both meshes must have the same number of axes."; + } + llvm::StringMap oldToNewAxis; + bool sameMesh = true; + for (auto [oldAxis, newAxisName] : + llvm::zip_equal(oldMesh.getAxes(), newMeshAxisNames)) { + oldToNewAxis[oldAxis.getName()] = newAxisName; + if (oldAxis.getName() != newAxisName) { + sameMesh = false; + } + } + // Exit early if the meshes are the exact same. + if (sameMesh) { + return success(); + } + StringAttr meshName = StringAttr::get(context, oldMeshName); + sdy::transformShardings( + moduleOp, + [&](sdy::TensorShardingAttr oldSharding) -> sdy::TensorShardingAttr { + SmallVector newDimShardings; + for (auto oldDimSharding : oldSharding.getDimShardings()) { + SmallVector newAxisRefs; + llvm::transform(oldDimSharding.getAxes(), + std::back_inserter(newAxisRefs), + [&](sdy::AxisRefAttr oldAxisRef) { + return sdy::AxisRefAttr::get( + context, oldToNewAxis[oldAxisRef.getName()], + oldAxisRef.getSubAxisInfo()); + }); + newDimShardings.push_back(sdy::DimensionShardingAttr::get( + context, newAxisRefs, oldDimSharding.getIsClosed(), + oldDimSharding.getPriority())); + } + SmallVector newReplicatedAxes; + llvm::transform(oldSharding.getReplicatedAxes(), + std::back_inserter(newReplicatedAxes), + [&](sdy::AxisRefAttr oldAxisRef) { + return sdy::AxisRefAttr::get( + context, oldToNewAxis[oldAxisRef.getName()], + oldAxisRef.getSubAxisInfo()); + }); + return sdy::TensorShardingAttr::get(context, meshName, newDimShardings, + newReplicatedAxes); + }); + SmallVector newAxes; + newAxes.reserve(newMeshAxisNames.size()); + for (const auto& [axisName, oldAxis] : + llvm::zip_equal(newMeshAxisNames, oldMesh.getAxes())) { + newAxes.push_back(MeshAxisAttr::get(context, axisName, oldAxis.getSize())); + } + IRRewriter rewriter(moduleOp); + rewriter.setInsertionPoint(oldMeshOp); + SymbolTable symbolTable(moduleOp); + auto newMeshOp = rewriter.create( + moduleOp.getLoc(), oldMeshName, + MeshAttr::get(context, newAxes, oldMesh.getDeviceIds())); + symbolTable.erase(oldMeshOp); + symbolTable.insert(newMeshOp); + return success(); +} + +struct TranslateMeshPass + : public impl::TranslateMeshPassBase { + using TranslateMeshPassBase::TranslateMeshPassBase; + + void runOnOperation() final { + if (translateMesh(getOperation(), oldMeshName, llvm::to_vector(axisNames)) + .failed()) { + signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace sdy +} // namespace mlir