Skip to content

Commit 820d714

Browse files
committed
supported the _logcumsumexp Op
Signed-off-by: sharavana20 <[email protected]>
1 parent c6624b1 commit 820d714

File tree

6 files changed

+102
-36
lines changed

6 files changed

+102
-36
lines changed

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9368,6 +9368,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
93689368
" func.func @\"__torch_mlir_shape_fn.aten.logcumsumexp\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
93699369
" return %arg0 : !torch.list<int>\n"
93709370
" }\n"
9371+
" func.func @\"__torch_mlir_shape_fn.aten._logcumsumexp\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
9372+
" return %arg0 : !torch.list<int>\n"
9373+
" }\n"
93719374
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
93729375
" return %arg0 : !torch.list<int>\n"
93739376
" }\n"
@@ -12547,9 +12550,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1254712550
" return %1 : !torch.int\n"
1254812551
" }\n"
1254912552
" func.func @\"__torch_mlir_dtype_fn.aten.logcumsumexp\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
12550-
" %int1 = torch.constant.int 1\n"
12551-
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
12552-
" return %0 : !torch.int\n"
12553+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12554+
" return %0#1 : !torch.int\n"
12555+
" }\n"
12556+
" func.func @\"__torch_mlir_dtype_fn.aten._logcumsumexp\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
12557+
" %none = torch.constant.none\n"
12558+
" %str = torch.constant.str \"AssertionError: \"\n"
12559+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12560+
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
12561+
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
12562+
" torch.prim.If %2 -> () {\n"
12563+
" torch.prim.If.yield\n"
12564+
" } else {\n"
12565+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12566+
" torch.prim.If.yield\n"
12567+
" }\n"
12568+
" return %0#1 : !torch.int\n"
1255312569
" }\n"
1255412570
" func.func @\"__torch_mlir_dtype_fn.aten.cumprod\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
1255512571
" %int4 = torch.constant.int 4\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2962,58 +2962,48 @@ class DecomposeAten_LogSoftmaxOp : public OpRewritePattern<Aten_LogSoftmaxOp> {
29622962
};
29632963
} // namespace
29642964

2965-
// Decompose AtenLogCumsumExpOp to:
2966-
// AtenExpOp
2967-
// AtenCumsumOp
2968-
// AtenLogOp
2965+
// Decompose AtenLogCumsumExpOp into: AtenExpOp,
2966+
// AtenCumsumOp and AtenLogOp
2967+
// logcumsumexp(x)[i][j] = log(sum_{k=0}^{j} exp(x[i][k]))
2968+
29692969
namespace {
2970+
template <typename OpTy>
29702971

2971-
class DecomposeAtenLogCumsumExpOp
2972-
: public OpRewritePattern<AtenLogcumsumexpOp> {
2972+
class DecomposeAtenLogCumsumExpOp : public OpRewritePattern<OpTy> {
29732973
public:
2974-
using OpRewritePattern<AtenLogcumsumexpOp>::OpRewritePattern;
2975-
LogicalResult matchAndRewrite(AtenLogcumsumexpOp op,
2974+
using OpRewritePattern<OpTy>::OpRewritePattern;
2975+
LogicalResult matchAndRewrite(OpTy op,
29762976
PatternRewriter &rewriter) const override {
29772977
Location loc = op.getLoc();
29782978
Value input = op.getSelf();
29792979

29802980
auto inputType = dyn_cast<BaseTensorType>(input.getType());
2981-
if (!inputType || !inputType.getDtype())
2981+
auto resultType = dyn_cast<BaseTensorType>(op.getType());
2982+
if (!inputType)
29822983
return rewriter.notifyMatchFailure(op, "Supports only tensor type");
29832984

29842985
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype()))
2985-
return rewriter.notifyMatchFailure(op, "Support only floating type");
2986-
2987-
Type elementType = inputType.getDtype();
2988-
torch_upstream::ScalarType scalarType;
2989-
// logcumsumexp is only supported for Float datatype
2990-
if (elementType.isF16())
2991-
scalarType = torch_upstream::ScalarType::Half;
2992-
else if (elementType.isF32())
2993-
scalarType = torch_upstream::ScalarType::Float;
2994-
else
2995-
scalarType = torch_upstream::ScalarType::Double;
2996-
2997-
int64_t scalarVal = static_cast<int64_t>(scalarType);
2998-
2999-
Value dtypeVal = rewriter.create<Torch::ConstantIntOp>(
3000-
loc, rewriter.getType<Torch::IntType>(), scalarVal);
2986+
return rewriter.notifyMatchFailure(
2987+
op, "Currently Support only floating point type");
30012988

30022989
int64_t inputRank = inputType.getSizes().size();
30032990
int64_t dim;
30042991
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
30052992
return rewriter.notifyMatchFailure(
3006-
op, "Only constant dim value is supported");
2993+
op, "Unimplemented: Only constant dim value is supported");
30072994
dim = toPositiveDim(dim, inputRank);
30082995
if (!isValidDim(dim, inputRank))
30092996
return rewriter.notifyMatchFailure(op, "invalid dim");
30102997

3011-
Value expInput = rewriter.create<AtenExpOp>(loc, input.getType(), input);
2998+
Value dtypeVal =
2999+
getDtypeIntValueForType(rewriter, loc, inputType.getDtype());
3000+
3001+
Value expInput = rewriter.create<AtenExpOp>(loc, resultType, input);
30123002

3013-
Value cumsum = rewriter.create<AtenCumsumOp>(
3014-
loc, expInput.getType(), expInput, op.getDim(), dtypeVal);
3003+
Value cumsum = rewriter.create<AtenCumsumOp>(loc, resultType, expInput,
3004+
op.getDim(), dtypeVal);
30153005

3016-
Value result = rewriter.create<AtenLogOp>(loc, cumsum.getType(), cumsum);
3006+
Value result = rewriter.create<AtenLogOp>(loc, resultType, cumsum);
30173007

30183008
rewriter.replaceOp(op, result);
30193009
return success();
@@ -12079,7 +12069,10 @@ class DecomposeComplexOpsPass
1207912069
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);
1208012070
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSoftmaxIntOp>(patterns);
1208112071
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSigmoidOp>(patterns);
12082-
addPatternIfTargetOpIsIllegal<DecomposeAtenLogCumsumExpOp>(patterns);
12072+
addPatternIfTargetOpIsIllegal<
12073+
DecomposeAtenLogCumsumExpOp<AtenLogcumsumexpOp>>(patterns);
12074+
addPatternIfTargetOpIsIllegal<
12075+
DecomposeAtenLogCumsumExpOp<Aten_LogcumsumexpOp>>(patterns);
1208312076
addPatternIfTargetOpIsIllegal<DecomposeAtenHardshrinkOp>(patterns);
1208412077
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftshrinkOp>(patterns);
1208512078
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyLikeOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
375375
target.addIllegalOp<Aten_LogSoftmaxOp>();
376376
target.addIllegalOp<AtenLogSoftmaxIntOp>();
377377
target.addIllegalOp<AtenLogSigmoidOp>();
378-
target.addIllegalOp<AtenLogcumsumexpOp>();
378+
target.addIllegalOp<Aten_LogcumsumexpOp, AtenLogcumsumexpOp>();
379379
target.addIllegalOp<AtenHardshrinkOp>();
380380
target.addIllegalOp<AtenSoftshrinkOp>();
381381
target.addIllegalOp<AtenEmptyLikeOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3349,6 +3349,8 @@
33493349
# RuntimeError: Given input size: (1x1x1). Calculated output size: (1x0x0). Output size is too small
33503350
"AvgPool2dWithoutPadFullDimIndivisibleByStrideModule_basic",
33513351
"MaxPool2dWithoutPadFullDimIndivisibleByStrideModule_basic",
3352+
"_LogCumsumExpStaticModule_basic",
3353+
"_LogCumsumExpStaticNegativeDimModule_basic",
33523354
}
33533355

33543356
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
@@ -3903,6 +3905,8 @@
39033905
"ScaledDotProductAttentionSameDynamicModule_basic",
39043906
"ScaledDotProductAttentionSameModule_basic",
39053907
"ScaledDotProductAttentionGQAModule_basic",
3908+
"_LogCumsumExpStaticModule_basic",
3909+
"_LogCumsumExpStaticNegativeDimModule_basic",
39063910
}
39073911

39083912
ONNX_TOSA_CRASHING_SET = {
@@ -4976,6 +4980,8 @@
49764980
"_ConvolutionDeprecated2DDeterministicModule_basic",
49774981
"_LogSoftmaxModule_basic",
49784982
"_SoftmaxModule_basic",
4983+
"_LogCumsumExpStaticModule_basic",
4984+
"_LogCumsumExpStaticNegativeDimModule_basic",
49794985
}
49804986

49814987
if torch_version_for_comparison() > version.parse("2.5.1"):

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1543,6 +1543,9 @@ def aten〇cumprod〡shape(self: List[int], dim: int, dtype: Optional[int] = Non
15431543
def aten〇logcumsumexp〡shape(self: List[int], dim: int) -> List[int]:
15441544
return self
15451545

1546+
def aten〇_logcumsumexp〡shape(self: List[int], dim: int) -> List[int]:
1547+
return self
1548+
15461549
def aten〇rand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
15471550
return self
15481551

@@ -3239,7 +3242,22 @@ def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Opt
32393242
@check_dtype_function(
32403243
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))
32413244
def aten〇logcumsumexp〡dtype(self_rank_dtype: Tuple[int, int], dim: int) -> int:
3242-
return self_rank_dtype[1]
3245+
self_rank, self_dtype = self_rank_dtype
3246+
return self_dtype
3247+
3248+
@check_dtype_function(
3249+
_check_tensors_with_the_same_dtype(
3250+
tensor_shapes=[(1, 1)],
3251+
tensor_device="cpu",
3252+
dim=0,
3253+
error_types={*all_integer_dtypes()}
3254+
)
3255+
)
3256+
def aten〇_logcumsumexp〡dtype(self_rank_dtype: Tuple[int, int], dim: int) -> int:
3257+
self_rank, self_dtype = self_rank_dtype
3258+
assert not is_integer_dtype(self_dtype)
3259+
return self_dtype
3260+
32433261

32443262
@check_dtype_function(
32453263
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) +

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5204,6 +5204,39 @@ def LogCumsumExpDtypeModule_basic(module, tu: TestUtils):
52045204
# ==============================================================================
52055205

52065206

5207+
class _LogCumsumExpStaticModule(torch.nn.Module):
5208+
def __init__(self):
5209+
super().__init__()
5210+
5211+
@export
5212+
@annotate_args([None, ([5, 3, 6, 9], torch.float32, True)])
5213+
def forward(self, x):
5214+
return torch.ops.aten._logcumsumexp(x, dim=1)
5215+
5216+
5217+
@register_test_case(module_factory=lambda: _LogCumsumExpStaticModule())
5218+
def _LogCumsumExpStaticModule_basic(module, tu: TestUtils):
5219+
module.forward(tu.rand(5, 3, 6, 9))
5220+
5221+
5222+
class _LogCumsumExpStaticNegativeDimModule(torch.nn.Module):
5223+
def __init__(self):
5224+
super().__init__()
5225+
5226+
@export
5227+
@annotate_args([None, ([6, 2, 3], torch.float32, True)])
5228+
def forward(self, x):
5229+
return torch.ops.aten.logcumsumexp(x, dim=-1)
5230+
5231+
5232+
@register_test_case(module_factory=lambda: _LogCumsumExpStaticNegativeDimModule())
5233+
def _LogCumsumExpStaticNegativeDimModule_basic(module, tu: TestUtils):
5234+
module.forward(tu.rand(6, 2, 3))
5235+
5236+
5237+
# ==============================================================================
5238+
5239+
52075240
class CumprodModule(torch.nn.Module):
52085241
def __init__(self):
52095242
super().__init__()

0 commit comments

Comments
 (0)