@@ -116,6 +116,96 @@ class ConvertAtenConstantPadNdOp
116
116
117
117
namespace {
118
118
119
+ class ConvertAtenReplicationPad1dOp
120
+ : public OpConversionPattern<AtenReplicationPad1dOp> {
121
+ public:
122
+ using OpConversionPattern::OpConversionPattern;
123
+
124
+ LogicalResult
125
+ matchAndRewrite (AtenReplicationPad1dOp op, OpAdaptor adaptor,
126
+ ConversionPatternRewriter &rewriter) const override {
127
+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
128
+ return failure ();
129
+
130
+ Location loc = op.getLoc ();
131
+ Value input = adaptor.getSelf ();
132
+ auto inputType = llvm::cast<RankedTensorType>(input.getType ());
133
+ int64_t inputRank = inputType.getRank ();
134
+
135
+ if (inputRank < 2 )
136
+ return rewriter.notifyMatchFailure (op, " input rank must be at least 2" );
137
+
138
+ SmallVector<int64_t > padInts;
139
+ if (!matchPattern (op.getPadding (), m_TorchListOfConstantInts (padInts)))
140
+ return rewriter.notifyMatchFailure (
141
+ op, " only support constant int pad ranges" );
142
+
143
+ if (padInts.size () != 2 )
144
+ return rewriter.notifyMatchFailure (
145
+ op, " pad range must have exactly two values" );
146
+
147
+ int64_t leftPad = padInts[0 ];
148
+ int64_t rightPad = padInts[1 ];
149
+
150
+ int64_t widthDim = inputRank - 1 ;
151
+ Type indexType = rewriter.getIndexType ();
152
+ Value one = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
153
+ Value zero = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
154
+
155
+ SmallVector<Value> inputShape = getTensorSizes (rewriter, loc, input);
156
+ Value widthSize = inputShape[widthDim];
157
+ Value widthMinusOne = rewriter.create <arith::SubIOp>(loc, widthSize, one);
158
+
159
+ // Build offset and size arrays for slicing
160
+ SmallVector<OpFoldResult> allOneStrides (inputRank,
161
+ rewriter.getIndexAttr (1 ));
162
+ SmallVector<OpFoldResult> leftOffsets (inputRank, rewriter.getIndexAttr (0 ));
163
+ SmallVector<OpFoldResult> rightOffsets (inputRank, rewriter.getIndexAttr (0 ));
164
+ SmallVector<OpFoldResult> sizes (inputRank, rewriter.getIndexAttr (0 ));
165
+ for (int i = 0 ; i < inputRank; ++i)
166
+ sizes[i] = (i == widthDim) ? rewriter.getIndexAttr (1 )
167
+ : getAsOpFoldResult (inputShape[i]);
168
+
169
+ rightOffsets[widthDim] = getAsOpFoldResult (widthMinusOne);
170
+
171
+ // Extract leftmost and rightmost slices
172
+ Value leftSlice = rewriter.create <tensor::ExtractSliceOp>(
173
+ loc, input, leftOffsets, sizes, allOneStrides);
174
+ Value rightSlice = rewriter.create <tensor::ExtractSliceOp>(
175
+ loc, input, rightOffsets, sizes, allOneStrides);
176
+
177
+ // Create repeated tiles
178
+ SmallVector<Value> resultParts;
179
+
180
+ if (leftPad > 0 ) {
181
+ SmallVector<Value> leftTiles (leftPad, leftSlice);
182
+ Value leftConcat =
183
+ rewriter.create <tensor::ConcatOp>(loc, widthDim, leftTiles);
184
+ resultParts.push_back (leftConcat);
185
+ }
186
+
187
+ resultParts.push_back (input);
188
+
189
+ if (rightPad > 0 ) {
190
+ SmallVector<Value> rightTiles (rightPad, rightSlice);
191
+ Value rightConcat =
192
+ rewriter.create <tensor::ConcatOp>(loc, widthDim, rightTiles);
193
+ resultParts.push_back (rightConcat);
194
+ }
195
+
196
+ Value result =
197
+ rewriter.create <tensor::ConcatOp>(loc, widthDim, resultParts);
198
+ Type resultType = getTypeConverter ()->convertType (op.getType ());
199
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultType, result);
200
+
201
+ return success ();
202
+ }
203
+ };
204
+
205
+ } // namespace
206
+
207
+ namespace {
208
+
119
209
// Lower aten.replication_pad2d operator into a sequence of
120
210
// tensor.extract_slice and tensor.concat operations.
121
211
@@ -621,6 +711,8 @@ void mlir::torch::torch_to_linalg::
621
711
MLIRContext *context = patterns.getContext ();
622
712
target.addIllegalOp <AtenReplicationPad2dOp>();
623
713
patterns.add <ConvertAtenReplicationPad2dOp>(typeConverter, context);
714
+ target.addIllegalOp <AtenReplicationPad1dOp>();
715
+ patterns.add <ConvertAtenReplicationPad1dOp>(typeConverter, context);
624
716
target.addIllegalOp <AtenConstantPadNdOp>();
625
717
patterns.add <ConvertAtenConstantPadNdOp>(typeConverter, context);
626
718
target.addIllegalOp <AtenZerosOp, AtenOnesOp>();
0 commit comments