Skip to content

Commit 519dc56

Browse files
committed
add the code for rms_norm op
Signed-off-by: sharavana20 <[email protected]>
1 parent 1cb25e9 commit 519dc56

File tree

8 files changed

+255
-0
lines changed

8 files changed

+255
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7454,6 +7454,32 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
74547454
}];
74557455
}
74567456

7457+
def Torch_AtenRmsNormOp : Torch_Op<"aten.rms_norm", [
7458+
AllowsTypeRefinement,
7459+
HasValueSemantics,
7460+
ReadOnly
7461+
]> {
7462+
let summary = "Generated op for `aten::rms_norm : (Tensor, int[], Tensor?, float?) -> (Tensor)`";
7463+
let arguments = (ins
7464+
AnyTorchTensorType:$input,
7465+
AnyTorchListOfTorchIntType:$normalized_shape,
7466+
AnyTorchOptionalTensorType:$weight,
7467+
AnyTorchOptionalFloatType:$eps
7468+
);
7469+
let results = (outs
7470+
AnyTorchOptionalTensorType:$result
7471+
);
7472+
let hasCustomAssemblyFormat = 1;
7473+
let extraClassDefinition = [{
7474+
ParseResult AtenRmsNormOp::parse(OpAsmParser &parser, OperationState &result) {
7475+
return parseDefaultTorchOp(parser, result, 4, 1);
7476+
}
7477+
void AtenRmsNormOp::print(OpAsmPrinter &printer) {
7478+
printDefaultTorchOp(printer, *this, 4, 1);
7479+
}
7480+
}];
7481+
}
7482+
74577483
def Torch_AtenRenormOp : Torch_Op<"aten.renorm", [
74587484
AllowsTypeRefinement,
74597485
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7326,6 +7326,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
73267326
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
73277327
" return %0 : !torch.list<int>\n"
73287328
" }\n"
7329+
" func.func @\"__torch_mlir_shape_fn.aten.rms_norm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
7330+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
7331+
" return %0 : !torch.list<int>\n"
7332+
" }\n"
73297333
" func.func @\"__torch_mlir_shape_fn.aten._softmax_backward_data\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
73307334
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
73317335
" return %0 : !torch.list<int>\n"
@@ -12732,6 +12736,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1273212736
" }\n"
1273312737
" return %0#1 : !torch.int\n"
1273412738
" }\n"
12739+
" func.func @\"__torch_mlir_dtype_fn.aten.rms_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<float>) -> !torch.int {\n"
12740+
" %none = torch.constant.none\n"
12741+
" %str = torch.constant.str \"AssertionError: \"\n"
12742+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12743+
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
12744+
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
12745+
" torch.prim.If %2 -> () {\n"
12746+
" torch.prim.If.yield\n"
12747+
" } else {\n"
12748+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12749+
" torch.prim.If.yield\n"
12750+
" }\n"
12751+
" return %0#1 : !torch.int\n"
12752+
" }\n"
1273512753
" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.bool) -> !torch.int {\n"
1273612754
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1273712755
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7485,6 +7485,80 @@ class DecomposeAtenNativeLayerNormOp
74857485
};
74867486
} // namespace
74877487

7488+
// rms(x)=sqrt(eps+mean(x^2))
7489+
namespace {
7490+
class DecomposeAtenRMSLayerNormOp : public OpRewritePattern<AtenRmsNormOp> {
7491+
using OpRewritePattern<AtenRmsNormOp>::OpRewritePattern;
7492+
7493+
LogicalResult matchAndRewrite(AtenRmsNormOp op,
7494+
PatternRewriter &rewriter) const override {
7495+
Location loc = op.getLoc();
7496+
auto context = op.getContext();
7497+
auto input = op.getInput();
7498+
auto inputTy = dyn_cast<ValueTensorType>(input.getType());
7499+
auto outputTy = dyn_cast<ValueTensorType>(op.getType());
7500+
if (!inputTy.hasSizes())
7501+
return rewriter.notifyMatchFailure(
7502+
op, "input tensor should have known size.");
7503+
if (!outputTy.hasDtype())
7504+
return rewriter.notifyMatchFailure(op, "output should have a dtype.");
7505+
7506+
int64_t inputRank = inputTy.getSizes().size();
7507+
Value normalizedShape = op.getNormalizedShape();
7508+
SmallVector<Value> normalizedShapeSizesTorchInt;
7509+
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
7510+
7511+
int64_t normalize_from_idx =
7512+
inputRank - normalizedShapeSizesTorchInt.size();
7513+
auto reduceDimInts =
7514+
llvm::to_vector<4>(llvm::seq<int64_t>(normalize_from_idx, inputRank));
7515+
auto sizeListType = ListType::get(IntType::get(context));
7516+
7517+
SmallVector<Value> reduceDimVals;
7518+
for (int64_t dim : reduceDimInts)
7519+
reduceDimVals.push_back(rewriter.create<Torch::ConstantIntOp>(
7520+
loc, rewriter.getI64IntegerAttr(dim)));
7521+
Value reduceDimList =
7522+
rewriter.create<PrimListConstructOp>(loc, sizeListType, reduceDimVals);
7523+
7524+
auto inputShape = inputTy.getSizes();
7525+
SmallVector<int64_t> reducedShape(inputShape.begin(), inputShape.end());
7526+
for (int64_t i : reduceDimInts)
7527+
reducedShape[i] = 1;
7528+
7529+
auto reducedTy =
7530+
ValueTensorType::get(context, reducedShape, inputTy.getDtype());
7531+
7532+
Value inputSquared = rewriter.create<AtenSquareOp>(loc, inputTy, input);
7533+
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
7534+
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
7535+
// mean(x)
7536+
Value mean = rewriter.create<AtenMeanDimOp>(loc, reducedTy, inputSquared,
7537+
reduceDimList, cstTrue, none);
7538+
// (mean(x)+eps)
7539+
if (!isa<Torch::NoneType>(op.getEps().getType())) {
7540+
Value one = rewriter.create<Torch::ConstantIntOp>(
7541+
loc, rewriter.getI64IntegerAttr(1));
7542+
mean = rewriter.create<AtenAddScalarOp>(loc, reducedTy, mean, op.getEps(),
7543+
one);
7544+
}
7545+
// rsqrt(mean(x)+eps)
7546+
Value invRMS = rewriter.create<AtenRsqrtOp>(loc, reducedTy, mean);
7547+
// rsqrt(mean(x)+eps)*x
7548+
Value normalized =
7549+
rewriter.create<AtenMulTensorOp>(loc, inputTy, input, invRMS);
7550+
7551+
Value weight = op.getWeight();
7552+
if (!isa<Torch::NoneType>(weight.getType())) {
7553+
normalized =
7554+
rewriter.create<AtenMulTensorOp>(loc, outputTy, normalized, weight);
7555+
}
7556+
rewriter.replaceOp(op, normalized);
7557+
return success();
7558+
}
7559+
};
7560+
} // namespace
7561+
74887562
namespace {
74897563
// Decompose `aten.emptyLike` op into `aten.size` and `aten.empty` ops.
74907564
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
@@ -12070,6 +12144,7 @@ class DecomposeComplexOpsPass
1207012144
addPatternIfTargetOpIsIllegal<DecomposeAtenInstanceNormOp>(patterns);
1207112145
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
1207212146
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
12147+
addPatternIfTargetOpIsIllegal<DecomposeAtenRMSLayerNormOp>(patterns);
1207312148
addPatternIfTargetOpIsIllegal<DecomposeAtenGroupNormOp>(patterns);
1207412149
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeGroupNormOp>(patterns);
1207512150
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
435435
target.addIllegalOp<AtenInstanceNormOp>();
436436
target.addIllegalOp<AtenLayerNormOp>();
437437
target.addIllegalOp<AtenNativeLayerNormOp>();
438+
target.addIllegalOp<AtenRmsNormOp>();
438439
target.addIllegalOp<AtenGroupNormOp>();
439440
target.addIllegalOp<AtenNativeGroupNormOp>();
440441
target.addIllegalOp<AtenNativeBatchNormOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,6 +1472,10 @@
14721472
"Rot90MultipleRotationsModule_basic",
14731473
"Rot90NegativeEvenRotationsModule_basic",
14741474
"Rot90NegativeOddRotationsModule_basic",
1475+
"RMSNormModule_basic",
1476+
"RMSNormWithoutEpsModule_basic",
1477+
"RMSNormWithoutWeightModule_basic",
1478+
"RMSNormAllNormalizeModule_basic",
14751479
"RsubInt0d_NumToTensor_Module_basic",
14761480
"ScalarConstantTupleModule_basic",
14771481
"ScalarImplicitFloatModule_basic",
@@ -2330,6 +2334,10 @@
23302334
"IscloseStaticModuleTrue_basic",
23312335
"IscloseStaticModule_basic",
23322336
"LayerNormNormalizeOverAllDimsModule_basic",
2337+
"RMSNormModule_basic",
2338+
"RMSNormWithoutEpsModule_basic",
2339+
"RMSNormWithoutWeightModule_basic",
2340+
"RMSNormAllNormalizeModule_basic",
23332341
"LeakyReluBackwardModule_basic",
23342342
"LeakyReluBackwardStaticModule_basic",
23352343
"LiftFreshCopyModule_basic",
@@ -3036,6 +3044,11 @@
30363044
"NativeGroupNormBackwardModule_basic",
30373045
"NativeGroupNormModule_basic",
30383046
"NativeLayerNormDynamicModule_basic",
3047+
"RMSNormModule_basic",
3048+
"RMSNormWithoutEpsModule_basic",
3049+
"RMSNormWithoutWeightModule_basic",
3050+
"RMSNormAllNormalizeModule_basic",
3051+
"RMSNormDynamicModule_basic",
30393052
"NeFloatIntModule_basic",
30403053
"NeIntModule_basic",
30413054
"NewEmptyStridedModuleDefaultDtype_basic",
@@ -4724,6 +4737,11 @@
47244737
"ReshapeCollapseModule_basic",
47254738
"ReshapeDynamicModule_basic",
47264739
"ReshapeExpandModule_basic",
4740+
"RMSNormModule_basic",
4741+
"RMSNormWithoutEpsModule_basic",
4742+
"RMSNormWithoutWeightModule_basic",
4743+
"RMSNormAllNormalizeModule_basic",
4744+
"RMSNormDynamicModule_basic",
47274745
"RollModule_basic",
47284746
"RsubIntModule_noalpha_basic",
47294747
"ScalarConstantTupleModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,9 @@ def aten〇gather〡shape(self: List[int], dim: int, index: List[int], sparse_gr
664664
def aten〇layer_norm〡shape(input: List[int], normalized_shape: List[int], weight: Optional[List[int]] = None, bias: Optional[List[int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enable: bool = True) -> List[int]:
665665
return upstream_shape_functions.unary(input)
666666

667+
def aten〇rms_norm〡shape(input: List[int], normalized_shape: List[int], weight: Optional[List[int]] = None, eps: Optional[float] = None) -> List[int]:
668+
return upstream_shape_functions.unary(input)
669+
667670
def aten〇_softmax_backward_data〡shape(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]:
668671
return upstream_shape_functions.unary(output)
669672

@@ -3420,6 +3423,13 @@ def aten〇layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shap
34203423
assert not is_integer_dtype(input_dtype)
34213424
return input_dtype
34223425

3426+
@check_dtype_function(_check_tensors_with_the_same_dtype(
3427+
num_of_tensors=1, error_types={*all_integer_dtypes()}, normalized_shape=[1]))
3428+
def aten〇rms_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shape: List[int], weight_rank_dtype: Optional[Tuple[int, int]] = None, eps: Optional[float] = None) -> int:
3429+
input_rank, input_dtype = input_rank_dtype
3430+
assert not is_integer_dtype(input_dtype)
3431+
return input_dtype
3432+
34233433
@check_dtype_function(_check_two_tensor_op(negative_slope=0.1, self_is_result=False))
34243434
def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float, complex], self_is_result: bool) -> int:
34253435
grad_output_rank, grad_output_dtype = grad_output_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,7 @@ def emit_with_mutating_variants(key, **kwargs):
640640
emit(
641641
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
642642
)
643+
emit("aten::rms_norm : (Tensor, int[], Tensor?, float?) -> (Tensor)")
643644
emit("aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)", has_verifier=True)
644645
emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True)
645646
emit("aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,112 @@ def AtenInstanceNormModule_basic(module, tu: TestUtils):
635635
module.forward(tu.rand(1, 2, 1, 3), tu.rand(2), tu.rand(2))
636636

637637

638+
# ==============================================================================
639+
class RMSNormModule(torch.nn.Module):
640+
def __init__(self):
641+
super().__init__()
642+
643+
@export
644+
@annotate_args(
645+
[
646+
None,
647+
([8, 9, 1, 2, 4], torch.float32, True),
648+
([1, 2, 4], torch.float32, True),
649+
]
650+
)
651+
def forward(self, x, weight):
652+
list = [1, 2, 4]
653+
return torch.ops.aten.rms_norm(x, list, weight, eps=0.5)
654+
655+
656+
@register_test_case(module_factory=lambda: RMSNormModule())
657+
def RMSNormModule_basic(module, tu: TestUtils):
658+
module.forward(tu.rand(8, 9, 1, 2, 4), tu.rand(1, 2, 4))
659+
660+
661+
class RMSNormWithoutEpsModule(torch.nn.Module):
662+
def __init__(self):
663+
super().__init__()
664+
665+
@export
666+
@annotate_args(
667+
[
668+
None,
669+
([2, 5, 2, 2, 3], torch.float32, True),
670+
([2, 2, 3], torch.float32, True),
671+
]
672+
)
673+
def forward(self, x, weight):
674+
list = [2, 2, 3]
675+
return torch.ops.aten.rms_norm(x, list, weight)
676+
677+
678+
@register_test_case(module_factory=lambda: RMSNormWithoutEpsModule())
679+
def RMSNormWithoutEpsModule_basic(module, tu: TestUtils):
680+
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3))
681+
682+
683+
class RMSNormWithoutWeightModule(torch.nn.Module):
684+
def __init__(self):
685+
super().__init__()
686+
687+
@export
688+
@annotate_args(
689+
[
690+
None,
691+
([1, 2, 3, 4], torch.float32, True),
692+
]
693+
)
694+
def forward(self, x):
695+
list = [4]
696+
return torch.ops.aten.rms_norm(x, list, eps=0.5)
697+
698+
699+
@register_test_case(module_factory=lambda: RMSNormWithoutWeightModule())
700+
def RMSNormWithoutWeightModule_basic(module, tu: TestUtils):
701+
module.forward(tu.rand(1, 2, 3, 4))
702+
703+
704+
class RMSNormAllNormalizeModule(torch.nn.Module):
705+
def __init__(self):
706+
super().__init__()
707+
708+
@export
709+
@annotate_args(
710+
[None, ([5, 6, 3], torch.float32, True), ([5, 6, 3], torch.float32, True)]
711+
)
712+
def forward(self, x, weight):
713+
list = [5, 6, 3]
714+
return torch.ops.aten.rms_norm(x, list, weight, eps=0.7)
715+
716+
717+
@register_test_case(module_factory=lambda: RMSNormAllNormalizeModule())
718+
def RMSNormAllNormalizeModule_basic(module, tu: TestUtils):
719+
module.forward(tu.rand(5, 6, 3), tu.rand(5, 6, 3))
720+
721+
722+
class RMSNormDynamicModule(torch.nn.Module):
723+
def __init__(self):
724+
super().__init__()
725+
726+
@export
727+
@annotate_args(
728+
[
729+
None,
730+
([-1, -1, -1, -1], torch.float32, True),
731+
([-1, -1, -1], torch.float32, True),
732+
]
733+
)
734+
def forward(self, x, weight):
735+
list = [2, 3, 4]
736+
return torch.ops.aten.rms_norm(x, list, weight, eps=0.8)
737+
738+
739+
@register_test_case(module_factory=lambda: RMSNormDynamicModule())
740+
def RMSNormDynamicModule_basic(module, tu: TestUtils):
741+
module.forward(tu.rand(1, 2, 3, 4), tu.rand(2, 3, 4))
742+
743+
638744
# ==============================================================================
639745
class RenormModuleFloat32(torch.nn.Module):
640746
def __init__(self):

0 commit comments

Comments
 (0)