forked from EPFL-LAP/dynamatic
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathHandshakeInferBasicBlocks.cpp
151 lines (132 loc) · 5.14 KB
/
HandshakeInferBasicBlocks.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
//===- HandshakeInferBasicBlocks.cpp - Infer ops basic blocks ---*- C++ -*-===//
//
// Dynamatic is under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// The basic block inference pass is implemented as a single operation
// conversion pattern that iterates over all operations in a function repeatedly
// until no more inferences can be performed, at which point it succeeds.
//
// A local inference heuristic is applied on each operation eligible for
// inference. The locality of the heuristic may require the pass to run the
// inference logic on eligible operations multiple times in order to let
// inference results propagate incrementally to their immediate graph neighbors.
//
//===----------------------------------------------------------------------===//
#include "dynamatic/Transforms/HandshakeInferBasicBlocks.h"
#include "dynamatic/Dialect/Handshake/HandshakeOps.h"
#include "dynamatic/Support/CFG.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace dynamatic;
/// Determines if the pass should attempt to infer the basic block of the
/// operation if it is missing.
static bool isLegalForInference(Operation *op) {
return !isa<handshake::MemoryOpInterface, handshake::SinkOp>(op);
}
/// Iterates over all operations legal for inference that do not have a "bb"
/// attribute and tries to infer it.
static bool inferBasicBlocks(Operation *op, PatternRewriter &rewriter) {
// Check whether we even need to run inference for the operation
if (!isLegalForInference(op))
return false;
if (std::optional<unsigned> bb = getLogicBB(op); bb.has_value())
return false;
// Run the inference logic
unsigned infBB;
if (succeeded(inferLogicBB(op, infBB))) {
op->setAttr(BB_ATTR_NAME, rewriter.getUI32IntegerAttr(infBB));
return true;
}
return false;
}
LogicalResult dynamatic::inferLogicBB(Operation *op, unsigned &logicBB) {
std::optional<unsigned> infBB;
auto mergeInferredBB = [&](std::optional<unsigned> otherBB) -> LogicalResult {
if (!otherBB.has_value() || (infBB.has_value() && *infBB != *otherBB)) {
infBB = std::nullopt;
return failure();
}
infBB = *otherBB;
return success();
};
// First, try to infer the basic block of the current operation by looking at
// its successors (i.e., users of its results). If they all belong to the same
// basic block, then we can safely say that the current operation also belongs
// to it
for (OpResult res : op->getResults()) {
bool conflict = false;
for (Operation *user : res.getUsers())
if (failed(mergeInferredBB(getLogicBB(user)))) {
conflict = true;
break;
}
if (conflict)
break;
}
// If the successor analysis successfully inferred a basic block, return this
// one; otherwise, run the predecessor analysis.
if (infBB.has_value()) {
logicBB = *infBB;
return success();
}
// Second, try to infer the basic block of the current operation by looking at
// its predecessors (i.e., producers of its operarands). If they all belong to
// the same basic block, then we can safely say that the current operation
// also belongs to it
for (Value opr : op->getOperands()) {
Operation *defOp = opr.getDefiningOp();
std::optional<unsigned> oprBB = defOp ? getLogicBB(defOp) : ENTRY_BB;
if (failed(mergeInferredBB(oprBB))) {
return failure();
}
}
if (infBB.has_value()) {
logicBB = *infBB;
return success();
}
return failure();
}
namespace {
/// Tries to infer the basic block of untagged operations in a function.
struct FuncOpInferBasicBlocks : public OpConversionPattern<handshake::FuncOp> {
FuncOpInferBasicBlocks(MLIRContext *ctx) : OpConversionPattern(ctx) {}
LogicalResult
matchAndRewrite(handshake::FuncOp funcOp, OpAdaptor /*adaptor*/,
ConversionPatternRewriter &rewriter) const override {
rewriter.updateRootInPlace(funcOp, [&] {
bool progress = false;
do {
progress = false;
for (Operation &op : funcOp.getOps())
progress |= inferBasicBlocks(&op, rewriter);
} while (progress);
});
return success();
}
};
/// Simple driver for basic block inference pass. Runs a partial conversion by
/// using a single operation conversion pattern on each handshake::FuncOp in the
/// module.
struct HandshakeInferBasicBlocksPass
: public dynamatic::impl::HandshakeInferBasicBlocksBase<
HandshakeInferBasicBlocksPass> {
void runDynamaticPass() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns{ctx};
patterns.add<FuncOpInferBasicBlocks>(ctx);
ConversionTarget target(*ctx);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
};
};
} // namespace
std::unique_ptr<dynamatic::DynamaticPass>
dynamatic::createHandshakeInferBasicBlocksPass() {
return std::make_unique<HandshakeInferBasicBlocksPass>();
}