forked from EPFL-LAP/dynamatic
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFlattenMemRefRowMajor.cpp
344 lines (300 loc) · 13.2 KB
/
FlattenMemRefRowMajor.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
//===- FlattenMemRefROwMajor.cpp - MemRef flattening pass -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file originates from the CIRCT project (https://github.com/llvm/circt).
// It includes modifications made as part of Dynamatic.
//
//===----------------------------------------------------------------------===//
//
// Contains the definitions of the MemRef flattening pass. It is closely modeled
// on the MemRef flattening pass from CIRCT bus uses row-major indexing to
// convert multidimensional load and store operations.
//
//===----------------------------------------------------------------------===//
#include "dynamatic/Transforms/FlattenMemRefRowMajor.h"
#include "dynamatic/Dialect/Handshake/HandshakeOps.h"
#include "dynamatic/Dialect/Handshake/MemoryInterfaces.h"
#include "dynamatic/Support/Attribute.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/MathExtras.h"
using namespace mlir;
using namespace dynamatic;
static inline bool isUniDimensional(MemRefType memref) {
return memref.getShape().size() == 1;
}
/// Flatten indices in row-major style, making adjacent indices in the last
/// memref dimension be adjacent indices in the flattened memref.
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
ValueRange indices, MemRefType memrefType) {
assert(memrefType.hasStaticShape() && "expected statically shaped memref");
Location loc = op->getLoc();
auto numIndices = indices.size();
if (numIndices == 0) {
// Singleton memref (e.g. memref<i32>) - return 0
return rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0))
.getResult();
}
if (numIndices == 1)
// Memref is already unidimensional
return indices.front();
// Iterate over indices to compute the final unidimensional index
Value finalIdx = indices.back();
int64_t dimProduct = 1;
for (size_t i = 0, e = numIndices - 1; i < e; ++i) {
auto memIdx = numIndices - i - 2;
Value partialIdx = indices[memIdx];
dimProduct *= memrefType.getShape()[memIdx];
// Multiply product by the current index operand
if (llvm::isPowerOf2_64(dimProduct)) {
auto constant =
rewriter
.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(llvm::Log2_64(dimProduct)))
.getResult();
partialIdx =
rewriter.create<arith::ShLIOp>(loc, partialIdx, constant).getResult();
} else {
auto constant =
rewriter
.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dimProduct))
.getResult();
partialIdx =
rewriter.create<arith::MulIOp>(loc, partialIdx, constant).getResult();
}
// Sum up with the prior lower dimension accessors
auto sumOp = rewriter.create<arith::AddIOp>(loc, finalIdx, partialIdx);
finalIdx = sumOp.getResult();
}
return finalIdx;
}
static bool hasMultiDimMemRef(ValueRange values) {
return llvm::any_of(values, [](Value v) {
auto memref = v.getType().dyn_cast<MemRefType>();
if (!memref)
return false;
return !isUniDimensional(memref);
});
}
namespace {
struct LoadOpConversion : public OpConversionPattern<memref::LoadOp> {
using OpConversionPattern::OpConversionPattern;
LoadOpConversion(MemoryOpLowering &memOpLowering, TypeConverter &converter,
MLIRContext *ctx)
: OpConversionPattern(converter, ctx), memOpLowering(memOpLowering){};
LogicalResult
matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType type = loadOp.getMemRefType();
if (isUniDimensional(type) || !type.hasStaticShape() ||
/*Already converted?*/ loadOp.getIndices().size() == 1)
return failure();
Value finalIdx = flattenIndices(rewriter, loadOp, adaptor.getIndices(),
loadOp.getMemRefType());
memref::LoadOp flatLoadOp = rewriter.replaceOpWithNewOp<memref::LoadOp>(
loadOp, adaptor.getMemref(), SmallVector<Value>{finalIdx});
memOpLowering.recordReplacement(loadOp, flatLoadOp);
return success();
}
private:
/// Used to record the operation replacement.
MemoryOpLowering &memOpLowering;
};
struct StoreOpConversion : public OpConversionPattern<memref::StoreOp> {
using OpConversionPattern::OpConversionPattern;
StoreOpConversion(MemoryOpLowering &memOpLowering, TypeConverter &converter,
MLIRContext *ctx)
: OpConversionPattern(converter, ctx), memOpLowering(memOpLowering){};
LogicalResult
matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType type = storeOp.getMemRefType();
if (isUniDimensional(type) || !type.hasStaticShape() ||
/*Already converted?*/ storeOp.getIndices().size() == 1)
return failure();
Value finalIdx = flattenIndices(rewriter, storeOp, adaptor.getIndices(),
storeOp.getMemRefType());
memref::StoreOp flatStoreOp = rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, adaptor.getValue(), adaptor.getMemref(),
SmallVector<Value>{finalIdx});
memOpLowering.recordReplacement(storeOp, flatStoreOp);
return success();
}
private:
/// Used to record the operation replacement.
MemoryOpLowering &memOpLowering;
};
struct AllocOpConversion : public OpConversionPattern<memref::AllocOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::AllocOp op, OpAdaptor /*adaptor*/,
ConversionPatternRewriter &rewriter) const override {
MemRefType type = op.getType();
if (isUniDimensional(type) || !type.hasStaticShape())
return failure();
MemRefType newType = MemRefType::get(
SmallVector<int64_t>{type.getNumElements()}, type.getElementType());
rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
return success();
}
};
// A generic pattern which will replace an op with a new op of the same type
// but using the adaptor (type converted) operands.
template <typename TOp>
struct OperandConversionPattern : public OpConversionPattern<TOp> {
using OpConversionPattern<TOp>::OpConversionPattern;
using OpAdaptor = typename TOp::Adaptor;
LogicalResult
matchAndRewrite(TOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TOp>(op, op->getResultTypes(),
adaptor.getOperands(), op->getAttrs());
return success();
}
};
// Cannot use OperandConversionPattern for branch op since the default builder
// doesn't provide a method for communicating block successors.
struct CondBranchOpConversion
: public OpConversionPattern<mlir::cf::CondBranchOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(mlir::cf::CondBranchOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
adaptor.getFalseDestOperands(), op.getTrueDest(), op.getFalseDest());
return success();
}
};
// Rewrites a call op signature to flattened types. If rewriteFunctions is set,
// will also replace the callee with a private definition of the called
// function of the updated signature.
struct CallOpConversion : public OpConversionPattern<func::CallOp> {
CallOpConversion(TypeConverter &typeConverter, MLIRContext *context,
bool rewriteFunctions = false)
: OpConversionPattern(typeConverter, context),
rewriteFunctions(rewriteFunctions) {}
LogicalResult
matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<Type> convResTypes;
if (typeConverter->convertTypes(op.getResultTypes(), convResTypes).failed())
return failure();
auto newCallOp = rewriter.replaceOpWithNewOp<func::CallOp>(
op, adaptor.getCallee(), convResTypes, adaptor.getOperands());
if (!rewriteFunctions)
return success();
// Override any definition corresponding to the updated signature.
// It is up to users of this pass to define how these rewritten functions
// are to be implemented.
rewriter.setInsertionPoint(op->getParentOfType<func::FuncOp>());
auto *calledFunction = dyn_cast<CallOpInterface>(*op).resolveCallable();
FunctionType funcType = FunctionType::get(
op.getContext(), newCallOp.getOperandTypes(), convResTypes);
func::FuncOp newFuncOp;
if (calledFunction)
newFuncOp = rewriter.replaceOpWithNewOp<func::FuncOp>(
calledFunction, op.getCallee(), funcType);
else
newFuncOp =
rewriter.create<func::FuncOp>(op.getLoc(), op.getCallee(), funcType);
newFuncOp.setVisibility(SymbolTable::Visibility::Private);
return success();
}
private:
bool rewriteFunctions;
};
template <typename... TOp>
void addGenericLegalityConstraint(ConversionTarget &target) {
(target.addDynamicallyLegalOp<TOp>([](TOp op) {
return !hasMultiDimMemRef(op->getOperands()) &&
!hasMultiDimMemRef(op->getResults());
}),
...);
}
static void populateFlattenMemRefsLegality(ConversionTarget &target) {
target.addLegalDialect<arith::ArithDialect>();
target.addDynamicallyLegalOp<memref::AllocOp>(
[](memref::AllocOp op) { return isUniDimensional(op.getType()); });
target.addDynamicallyLegalOp<memref::StoreOp>(
[](memref::StoreOp op) { return op.getIndices().size() == 1; });
target.addDynamicallyLegalOp<memref::LoadOp>(
[](memref::LoadOp op) { return op.getIndices().size() == 1; });
addGenericLegalityConstraint<mlir::cf::CondBranchOp, mlir::cf::BranchOp,
func::CallOp, func::ReturnOp, memref::DeallocOp,
memref::CopyOp>(target);
target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
auto argsConverted = llvm::none_of(op.getBlocks(), [](auto &block) {
return hasMultiDimMemRef(block.getArguments());
});
auto resultsConverted = llvm::all_of(op.getResultTypes(), [](Type type) {
if (auto memref = type.dyn_cast<MemRefType>())
return isUniDimensional(memref);
return true;
});
return argsConverted && resultsConverted;
});
}
static void populateTypeConversionPatterns(TypeConverter &typeConverter) {
// Add default conversion for all types generically.
typeConverter.addConversion([](Type type) { return type; });
// Add specific conversion for memref types.
typeConverter.addConversion([](MemRefType memref) {
if (isUniDimensional(memref))
return memref;
return MemRefType::get(llvm::SmallVector<int64_t>{memref.getNumElements()},
memref.getElementType());
});
}
struct FlattenMemRefRowMajorPass
: public dynamatic::impl::FlattenMemRefRowMajorBase<
FlattenMemRefRowMajorPass> {
public:
void runDynamaticPass() override {
mlir::ModuleOp modOp = getOperation();
MLIRContext *ctx = &getContext();
TypeConverter typeConverter;
MemoryOpLowering memOpLowering(getAnalysis<NameAnalysis>());
populateTypeConversionPatterns(typeConverter);
RewritePatternSet patterns(ctx);
SetVector<StringRef> rewrittenCallees;
patterns.add<AllocOpConversion, OperandConversionPattern<func::ReturnOp>,
OperandConversionPattern<memref::DeallocOp>,
CondBranchOpConversion,
OperandConversionPattern<memref::DeallocOp>,
OperandConversionPattern<memref::CopyOp>, CallOpConversion>(
typeConverter, ctx);
patterns.add<LoadOpConversion, StoreOpConversion>(memOpLowering,
typeConverter, ctx);
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
ConversionTarget target(*ctx);
populateFlattenMemRefsLegality(target);
if (failed(applyPartialConversion(modOp, target, std::move(patterns))))
return signalPassFailure();
// Change the name of destination memory acceses in all stored memory
// dependencies to reflect the new access names
memOpLowering.renameDependencies(modOp);
}
};
} // namespace
namespace dynamatic {
std::unique_ptr<dynamatic::DynamaticPass> createFlattenMemRefRowMajorPass() {
return std::make_unique<FlattenMemRefRowMajorPass>();
}
} // namespace dynamatic