Skip to content

Commit c6624b1

Browse files
committed
added the code for logcumsumexpOp
1 parent 1cb25e9 commit c6624b1

File tree

8 files changed

+201
-0
lines changed

8 files changed

+201
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8716,6 +8716,54 @@ def Torch_AtenCumprodOp : Torch_Op<"aten.cumprod", [
87168716
}];
87178717
}
87188718

8719+
def Torch_AtenLogcumsumexpOp : Torch_Op<"aten.logcumsumexp", [
8720+
AllowsTypeRefinement,
8721+
HasValueSemantics,
8722+
ReadOnly
8723+
]> {
8724+
let summary = "Generated op for `aten::logcumsumexp : (Tensor, int) -> (Tensor)`";
8725+
let arguments = (ins
8726+
AnyTorchTensorType:$self,
8727+
Torch_IntType:$dim
8728+
);
8729+
let results = (outs
8730+
AnyTorchOptionalTensorType:$result
8731+
);
8732+
let hasCustomAssemblyFormat = 1;
8733+
let extraClassDefinition = [{
8734+
ParseResult AtenLogcumsumexpOp::parse(OpAsmParser &parser, OperationState &result) {
8735+
return parseDefaultTorchOp(parser, result, 2, 1);
8736+
}
8737+
void AtenLogcumsumexpOp::print(OpAsmPrinter &printer) {
8738+
printDefaultTorchOp(printer, *this, 2, 1);
8739+
}
8740+
}];
8741+
}
8742+
8743+
def Torch_Aten_LogcumsumexpOp : Torch_Op<"aten._logcumsumexp", [
8744+
AllowsTypeRefinement,
8745+
HasValueSemantics,
8746+
ReadOnly
8747+
]> {
8748+
let summary = "Generated op for `aten::_logcumsumexp : (Tensor, int) -> (Tensor)`";
8749+
let arguments = (ins
8750+
AnyTorchTensorType:$self,
8751+
Torch_IntType:$dim
8752+
);
8753+
let results = (outs
8754+
AnyTorchOptionalTensorType:$result
8755+
);
8756+
let hasCustomAssemblyFormat = 1;
8757+
let extraClassDefinition = [{
8758+
ParseResult Aten_LogcumsumexpOp::parse(OpAsmParser &parser, OperationState &result) {
8759+
return parseDefaultTorchOp(parser, result, 2, 1);
8760+
}
8761+
void Aten_LogcumsumexpOp::print(OpAsmPrinter &printer) {
8762+
printDefaultTorchOp(printer, *this, 2, 1);
8763+
}
8764+
}];
8765+
}
8766+
87198767
def Torch_AtenFloorDivideScalarOp : Torch_Op<"aten.floor_divide.Scalar", [
87208768
AllowsTypeRefinement,
87218769
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9365,6 +9365,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
93659365
" func.func @\"__torch_mlir_shape_fn.aten.cumprod\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
93669366
" return %arg0 : !torch.list<int>\n"
93679367
" }\n"
9368+
" func.func @\"__torch_mlir_shape_fn.aten.logcumsumexp\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
9369+
" return %arg0 : !torch.list<int>\n"
9370+
" }\n"
93689371
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
93699372
" return %arg0 : !torch.list<int>\n"
93709373
" }\n"
@@ -12543,6 +12546,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1254312546
" }\n"
1254412547
" return %1 : !torch.int\n"
1254512548
" }\n"
12549+
" func.func @\"__torch_mlir_dtype_fn.aten.logcumsumexp\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
12550+
" %int1 = torch.constant.int 1\n"
12551+
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
12552+
" return %0 : !torch.int\n"
12553+
" }\n"
1254612554
" func.func @\"__torch_mlir_dtype_fn.aten.cumprod\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
1254712555
" %int4 = torch.constant.int 4\n"
1254812556
" %none = torch.constant.none\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2962,6 +2962,65 @@ class DecomposeAten_LogSoftmaxOp : public OpRewritePattern<Aten_LogSoftmaxOp> {
29622962
};
29632963
} // namespace
29642964

2965+
// Decompose AtenLogCumsumExpOp to:
2966+
// AtenExpOp
2967+
// AtenCumsumOp
2968+
// AtenLogOp
2969+
namespace {
2970+
2971+
class DecomposeAtenLogCumsumExpOp
2972+
: public OpRewritePattern<AtenLogcumsumexpOp> {
2973+
public:
2974+
using OpRewritePattern<AtenLogcumsumexpOp>::OpRewritePattern;
2975+
LogicalResult matchAndRewrite(AtenLogcumsumexpOp op,
2976+
PatternRewriter &rewriter) const override {
2977+
Location loc = op.getLoc();
2978+
Value input = op.getSelf();
2979+
2980+
auto inputType = dyn_cast<BaseTensorType>(input.getType());
2981+
if (!inputType || !inputType.getDtype())
2982+
return rewriter.notifyMatchFailure(op, "Supports only tensor type");
2983+
2984+
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype()))
2985+
return rewriter.notifyMatchFailure(op, "Support only floating type");
2986+
2987+
Type elementType = inputType.getDtype();
2988+
torch_upstream::ScalarType scalarType;
2989+
// logcumsumexp is only supported for Float datatype
2990+
if (elementType.isF16())
2991+
scalarType = torch_upstream::ScalarType::Half;
2992+
else if (elementType.isF32())
2993+
scalarType = torch_upstream::ScalarType::Float;
2994+
else
2995+
scalarType = torch_upstream::ScalarType::Double;
2996+
2997+
int64_t scalarVal = static_cast<int64_t>(scalarType);
2998+
2999+
Value dtypeVal = rewriter.create<Torch::ConstantIntOp>(
3000+
loc, rewriter.getType<Torch::IntType>(), scalarVal);
3001+
3002+
int64_t inputRank = inputType.getSizes().size();
3003+
int64_t dim;
3004+
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
3005+
return rewriter.notifyMatchFailure(
3006+
op, "Only constant dim value is supported");
3007+
dim = toPositiveDim(dim, inputRank);
3008+
if (!isValidDim(dim, inputRank))
3009+
return rewriter.notifyMatchFailure(op, "invalid dim");
3010+
3011+
Value expInput = rewriter.create<AtenExpOp>(loc, input.getType(), input);
3012+
3013+
Value cumsum = rewriter.create<AtenCumsumOp>(
3014+
loc, expInput.getType(), expInput, op.getDim(), dtypeVal);
3015+
3016+
Value result = rewriter.create<AtenLogOp>(loc, cumsum.getType(), cumsum);
3017+
3018+
rewriter.replaceOp(op, result);
3019+
return success();
3020+
}
3021+
};
3022+
} // namespace
3023+
29653024
namespace {
29663025
class DecomposeAtenLogSigmoidOp : public OpRewritePattern<AtenLogSigmoidOp> {
29673026
public:
@@ -12020,6 +12079,7 @@ class DecomposeComplexOpsPass
1202012079
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);
1202112080
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSoftmaxIntOp>(patterns);
1202212081
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSigmoidOp>(patterns);
12082+
addPatternIfTargetOpIsIllegal<DecomposeAtenLogCumsumExpOp>(patterns);
1202312083
addPatternIfTargetOpIsIllegal<DecomposeAtenHardshrinkOp>(patterns);
1202412084
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftshrinkOp>(patterns);
1202512085
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyLikeOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
375375
target.addIllegalOp<Aten_LogSoftmaxOp>();
376376
target.addIllegalOp<AtenLogSoftmaxIntOp>();
377377
target.addIllegalOp<AtenLogSigmoidOp>();
378+
target.addIllegalOp<AtenLogcumsumexpOp>();
378379
target.addIllegalOp<AtenHardshrinkOp>();
379380
target.addIllegalOp<AtenSoftshrinkOp>();
380381
target.addIllegalOp<AtenEmptyLikeOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2988,6 +2988,10 @@
29882988
"LinalgNormKeepDimComplexModule_basic",
29892989
"LinalgVectorNormComplexModule_basic",
29902990
"LogSoftmaxBackwardModule_basic",
2991+
"LogCumsumExpModule_basic",
2992+
"LogCumsumExpStaticModule_basic",
2993+
"LogCumsumExpStaticNegativeDimModule_basic",
2994+
"LogCumsumExpDtypeModule_basic",
29912995
"MaxPool1dCeilModeTrueModule_basic",
29922996
"MaxPool1dModule_basic",
29932997
"MaxPool2dCeilModeTrueModule_basic",
@@ -3713,6 +3717,10 @@
37133717
"LinalgNormKeepDimComplexModule_basic",
37143718
"LinalgVectorNormComplexModule_basic",
37153719
"LinspaceEmptyModule_basic",
3720+
"LogCumsumExpModule_basic",
3721+
"LogCumsumExpStaticModule_basic",
3722+
"LogCumsumExpStaticNegativeDimModule_basic",
3723+
"LogCumsumExpDtypeModule_basic",
37163724
"MaskedScatterStaticBasic_basic",
37173725
"MaxPool1dCeilModeTrueModule_basic",
37183726
"MaxPool1dModule_basic",
@@ -4510,6 +4518,10 @@
45104518
"LinalgVectorNormComplexModule_basic",
45114519
"LogSoftmaxBackwardModule_basic",
45124520
"LogSoftmaxIntModule_basic",
4521+
"logCumsumExpModule_basic",
4522+
"LogCumsumExpStaticModule_basic",
4523+
"LogCumsumExpStaticNegativeDimModule_basic",
4524+
"LogCumsumExpDtypeModule_basic",
45134525
"MaskedFillTensorFloatValueModule_basic",
45144526
"MatmulBroadcastBatchDim_basic",
45154527
"MatmulSingleDynamicBatchDim_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,6 +1540,9 @@ def aten〇cumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None
15401540
def aten〇cumprod〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
15411541
return self
15421542

1543+
def aten〇logcumsumexp〡shape(self: List[int], dim: int) -> List[int]:
1544+
return self
1545+
15431546
def aten〇rand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
15441547
return self
15451548

@@ -3233,6 +3236,10 @@ def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Opt
32333236
return torch.int64
32343237
return self_dtype
32353238

3239+
@check_dtype_function(
3240+
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))
3241+
def aten〇logcumsumexp〡dtype(self_rank_dtype: Tuple[int, int], dim: int) -> int:
3242+
return self_rank_dtype[1]
32363243

32373244
@check_dtype_function(
32383245
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) +

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,8 @@ def emit_with_mutating_variants(key, **kwargs):
720720
emit("aten::bmm : (Tensor, Tensor) -> (Tensor)")
721721
emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)")
722722
emit("aten::cumprod : (Tensor, int, int?) -> (Tensor)")
723+
emit("aten::logcumsumexp : (Tensor, int) -> (Tensor)")
724+
emit("aten::_logcumsumexp : (Tensor, int) -> (Tensor)")
723725
emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)")
724726
emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)")
725727
emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5141,6 +5141,69 @@ def CumsumWithDtypeModule_basic(module, tu: TestUtils):
51415141
# ==============================================================================
51425142

51435143

5144+
class LogCumsumExpModule(torch.nn.Module):
5145+
def __init__(self):
5146+
super().__init__()
5147+
5148+
@export
5149+
@annotate_args([None, ([-1, -1, -1], torch.float32, True)])
5150+
def forward(self, x):
5151+
return torch.ops.aten.logcumsumexp(x, dim=1)
5152+
5153+
5154+
@register_test_case(module_factory=lambda: LogCumsumExpModule())
5155+
def LogCumsumExpModule_basic(module, tu: TestUtils):
5156+
module.forward(tu.rand(1, 2, 3))
5157+
5158+
5159+
class LogCumsumExpStaticModule(torch.nn.Module):
5160+
def __init__(self):
5161+
super().__init__()
5162+
5163+
@export
5164+
@annotate_args([None, ([1, 2, 3], torch.float32, True)])
5165+
def forward(self, x):
5166+
return torch.ops.aten.logcumsumexp(x, dim=1)
5167+
5168+
5169+
@register_test_case(module_factory=lambda: LogCumsumExpStaticModule())
5170+
def LogCumsumExpStaticModule_basic(module, tu: TestUtils):
5171+
module.forward(tu.rand(1, 2, 3))
5172+
5173+
5174+
class LogCumsumExpStaticNegativeDimModule(torch.nn.Module):
5175+
def __init__(self):
5176+
super().__init__()
5177+
5178+
@export
5179+
@annotate_args([None, ([8, 5, 6], torch.float32, True)])
5180+
def forward(self, x):
5181+
return torch.ops.aten.logcumsumexp(x, dim=-2)
5182+
5183+
5184+
@register_test_case(module_factory=lambda: LogCumsumExpStaticNegativeDimModule())
5185+
def LogCumsumExpStaticNegativeDimModule_basic(module, tu: TestUtils):
5186+
module.forward(tu.rand(8, 5, 6))
5187+
5188+
5189+
class LogCumsumExpDtypeModule(torch.nn.Module):
5190+
def __init__(self):
5191+
super().__init__()
5192+
5193+
@export
5194+
@annotate_args([None, ([5, 3, 6, 9], torch.float64, True)])
5195+
def forward(self, x):
5196+
return torch.ops.aten.logcumsumexp(x, dim=1)
5197+
5198+
5199+
@register_test_case(module_factory=lambda: LogCumsumExpDtypeModule())
5200+
def LogCumsumExpDtypeModule_basic(module, tu: TestUtils):
5201+
module.forward(tu.rand(5, 3, 6, 9).to(torch.float64))
5202+
5203+
5204+
# ==============================================================================
5205+
5206+
51445207
class CumprodModule(torch.nn.Module):
51455208
def __init__(self):
51465209
super().__init__()

0 commit comments

Comments
 (0)