Skip to content

Commit 47089ef

Browse files
committed
Lower replication pad 1d to linalg backend
Signed-off-by: Zahid Wakeel <[email protected]>
1 parent ad3ae54 commit 47089ef

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed

lib/Conversion/TorchToLinalg/TensorConstructors.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,96 @@ class ConvertAtenConstantPadNdOp
116116

117117
namespace {
118118

119+
class ConvertAtenReplicationPad1dOp
120+
: public OpConversionPattern<AtenReplicationPad1dOp> {
121+
public:
122+
using OpConversionPattern::OpConversionPattern;
123+
124+
LogicalResult
125+
matchAndRewrite(AtenReplicationPad1dOp op, OpAdaptor adaptor,
126+
ConversionPatternRewriter &rewriter) const override {
127+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
128+
return failure();
129+
130+
Location loc = op.getLoc();
131+
Value input = adaptor.getSelf();
132+
auto inputType = llvm::cast<RankedTensorType>(input.getType());
133+
int64_t inputRank = inputType.getRank();
134+
135+
if (inputRank < 2)
136+
return rewriter.notifyMatchFailure(op, "input rank must be at least 2");
137+
138+
SmallVector<int64_t> padInts;
139+
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts)))
140+
return rewriter.notifyMatchFailure(
141+
op, "only support constant int pad ranges");
142+
143+
if (padInts.size() != 2)
144+
return rewriter.notifyMatchFailure(
145+
op, "pad range must have exactly two values");
146+
147+
int64_t leftPad = padInts[0];
148+
int64_t rightPad = padInts[1];
149+
150+
int64_t widthDim = inputRank - 1;
151+
Type indexType = rewriter.getIndexType();
152+
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
153+
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
154+
155+
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
156+
Value widthSize = inputShape[widthDim];
157+
Value widthMinusOne = rewriter.create<arith::SubIOp>(loc, widthSize, one);
158+
159+
// Build offset and size arrays for slicing
160+
SmallVector<OpFoldResult> allOneStrides(inputRank,
161+
rewriter.getIndexAttr(1));
162+
SmallVector<OpFoldResult> leftOffsets(inputRank, rewriter.getIndexAttr(0));
163+
SmallVector<OpFoldResult> rightOffsets(inputRank, rewriter.getIndexAttr(0));
164+
SmallVector<OpFoldResult> sizes(inputRank, rewriter.getIndexAttr(0));
165+
for (int i = 0; i < inputRank; ++i)
166+
sizes[i] = (i == widthDim) ? rewriter.getIndexAttr(1)
167+
: getAsOpFoldResult(inputShape[i]);
168+
169+
rightOffsets[widthDim] = getAsOpFoldResult(widthMinusOne);
170+
171+
// Extract leftmost and rightmost slices
172+
Value leftSlice = rewriter.create<tensor::ExtractSliceOp>(
173+
loc, input, leftOffsets, sizes, allOneStrides);
174+
Value rightSlice = rewriter.create<tensor::ExtractSliceOp>(
175+
loc, input, rightOffsets, sizes, allOneStrides);
176+
177+
// Create repeated tiles
178+
SmallVector<Value> resultParts;
179+
180+
if (leftPad > 0) {
181+
SmallVector<Value> leftTiles(leftPad, leftSlice);
182+
Value leftConcat =
183+
rewriter.create<tensor::ConcatOp>(loc, widthDim, leftTiles);
184+
resultParts.push_back(leftConcat);
185+
}
186+
187+
resultParts.push_back(input);
188+
189+
if (rightPad > 0) {
190+
SmallVector<Value> rightTiles(rightPad, rightSlice);
191+
Value rightConcat =
192+
rewriter.create<tensor::ConcatOp>(loc, widthDim, rightTiles);
193+
resultParts.push_back(rightConcat);
194+
}
195+
196+
Value result =
197+
rewriter.create<tensor::ConcatOp>(loc, widthDim, resultParts);
198+
Type resultType = getTypeConverter()->convertType(op.getType());
199+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
200+
201+
return success();
202+
}
203+
};
204+
205+
} // namespace
206+
207+
namespace {
208+
119209
// Lower aten.replication_pad2d operator into a sequence of
120210
// tensor.extract_slice and tensor.concat operations.
121211

@@ -621,6 +711,8 @@ void mlir::torch::torch_to_linalg::
621711
MLIRContext *context = patterns.getContext();
622712
target.addIllegalOp<AtenReplicationPad2dOp>();
623713
patterns.add<ConvertAtenReplicationPad2dOp>(typeConverter, context);
714+
target.addIllegalOp<AtenReplicationPad1dOp>();
715+
patterns.add<ConvertAtenReplicationPad1dOp>(typeConverter, context);
624716
target.addIllegalOp<AtenConstantPadNdOp>();
625717
patterns.add<ConvertAtenConstantPadNdOp>(typeConverter, context);
626718
target.addIllegalOp<AtenZerosOp, AtenOnesOp>();

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,26 @@ def ReplicationPad2dModule_left0(module, tu: TestUtils):
685685
# ==============================================================================
686686

687687

688+
class ReplicationPad1dModule(torch.nn.Module):
689+
def __init__(self):
690+
super().__init__()
691+
692+
@export
693+
@annotate_args(
694+
[
695+
None,
696+
([-1, -1, -1], torch.float32, True),
697+
]
698+
)
699+
def forward(self, x):
700+
return torch.ops.aten.replication_pad1d(x, [3, 5])
701+
702+
703+
@register_test_case(module_factory=lambda: ReplicationPad1dModule())
704+
def ReplicationPad1dModule_basic(module, tu: TestUtils):
705+
module.forward(tu.rand(1, 15, 20, low=-1))
706+
707+
688708
class ReplicationPad2dModule_right0_module(torch.nn.Module):
689709
def __init__(self):
690710
super().__init__()

0 commit comments

Comments
 (0)