diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c851f82dba72..d2aefb4db9a1 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -13077,6 +13077,53 @@ def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [ let hasFolder = 1; } +def Torch_AtenHeavisideOp : Torch_Op<"aten.heaviside", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::heaviside : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$values + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenHeavisideOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenHeavisideOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenHeaviside_Op : Torch_Op<"aten.heaviside_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::heaviside_ : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$values + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenHeaviside_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenHeaviside_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenNanToNumOp : Torch_Op<"aten.nan_to_num", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 9c355f4ea4a8..242d2f08a38d 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9671,6 +9671,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.heaviside\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.nan_to_num\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -15192,6 +15196,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.heaviside\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.nan_to_num\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 62ce02df50a6..0efdc2cd575b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -11068,6 +11068,76 @@ class DecomposeAtenSgnOp : public OpRewritePattern { }; } // namespace +namespace { +// Decomposed aten.heaviside op into +// using aten.eq, aten.lt, aten.logical_or, aten.where +// Heaviside(x, y) returns +// 0 if x < 0 +// y if x == 0 +// 1 if x > 0 +class DecomposeAtenHeaviside : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenHeavisideOp op, + PatternRewriter &rewriter) const override { + auto input = op.getSelf(); + auto value = op.getValues(); + auto loc = op.getLoc(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy || !inputTy.hasDtype() || !inputTy.hasSizes()) + return rewriter.notifyMatchFailure(op, "input must have dtype and size."); + + auto valueTy = dyn_cast(value.getType()); + if (!valueTy || !valueTy.hasDtype() || !valueTy.hasSizes()) + return rewriter.notifyMatchFailure(op, "value must have dtype and size."); + auto resultTy = dyn_cast(op.getType()); + SmallVector broadcastShape; + SmallVector broadcastShapeValue; + computeBroadcastShape(rewriter, loc, input, value, broadcastShape, + broadcastShapeValue); + + auto broadcastType = ValueTensorType::get( + op.getContext(), llvm::ArrayRef(broadcastShape), resultTy.getDtype()); + auto boolBroadcastType = ValueTensorType::get( + op.getContext(), llvm::ArrayRef(broadcastShape), rewriter.getI1Type()); + Value indexBroadcastShapeTorchList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + broadcastShapeValue); + auto inputBroadcasted = rewriter.create( + loc, broadcastType, input, indexBroadcastShapeTorchList); + auto valueBroadcasted = rewriter.create( + loc, broadcastType, value, indexBroadcastShapeTorchList); + + Value zero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0, + resultTy.getDtype()); + Value one = getConstantWithGivenDtypeAndValue(rewriter, loc, 1, + resultTy.getDtype()); + // Compute mask: input == 0 + auto inputEqZero = rewriter + .create(loc, boolBroadcastType, + inputBroadcasted, zero) + ->getResult(0); + // Compute mask: input < 0 + auto inputLtZero = rewriter.create(loc, boolBroadcastType, + inputBroadcasted, zero); + // Compute mask: isnan(input) + auto isNan = + rewriter.create(loc, boolBroadcastType, inputBroadcasted); + // Combine: input < 0 || isnan(input) + auto inputNegativeOrNan = rewriter.create( + loc, boolBroadcastType, inputLtZero, isNan); + // Select 0 if input < 0 or input is nan, else 1 + auto zerosOrOnes = rewriter.create( + loc, resultTy, inputNegativeOrNan, zero, one); + // Final result: if input == 0, take from valueBroadcasted, else take from + // zerosOrOnes + rewriter.replaceOpWithNewOp(op, resultTy, inputEqZero, + valueBroadcasted, zerosOrOnes); + return success(); + } +}; +} // namespace + namespace { // Unconditionally decompose `torch.type_as` into `prim.dtype` + // `torch.to.dtype`. @@ -12291,6 +12361,7 @@ class DecomposeComplexOpsPass DecomposeConstantTensorNewLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 13d5a1f2ab8b..afa237ab8d9a 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -460,6 +460,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 3b5f651c8903..729a178ec10d 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1249,6 +1249,7 @@ "ElementwiseToDtypeI64ToI8Module_basic", "ElementwiseToDtypeIdentityModule_basic", "ElementwiseUnaryModule_basic", + "ElementwiseHeavisideModule_basic", "EmptyLikeMemoryFormatModule_basic", "EmptyLikeModule_defaultDtype", "EmptyLikeModule_falsePinMemory", @@ -1849,6 +1850,7 @@ "ElementwiseFracModule_basic", "ElementwiseLdexpModule_basic", "ElementwiseSignbitIntModule_basic", + "ElementwiseHeavisideModule_basic", "Exp2StaticIntModule_basic", "MaxPool1dEmptyStrideStaticModule_basic", "MaxPool1dStaticCeilModeTrueModule_basic", @@ -2958,6 +2960,8 @@ "GtFloatIntModule_basic", "GtIntModule_basic", "HardtanhBackward_basic", + "ElementwiseHeavisideModule_basic", + "ElementwiseHeavisideIntModule_basic", "HstackBasicComplexModule_basic", "HstackBasicFloatModule_basic", "HstackBasicIntFloatModule_basic", @@ -3958,6 +3962,8 @@ "ElementwiseRreluWithNoiseEvalStaticModule_basic", "ElementwiseRreluWithNoiseTrainModule_basic", "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "ElementwiseHeavisideModule_basic", + "ElementwiseHeavisideIntModule_basic", "RreluWithNoiseBackwardEvalModule_basic", "RreluWithNoiseBackwardEvalStaticModule_basic", "RreluWithNoiseBackwardTrainModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index a069550ec669..fea545dea21a 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1772,6 +1772,9 @@ def aten〇where〇ScalarOther〡shape(condition: List[int], self: List[int], ot def aten〇where〇ScalarSelf〡shape(condition: List[int], self: float, other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(condition, other) +def aten〇heaviside〡shape(self: List[int], values: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, values) + def aten〇nan_to_num〡shape(self: List[int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> List[int]: return upstream_shape_functions.unary(self) @@ -5069,6 +5072,14 @@ def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], sel dtypes = [get_dtype_of_scalar(self), other_dtype] return promote_dtypes(ranks, dtypes) +def aten〇heaviside〡dtype(self_rank_dtype: Tuple[int, int], values_rank_dtype: Tuple[int, int]) -> int: + self_rank,self_dtype = self_rank_dtype + values_rank,values_dtype = values_rank_dtype + ranks: List[Optional[int]] = [self_rank, values_rank] + dtypes = [self_dtype, values_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇nan_to_num〡dtype(self_rank_dtype: Tuple[int, int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> int: diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 6817d285faea..03f66d4d9f07 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -958,6 +958,7 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)", has_folder=True ) + emit_with_mutating_variants("aten::heaviside : (Tensor, Tensor) -> (Tensor)") emit("aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)") emit( "aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py index 304bc422e4d2..def51d606a6b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py @@ -298,6 +298,46 @@ def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseHeavisideModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([5], torch.float32, True), ([1], torch.float32, True)]) + def forward(self, x, values): + return torch.heaviside(x, values) + + +@register_test_case(module_factory=lambda: ElementwiseHeavisideModule()) +def ElementwiseHeavisideModule_basic(module, tu: TestUtils): + module.forward( + torch.tensor([1.0, -2.0, torch.inf, torch.nan, -torch.inf]), torch.tensor([5.0]) + ) + + +class ElementwiseHeavisideIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [None, ([-1, -1, -1], torch.int64, True), ([-1, -1, -1, -1], torch.int64, True)] + ) + def forward(self, x, values): + return torch.heaviside(x, values) + + +@register_test_case(module_factory=lambda: ElementwiseHeavisideIntModule()) +def ElementwiseHeavisideIntModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(1, 2, 3, low=-100, high=1000), + tu.randint(1, 1, 1, 1, low=-100, high=1000), + ) + + +# ============================================================================== + + class ElementwiseLtIntScalarModule(torch.nn.Module): def __init__(self): super().__init__()