Skip to content

Commit d6c3a9b

Browse files
committed
Enable lowering to linalg
Signed-off-by: Zahid Wakeel <[email protected]>
1 parent 716303a commit d6c3a9b

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

lib/Conversion/TorchToLinalg/Pooling.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,10 @@ template <> struct DimensionTraits<AtenMaxPool1dOp> {
225225
static_assert(Dim == Dim);
226226
};
227227

228+
template <>
229+
struct DimensionTraits<AtenMaxPool1dWithIndicesOp>
230+
: DimensionTraits<AtenMaxPool1dOp> {};
231+
228232
template <> struct DimensionTraits<AtenMaxPool2dOp> {
229233
static constexpr int64_t Dim = 2;
230234
// unused const variable warning suppression:
@@ -250,7 +254,8 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
250254
using OpConversionPattern<OpTy>::OpConversionPattern;
251255

252256
static const bool withIndices =
253-
llvm::is_one_of<OpTy, AtenMaxPool2dWithIndicesOp,
257+
llvm::is_one_of<OpTy, AtenMaxPool1dWithIndicesOp,
258+
AtenMaxPool2dWithIndicesOp,
254259
AtenMaxPool3dWithIndicesOp>::value;
255260

256261
private:
@@ -1687,8 +1692,11 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
16871692
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool2dOp>>(typeConverter, context);
16881693
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dOp>>(typeConverter, context);
16891694

1695+
target.addIllegalOp<AtenMaxPool1dWithIndicesOp>();
16901696
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
16911697
target.addIllegalOp<AtenMaxPool3dWithIndicesOp>();
1698+
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool1dWithIndicesOp>>(typeConverter,
1699+
context);
16921700
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
16931701
context);
16941702
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dWithIndicesOp>>(typeConverter,

projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,55 @@ def AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic(module, tu: TestUtils):
180180
# ==============================================================================
181181

182182

183+
class MaxPool1dWithIndicesModule(torch.nn.Module):
184+
185+
def __init__(self):
186+
super().__init__()
187+
188+
@export
189+
@annotate_args(
190+
[
191+
None,
192+
([-1, -1, -1], torch.float32, True),
193+
]
194+
)
195+
def forward(self, x):
196+
return torch.ops.aten.max_pool1d_with_indices(
197+
x, kernel_size=[6], stride=[2], padding=[3], dilation=2, ceil_mode=False
198+
)
199+
200+
201+
@register_test_case(module_factory=lambda: MaxPool1dWithIndicesModule())
202+
def MaxPool1dWithIndicesModule_basic(module, tu: TestUtils):
203+
module.forward(tu.rand(1, 64, 112, low=-1))
204+
205+
206+
class MaxPool1dWithIndicesCeilModeModule(torch.nn.Module):
207+
208+
def __init__(self):
209+
super().__init__()
210+
211+
@export
212+
@annotate_args(
213+
[
214+
None,
215+
([-1, -1, -1], torch.float32, True),
216+
]
217+
)
218+
def forward(self, x):
219+
return torch.ops.aten.max_pool1d_with_indices(
220+
x, kernel_size=[6], stride=[2], padding=[3], dilation=2, ceil_mode=True
221+
)
222+
223+
224+
@register_test_case(module_factory=lambda: MaxPool1dWithIndicesCeilModeModule())
225+
def MaxPool1dWithIndicesCeilModeModule_basic(module, tu: TestUtils):
226+
module.forward(tu.rand(1, 64, 112, low=-1))
227+
228+
229+
# ==============================================================================
230+
231+
183232
class MaxPool1dModule(torch.nn.Module):
184233

185234
def __init__(self):

0 commit comments

Comments
 (0)