Skip to content

[TORCH] Add support for aten.heaviside Op #4220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13077,6 +13077,53 @@ def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [
let hasFolder = 1;
}

def Torch_AtenHeavisideOp : Torch_Op<"aten.heaviside", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::heaviside : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$values
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenHeavisideOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenHeavisideOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenHeaviside_Op : Torch_Op<"aten.heaviside_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::heaviside_ : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_NonValueTensorType:$values
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenHeaviside_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenHeaviside_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenNanToNumOp : Torch_Op<"aten.nan_to_num", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9671,6 +9671,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.heaviside\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.nan_to_num\"(%arg0: !torch.list<int>, %arg1: !torch.optional<float>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -15192,6 +15196,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.heaviside\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.nan_to_num\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<float>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
71 changes: 71 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11068,6 +11068,76 @@ class DecomposeAtenSgnOp : public OpRewritePattern<AtenSgnOp> {
};
} // namespace

namespace {
// Decomposed aten.heaviside op into
// using aten.eq, aten.lt, aten.logical_or, aten.where
// Heaviside(x, y) returns:
// 0 if x < 0
// y if x == 0
// 1 if x > 0
class DecomposeAtenHeaviside : public OpRewritePattern<AtenHeavisideOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenHeavisideOp op,
PatternRewriter &rewriter) const override {
auto input = op.getSelf();
auto value = op.getValues();
auto loc = op.getLoc();
auto inputTy = dyn_cast<BaseTensorType>(input.getType());
if (!inputTy || !inputTy.hasDtype() || !inputTy.hasSizes())
return rewriter.notifyMatchFailure(op, "input must have dtype and size.");

auto valueTy = dyn_cast<BaseTensorType>(value.getType());
if (!valueTy || !valueTy.hasDtype() || !valueTy.hasSizes())
return rewriter.notifyMatchFailure(op, "value must have dtype and size.");
auto resultTy = dyn_cast<BaseTensorType>(op.getType());
SmallVector<int64_t> broadcastShape;
SmallVector<Value> broadcastShapeValue;
computeBroadcastShape(rewriter, loc, input, value, broadcastShape,
broadcastShapeValue);

auto broadcastType = ValueTensorType::get(
op.getContext(), llvm::ArrayRef(broadcastShape), resultTy.getDtype());
auto boolBroadcastType = ValueTensorType::get(
op.getContext(), llvm::ArrayRef(broadcastShape), rewriter.getI1Type());
Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
broadcastShapeValue);
auto inputBroadcasted = rewriter.create<AtenBroadcastToOp>(
loc, broadcastType, input, indexBroadcastShapeTorchList);
auto valueBroadcasted = rewriter.create<AtenBroadcastToOp>(
loc, broadcastType, value, indexBroadcastShapeTorchList);
Comment on lines +11094 to +11109
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not needed. Since you are decomposing this op into elementwise ops, the broadcasting part will be handled during Torch->Linalg lowering.

Copy link
Author

@sharavak sharavak Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vivekkhandelwal1 You're right

But I ran into an issue in a specific case: when the input shape is [1, 2, 3] and the value shape is [1, 1, 1, 1], the broadcasted result shape becomes [1, 1, 2, 3].

Without explicitly broadcasting the inputs, some intermediate ops (like aten.eq.scalar or aten.isnan) end up producing tensors of shape [1, 2, 3], which causes this error:

'tensor.cast' op operand type 'tensor<1x2x3xi1>' and result type 'tensor<1x1x2x3xi1>' are cast incompatible
So to avoid this mismatch, I added explicit broadcasting to ensure all intermediate results match the final shape.


Value zero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0,
resultTy.getDtype());
Value one = getConstantWithGivenDtypeAndValue(rewriter, loc, 1,
resultTy.getDtype());
// Compute mask: input == 0
auto inputEqZero = rewriter
.create<AtenEqScalarOp>(loc, boolBroadcastType,
inputBroadcasted, zero)
->getResult(0);
// Compute mask: input < 0
auto inputLtZero = rewriter.create<AtenLtScalarOp>(loc, boolBroadcastType,
inputBroadcasted, zero);
// Compute mask: isnan(input)
auto isNan =
rewriter.create<AtenIsnanOp>(loc, boolBroadcastType, inputBroadcasted);
Comment on lines +11123 to +11125
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not see the mention of this case here: https://docs.pytorch.org/docs/stable/generated/torch.heaviside.html. Can you please share any reference?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review, @vivekkhandelwal1.
I tested this behavior with PyTorch — if the input contains NaN values, they are replaced with 0.
To handle this explicitly, I used AtenIsnanOp to detect NaN values

input=torch.tensor([[0,float('nan')]])
values=torch.tensor([2],dtype=torch.float32)
torch.heaviside(input,values)

Output
tensor([[2., 0.]])

Ref: https://github.com/pytorch/pytorch/blob/main/torch/_refs/__init__.py#L1448

// Combine: input < 0 || isnan(input)
auto inputNegativeOrNan = rewriter.create<AtenLogicalOrOp>(
loc, boolBroadcastType, inputLtZero, isNan);
// Select 0 if input < 0 or input is nan, else 1
auto zerosOrOnes = rewriter.create<AtenWhereScalarOp>(
loc, resultTy, inputNegativeOrNan, zero, one);
// Final result: if input == 0, take from valueBroadcasted, else take from
// zerosOrOnes
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resultTy, inputEqZero,
valueBroadcasted, zerosOrOnes);
return success();
}
};
} // namespace

namespace {
// Unconditionally decompose `torch.type_as` into `prim.dtype` +
// `torch.to.dtype`.
Expand Down Expand Up @@ -12291,6 +12361,7 @@ class DecomposeComplexOpsPass
DecomposeConstantTensorNewLikeOp<AtenNewOnesOp, AtenOnesOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardtanhOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFullOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHeaviside>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinearOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenSquareOp>();
target.addIllegalOp<AtenVarOp>();
target.addIllegalOp<AtenStdOp>();
target.addIllegalOp<AtenHeavisideOp>();
target.addIllegalOp<Aten_UnsafeViewOp>();
target.addIllegalOp<Aten_ReshapeAliasOp>();
target.addIllegalOp<AtenBernoulliOp>();
Expand Down
6 changes: 6 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,7 @@
"ElementwiseToDtypeI64ToI8Module_basic",
"ElementwiseToDtypeIdentityModule_basic",
"ElementwiseUnaryModule_basic",
"ElementwiseHeavisideModule_basic",
"EmptyLikeMemoryFormatModule_basic",
"EmptyLikeModule_defaultDtype",
"EmptyLikeModule_falsePinMemory",
Expand Down Expand Up @@ -1849,6 +1850,7 @@
"ElementwiseFracModule_basic",
"ElementwiseLdexpModule_basic",
"ElementwiseSignbitIntModule_basic",
"ElementwiseHeavisideModule_basic",
"Exp2StaticIntModule_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
Expand Down Expand Up @@ -2958,6 +2960,8 @@
"GtFloatIntModule_basic",
"GtIntModule_basic",
"HardtanhBackward_basic",
"ElementwiseHeavisideModule_basic",
"ElementwiseHeavisideIntModule_basic",
"HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic",
"HstackBasicIntFloatModule_basic",
Expand Down Expand Up @@ -3958,6 +3962,8 @@
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
"ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
"ElementwiseHeavisideModule_basic",
"ElementwiseHeavisideIntModule_basic",
"RreluWithNoiseBackwardEvalModule_basic",
"RreluWithNoiseBackwardEvalStaticModule_basic",
"RreluWithNoiseBackwardTrainModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1772,6 +1772,9 @@ def aten〇where〇ScalarOther〡shape(condition: List[int], self: List[int], ot
def aten〇where〇ScalarSelf〡shape(condition: List[int], self: float, other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(condition, other)

def aten〇heaviside〡shape(self: List[int], values: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, values)

def aten〇nan_to_num〡shape(self: List[int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -5069,6 +5072,14 @@ def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], sel
dtypes = [get_dtype_of_scalar(self), other_dtype]
return promote_dtypes(ranks, dtypes)

def aten〇heaviside〡dtype(self_rank_dtype: Tuple[int, int], values_rank_dtype: Tuple[int, int]) -> int:
self_rank,self_dtype = self_rank_dtype
values_rank,values_dtype = values_rank_dtype
ranks: List[Optional[int]] = [self_rank, values_rank]
dtypes = [self_dtype, values_dtype]
promoted_dtype = promote_dtypes(ranks, dtypes)
return promoted_dtype

@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇nan_to_num〡dtype(self_rank_dtype: Tuple[int, int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)", has_folder=True
)
emit_with_mutating_variants("aten::heaviside : (Tensor, Tensor) -> (Tensor)")
emit("aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)")
emit(
"aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,46 @@ def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseHeavisideModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([5], torch.float32, True), ([1], torch.float32, True)])
def forward(self, x, values):
return torch.heaviside(x, values)


@register_test_case(module_factory=lambda: ElementwiseHeavisideModule())
def ElementwiseHeavisideModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor([1.0, -2.0, torch.inf, torch.nan, -torch.inf]), torch.tensor([5.0])
)


class ElementwiseHeavisideIntModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[None, ([-1, -1, -1], torch.int64, True), ([-1, -1, -1, -1], torch.int64, True)]
)
def forward(self, x, values):
return torch.heaviside(x, values)


@register_test_case(module_factory=lambda: ElementwiseHeavisideIntModule())
def ElementwiseHeavisideIntModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(1, 2, 3, low=-100, high=1000),
tu.randint(1, 1, 1, 1, low=-100, high=1000),
)


# ==============================================================================


class ElementwiseLtIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down