Skip to content

Commit 048f66d

Browse files
Implemented GPU OpenCL runtime
1 parent 8635ad2 commit 048f66d

File tree

20 files changed

+1403
-11
lines changed

20 files changed

+1403
-11
lines changed

.github/workflows/clang-tidy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414

1515
steps:
1616
- name: Install OpenMP
17-
run: "sudo apt install -y libomp-dev"
17+
run: "sudo apt install -y libomp-dev opencl-c-headers"
1818

1919
- name: Fetch sources
2020
uses: actions/checkout@v4

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,15 @@ get_property(GC_TOOLS GLOBAL PROPERTY GC_TOOLS)
9898
get_property(GC_MLIR_LIBS GLOBAL PROPERTY GC_MLIR_LIBS)
9999
get_property(GC_PASS_LIBS GLOBAL PROPERTY GC_PASS_LIBS)
100100
get_property(GC_DIALECT_LIBS GLOBAL PROPERTY GC_DIALECT_LIBS)
101+
get_property(IMEX_LIBS GLOBAL PROPERTY IMEX_LIBS)
102+
101103
install(TARGETS
102104
GcInterface
103105
${GC_TOOLS}
104106
${GC_MLIR_LIBS}
105107
${GC_PASS_LIBS}
106108
${GC_DIALECT_LIBS}
109+
${IMEX_LIBS}
107110
EXPORT ${PROJECT_NAME}Targets
108111
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
109112
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}

cmake/imex.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ if (NOT DEFINED IMEX_INCLUDES)
2424
${imex_SOURCE_DIR}/src
2525
)
2626
set_property(GLOBAL PROPERTY IMEX_INCLUDES ${IMEX_INCLUDES})
27+
target_compile_options(GcInterface INTERFACE -DGC_USE_IMEX)
2728
endif ()

include/gc/ExecutionEngine/GPURuntime/GpuOclRuntime.h

Lines changed: 293 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,301 @@ constexpr char GPU_OCL_MOD_DESTRUCTOR[] = "gcGpuOclModuleDestructor";
2020
} // namespace mlir::gc::gpu
2121

2222
#ifndef GC_GPU_OCL_CONST_ONLY
23+
#include <cstdarg>
24+
#include <unordered_set>
25+
#include <vector>
2326

24-
// TBD
27+
#include <CL/cl.h>
2528

29+
#include <llvm/ADT/SmallString.h>
30+
31+
#include "mlir/ExecutionEngine/ExecutionEngine.h"
32+
#include "mlir/IR/BuiltinOps.h"
33+
34+
namespace mlir::gc::gpu {
35+
struct OclDevCtxPair {
36+
cl_device_id device;
37+
cl_context context;
38+
explicit OclDevCtxPair(cl_device_id device, cl_context context)
39+
: device(device), context(context) {}
40+
41+
bool operator==(const OclDevCtxPair &other) const {
42+
return device == other.device && context == other.context;
43+
}
44+
};
45+
} // namespace mlir::gc::gpu
46+
template <> struct std::hash<const mlir::gc::gpu::OclDevCtxPair> {
47+
std::size_t
48+
operator()(const mlir::gc::gpu::OclDevCtxPair &pair) const noexcept {
49+
return std::hash<cl_device_id>()(pair.device) ^
50+
std::hash<cl_context>()(pair.context);
51+
}
52+
}; // namespace std
53+
namespace mlir::gc::gpu {
54+
struct OclModule;
55+
struct OclContext;
56+
struct OclModuleBuilder;
57+
58+
struct OclRuntime {
59+
// Returns the available Intel GPU device ids.
60+
[[nodiscard]] static llvm::Expected<SmallVector<cl_device_id, 2>>
61+
gcIntelDevices(size_t max = std::numeric_limits<size_t>::max());
62+
63+
[[nodiscard]] static llvm::Expected<OclRuntime> get();
64+
65+
[[nodiscard]] static llvm::Expected<OclRuntime> get(cl_device_id device);
66+
67+
[[nodiscard]] static llvm::Expected<OclRuntime> get(cl_command_queue queue);
68+
69+
[[nodiscard]] static llvm::Expected<OclRuntime> get(cl_device_id device,
70+
cl_context context);
71+
72+
static bool isOutOfOrder(cl_command_queue queue);
73+
74+
[[nodiscard]] cl_context getContext() const;
75+
76+
[[nodiscard]] cl_device_id getDevice() const;
77+
78+
[[nodiscard]] llvm::Expected<cl_command_queue>
79+
createQueue(bool outOfOrder = false) const;
80+
81+
[[nodiscard]] static llvm::Expected<bool>
82+
releaseQueue(cl_command_queue queue);
83+
84+
[[nodiscard]] llvm::Expected<void *> usmAllocDev(size_t size) const;
85+
86+
[[nodiscard]] llvm::Expected<void *> usmAllocShared(size_t size) const;
87+
88+
[[nodiscard]] llvm::Expected<bool> usmFree(const void *ptr) const;
89+
90+
[[nodiscard]] llvm::Expected<bool> usmCpy(OclContext &ctx, const void *src,
91+
void *dst, size_t size) const;
92+
93+
template <typename T>
94+
[[nodiscard]] llvm::Expected<T *> usmNewDev(size_t size) const {
95+
auto expected = usmAllocDev(size * sizeof(T));
96+
if (expected) {
97+
return static_cast<T *>(*expected);
98+
}
99+
return expected.takeError();
100+
}
101+
102+
template <typename T>
103+
[[nodiscard]] llvm::Expected<T *> usmNewShared(size_t size) const {
104+
auto expected = usmAllocShared(size * sizeof(T));
105+
if (expected) {
106+
return static_cast<T *>(*expected);
107+
}
108+
return expected.takeError();
109+
}
110+
111+
template <typename T>
112+
[[nodiscard]] llvm::Expected<bool> usmCpy(OclContext &ctx, const T *src,
113+
T *dst, size_t size) const {
114+
return usmCpy(ctx, static_cast<const void *>(src), static_cast<void *>(dst),
115+
size * sizeof(T));
116+
}
117+
118+
// Use with caution! This is safe to check validity of USM, but may be false
119+
// positive for any other kinds.
120+
bool isUsm(const void *ptr) const;
121+
122+
bool operator==(const OclRuntime &other) const {
123+
return getDevice() == other.getDevice() &&
124+
getContext() == other.getContext();
125+
}
126+
127+
private:
128+
struct Ext;
129+
struct Exports;
130+
friend OclContext;
131+
friend OclModuleBuilder;
132+
explicit OclRuntime(const Ext &ext);
133+
const Ext &ext;
134+
};
135+
136+
static constexpr int64_t ZERO = 0;
137+
static constexpr auto ZERO_PTR = const_cast<int64_t *>(&ZERO);
138+
139+
// NOTE: The context is mutable and not thread-safe! It's expected to be used in
140+
// a single thread only.
141+
struct OclContext {
142+
const OclRuntime &runtime;
143+
cl_command_queue const queue;
144+
// Preserve the execution order. This is required in case of out-of-order
145+
// execution (CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE). When the execution
146+
// is completed, the 'lastEvent' field contains the event of the last enqueued
147+
// command. If this field is false, 'waitList' is ignored.
148+
const bool preserveOrder;
149+
cl_uint waitListLen;
150+
cl_event *waitList;
151+
cl_event lastEvent;
152+
153+
explicit OclContext(const OclRuntime &runtime, cl_command_queue queue,
154+
cl_uint waitListLen = 0, cl_event *waitList = nullptr)
155+
: OclContext(runtime, queue, OclRuntime::isOutOfOrder(queue), waitListLen,
156+
waitList) {}
157+
158+
explicit OclContext(const OclRuntime &runtime, cl_command_queue queue,
159+
bool preserveOrder, cl_uint waitListLen,
160+
cl_event *waitList)
161+
: runtime(runtime), queue(queue), preserveOrder(preserveOrder),
162+
waitListLen(preserveOrder ? waitListLen : 0),
163+
waitList(preserveOrder ? waitList : nullptr), lastEvent(nullptr) {
164+
assert(!OclRuntime::isOutOfOrder(queue) || preserveOrder);
165+
assert(preserveOrder || (waitListLen == 0 && waitList == nullptr));
166+
}
167+
168+
OclContext(const OclContext &) = delete;
169+
OclContext &operator=(const OclContext &) = delete;
170+
171+
void finish();
172+
173+
private:
174+
friend OclRuntime;
175+
friend OclRuntime::Exports;
176+
template <unsigned N> friend struct OclModuleArgs;
177+
// Contains the pointers of all non-USM arguments. It's expected, that the
178+
// arguments are either USM or CL pointers and most probably are USM, thus,
179+
// in most cases, this set will be empty.
180+
std::unordered_set<void *> clPtrs;
181+
182+
void setLastEvent(cl_event event) {
183+
lastEvent = event;
184+
if (event) {
185+
waitListLen = 1;
186+
waitList = &lastEvent;
187+
} else {
188+
waitListLen = 0;
189+
waitList = nullptr;
190+
}
191+
}
192+
};
193+
194+
// The main function arguments in the following format -
195+
// https://mlir.llvm.org/docs/TargetLLVMIR/#c-compatible-wrapper-emission.
196+
// NOTE: The values are not copied, only the pointers are stored!
197+
// NOTE: This class is mutable and not thread-safe!
198+
template <unsigned N = 64> struct OclModuleArgs {
199+
explicit OclModuleArgs(OclContext &ctx) : ctx(ctx) {}
200+
OclModuleArgs(const OclModuleArgs &) = delete;
201+
OclModuleArgs &operator=(const OclModuleArgs &) = delete;
202+
203+
void add(void *&alignedPtr, size_t rank, const int64_t *shape,
204+
const int64_t *strides, bool isUsm = true) {
205+
add(alignedPtr, alignedPtr, ZERO, rank, shape, strides, isUsm);
206+
}
207+
208+
void add(void *&allocatedPtr, void *&alignedPtr, const int64_t &offset,
209+
size_t rank, const int64_t *shape, const int64_t *strides,
210+
bool isUsm = true) {
211+
#ifndef NDEBUG
212+
assert(!isUsm || ctx.runtime.isUsm(alignedPtr));
213+
// It's recommended to have at least 16-byte alignment
214+
assert(reinterpret_cast<std::uintptr_t>(alignedPtr) % 16 == 0);
215+
#endif
216+
217+
args.emplace_back(&allocatedPtr);
218+
args.emplace_back(&alignedPtr);
219+
args.emplace_back(const_cast<int64_t *>(&offset));
220+
for (size_t i = 0; i < rank; i++) {
221+
args.emplace_back(const_cast<int64_t *>(&shape[i]));
222+
}
223+
for (size_t i = 0; i < rank; i++) {
224+
args.emplace_back(const_cast<int64_t *>(&strides[i]));
225+
}
226+
if (!isUsm) {
227+
ctx.clPtrs.insert(alignedPtr);
228+
}
229+
}
230+
231+
template <typename T>
232+
void add(T *&alignedPtr, size_t rank, const int64_t *shape,
233+
const int64_t *strides, bool isUsm = true) {
234+
add(reinterpret_cast<void *&>(alignedPtr), rank, shape, strides, isUsm);
235+
}
236+
237+
template <typename T>
238+
void add(T *&allocatedPtr, T *&alignedPtr, const int64_t &offset, size_t rank,
239+
const int64_t *shape, const int64_t *strides, bool isUsm = true) {
240+
add(reinterpret_cast<void *&>(allocatedPtr),
241+
reinterpret_cast<void *&>(alignedPtr), offset, rank, shape, strides,
242+
isUsm);
243+
}
244+
245+
void clear() {
246+
args.clear();
247+
ctx.clPtrs.clear();
248+
}
249+
250+
private:
251+
friend OclModule;
252+
OclContext &ctx;
253+
SmallVector<void *, N + 3> args;
254+
};
255+
256+
struct OclModule {
257+
const OclRuntime runtime;
258+
259+
using MainFunc = void (*)(void **);
260+
261+
explicit OclModule(const OclRuntime &runtime,
262+
std::unique_ptr<ExecutionEngine> engine, MainFunc main)
263+
: runtime(runtime), engine(std::move(engine)), main(main) {}
264+
265+
template <unsigned N> void exec(OclModuleArgs<N> &args) const {
266+
OclContext &ctx = args.ctx;
267+
#ifndef NDEBUG
268+
auto rt = OclRuntime::get(ctx.queue);
269+
assert(rt);
270+
assert(*rt == this->runtime);
271+
#endif
272+
auto size = args.args.size();
273+
auto ctxPtr = &ctx;
274+
args.args.emplace_back(&ctxPtr);
275+
args.args.emplace_back(&ctxPtr);
276+
args.args.emplace_back(ZERO_PTR);
277+
main(args.args.data());
278+
args.args.truncate(size);
279+
}
280+
281+
~OclModule();
282+
OclModule(const OclModule &) = delete;
283+
OclModule &operator=(const OclModule &) = delete;
284+
OclModule(const OclModule &&) = delete;
285+
OclModule &operator=(const OclModule &&) = delete;
286+
287+
private:
288+
std::unique_ptr<ExecutionEngine> engine;
289+
MainFunc main;
290+
};
291+
292+
struct OclModuleBuilder {
293+
friend OclRuntime;
294+
explicit OclModuleBuilder(ModuleOp module);
295+
explicit OclModuleBuilder(OwningOpRef<ModuleOp> &module)
296+
: OclModuleBuilder(module.release()) {}
297+
298+
llvm::Expected<std::shared_ptr<const OclModule>>
299+
build(const OclRuntime &runtime);
300+
301+
llvm::Expected<std::shared_ptr<const OclModule>>
302+
build(cl_command_queue queue);
303+
304+
llvm::Expected<std::shared_ptr<const OclModule>> build(cl_device_id device,
305+
cl_context context);
306+
307+
private:
308+
std::shared_mutex mux;
309+
ModuleOp mlirModule;
310+
SmallString<32> funcName;
311+
std::unordered_map<const OclDevCtxPair, std::shared_ptr<const OclModule>>
312+
cache;
313+
llvm::Expected<std::shared_ptr<const OclModule>>
314+
315+
build(const OclRuntime::Ext &ext);
316+
};
317+
}; // namespace mlir::gc::gpu
26318
#else
27319
#undef GC_GPU_OCL_CONST_ONLY
28320
#endif

include/gc/Transforms/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ std::unique_ptr<Pass> createMergeAllocPass();
115115
void populateFrontendPasses(mlir::OpPassManager &);
116116
void populateCPUPipeline(mlir::OpPassManager &);
117117

118+
#ifdef GC_USE_IMEX
119+
void populateGPUPipeline(mlir::OpPassManager &);
120+
#endif
121+
118122
#define GEN_PASS_DECL
119123
#include "gc/Transforms/Passes.h.inc"
120124

0 commit comments

Comments
 (0)