Skip to content

Commit 57d14be

Browse files
committed
Add ceil_mode support in TorchToStablehlo lowering for AtenMaxPool1dWithIndicesOp
1 parent facbea2 commit 57d14be

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

lib/Conversion/TorchToStablehlo/Pooling.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,29 @@ LogicalResult ConvertAtenOp<AtenMaxPool1dWithIndicesOp>::matchAndRewrite(
132132
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
133133
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
134134

135+
if (ceilMode) {
136+
// Match PyTorch output shape with extra padding. See
137+
// https://github.com/pytorch/pytorch/blob/c5de6ff079e3e5b453d6ff5190c90f02db458928/aten/src/ATen/native/Pool.h#L79
138+
const int64_t inputSize = inputShape[inputRank - 1];
139+
const int64_t numerator =
140+
(inputSize + 2 * padding[0] - dilation[0] * (kernelSize[0] - 1) - 1);
141+
const int64_t floor_output_size = (numerator) / stride[0] + 1;
142+
const int64_t adj = (stride[0] - 1);
143+
int64_t ceil_output_size = std::ceil((numerator + adj) / stride[0]) + 1;
144+
145+
// Ensure last pooling starts inside input
146+
if ((ceil_output_size - 1) * stride[0] >= inputSize + padding[0]) {
147+
ceil_output_size--;
148+
}
149+
150+
// Add extra padding to make output size same as torch
151+
if (ceil_output_size > floor_output_size) {
152+
const int64_t sizeDiff = ceil_output_size - floor_output_size;
153+
const int64_t extraPadding = sizeDiff * stride[0];
154+
stablehloPadding[stablehloPadding.size() - 1] += extraPadding;
155+
}
156+
}
157+
135158
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
136159

137160
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,6 @@
561561
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
562562
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
563563
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
564-
"MaxPool1dWithIndicesCeilModeModule_basic",
565564
"MaxPool1dCeilModeTrueModule_basic",
566565
"MaxPool1dStaticCeilModeTrueModule_basic",
567566
"MaxUnpool3dModulePad0_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,13 +217,13 @@ def __init__(self):
217217
)
218218
def forward(self, x):
219219
return torch.ops.aten.max_pool1d_with_indices(
220-
x, kernel_size=[6], stride=[2], padding=[3], dilation=2, ceil_mode=True
220+
x, kernel_size=[4], stride=[2], padding=[2], dilation=2, ceil_mode=True
221221
)
222222

223223

224224
@register_test_case(module_factory=lambda: MaxPool1dWithIndicesCeilModeModule())
225225
def MaxPool1dWithIndicesCeilModeModule_basic(module, tu: TestUtils):
226-
module.forward(tu.rand(1, 64, 112, low=-1))
226+
module.forward(tu.rand(3, 25, 37, low=-1))
227227

228228

229229
# ==============================================================================

0 commit comments

Comments
 (0)