@@ -20,9 +20,301 @@ constexpr char GPU_OCL_MOD_DESTRUCTOR[] = "gcGpuOclModuleDestructor";
20
20
} // namespace mlir::gc::gpu
21
21
22
22
#ifndef GC_GPU_OCL_CONST_ONLY
23
+ #include < cstdarg>
24
+ #include < unordered_set>
25
+ #include < vector>
23
26
24
- // TBD
27
+ # include < CL/cl.h >
25
28
29
+ #include < llvm/ADT/SmallString.h>
30
+
31
+ #include " mlir/ExecutionEngine/ExecutionEngine.h"
32
+ #include " mlir/IR/BuiltinOps.h"
33
+
34
+ namespace mlir ::gc::gpu {
35
+ struct OclDevCtxPair {
36
+ cl_device_id device;
37
+ cl_context context;
38
+ explicit OclDevCtxPair (cl_device_id device, cl_context context)
39
+ : device(device), context(context) {}
40
+
41
+ bool operator ==(const OclDevCtxPair &other) const {
42
+ return device == other.device && context == other.context ;
43
+ }
44
+ };
45
+ } // namespace mlir::gc::gpu
46
+ template <> struct std ::hash<const mlir::gc::gpu::OclDevCtxPair> {
47
+ std::size_t
48
+ operator ()(const mlir::gc::gpu::OclDevCtxPair &pair) const noexcept {
49
+ return std::hash<cl_device_id>()(pair.device ) ^
50
+ std::hash<cl_context>()(pair.context );
51
+ }
52
+ }; // namespace std
53
+ namespace mlir ::gc::gpu {
54
+ struct OclModule ;
55
+ struct OclContext ;
56
+ struct OclModuleBuilder ;
57
+
58
+ struct OclRuntime {
59
+ // Returns the available Intel GPU device ids.
60
+ [[nodiscard]] static llvm::Expected<SmallVector<cl_device_id, 2 >>
61
+ gcIntelDevices (size_t max = std::numeric_limits<size_t >::max());
62
+
63
+ [[nodiscard]] static llvm::Expected<OclRuntime> get ();
64
+
65
+ [[nodiscard]] static llvm::Expected<OclRuntime> get (cl_device_id device);
66
+
67
+ [[nodiscard]] static llvm::Expected<OclRuntime> get (cl_command_queue queue);
68
+
69
+ [[nodiscard]] static llvm::Expected<OclRuntime> get (cl_device_id device,
70
+ cl_context context);
71
+
72
+ static bool isOutOfOrder (cl_command_queue queue);
73
+
74
+ [[nodiscard]] cl_context getContext () const ;
75
+
76
+ [[nodiscard]] cl_device_id getDevice () const ;
77
+
78
+ [[nodiscard]] llvm::Expected<cl_command_queue>
79
+ createQueue (bool outOfOrder = false ) const ;
80
+
81
+ [[nodiscard]] static llvm::Expected<bool >
82
+ releaseQueue (cl_command_queue queue);
83
+
84
+ [[nodiscard]] llvm::Expected<void *> usmAllocDev (size_t size) const ;
85
+
86
+ [[nodiscard]] llvm::Expected<void *> usmAllocShared (size_t size) const ;
87
+
88
+ [[nodiscard]] llvm::Expected<bool > usmFree (const void *ptr) const ;
89
+
90
+ [[nodiscard]] llvm::Expected<bool > usmCpy (OclContext &ctx, const void *src,
91
+ void *dst, size_t size) const ;
92
+
93
+ template <typename T>
94
+ [[nodiscard]] llvm::Expected<T *> usmNewDev (size_t size) const {
95
+ auto expected = usmAllocDev (size * sizeof (T));
96
+ if (expected) {
97
+ return static_cast <T *>(*expected);
98
+ }
99
+ return expected.takeError ();
100
+ }
101
+
102
+ template <typename T>
103
+ [[nodiscard]] llvm::Expected<T *> usmNewShared (size_t size) const {
104
+ auto expected = usmAllocShared (size * sizeof (T));
105
+ if (expected) {
106
+ return static_cast <T *>(*expected);
107
+ }
108
+ return expected.takeError ();
109
+ }
110
+
111
+ template <typename T>
112
+ [[nodiscard]] llvm::Expected<bool > usmCpy (OclContext &ctx, const T *src,
113
+ T *dst, size_t size) const {
114
+ return usmCpy (ctx, static_cast <const void *>(src), static_cast <void *>(dst),
115
+ size * sizeof (T));
116
+ }
117
+
118
+ // Use with caution! This is safe to check validity of USM, but may be false
119
+ // positive for any other kinds.
120
+ bool isUsm (const void *ptr) const ;
121
+
122
+ bool operator ==(const OclRuntime &other) const {
123
+ return getDevice () == other.getDevice () &&
124
+ getContext () == other.getContext ();
125
+ }
126
+
127
+ private:
128
+ struct Ext ;
129
+ struct Exports ;
130
+ friend OclContext;
131
+ friend OclModuleBuilder;
132
+ explicit OclRuntime (const Ext &ext);
133
+ const Ext &ext;
134
+ };
135
+
136
+ static constexpr int64_t ZERO = 0 ;
137
+ static constexpr auto ZERO_PTR = const_cast <int64_t *>(&ZERO);
138
+
139
+ // NOTE: The context is mutable and not thread-safe! It's expected to be used in
140
+ // a single thread only.
141
+ struct OclContext {
142
+ const OclRuntime &runtime;
143
+ cl_command_queue const queue;
144
+ // Preserve the execution order. This is required in case of out-of-order
145
+ // execution (CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE). When the execution
146
+ // is completed, the 'lastEvent' field contains the event of the last enqueued
147
+ // command. If this field is false, 'waitList' is ignored.
148
+ const bool preserveOrder;
149
+ cl_uint waitListLen;
150
+ cl_event *waitList;
151
+ cl_event lastEvent;
152
+
153
+ explicit OclContext (const OclRuntime &runtime, cl_command_queue queue,
154
+ cl_uint waitListLen = 0 , cl_event *waitList = nullptr )
155
+ : OclContext(runtime, queue, OclRuntime::isOutOfOrder(queue), waitListLen,
156
+ waitList) {}
157
+
158
+ explicit OclContext (const OclRuntime &runtime, cl_command_queue queue,
159
+ bool preserveOrder, cl_uint waitListLen,
160
+ cl_event *waitList)
161
+ : runtime(runtime), queue(queue), preserveOrder(preserveOrder),
162
+ waitListLen(preserveOrder ? waitListLen : 0 ),
163
+ waitList(preserveOrder ? waitList : nullptr ), lastEvent(nullptr ) {
164
+ assert (!OclRuntime::isOutOfOrder (queue) || preserveOrder);
165
+ assert (preserveOrder || (waitListLen == 0 && waitList == nullptr ));
166
+ }
167
+
168
+ OclContext (const OclContext &) = delete ;
169
+ OclContext &operator =(const OclContext &) = delete ;
170
+
171
+ void finish ();
172
+
173
+ private:
174
+ friend OclRuntime;
175
+ friend OclRuntime::Exports;
176
+ template <unsigned N> friend struct OclModuleArgs ;
177
+ // Contains the pointers of all non-USM arguments. It's expected, that the
178
+ // arguments are either USM or CL pointers and most probably are USM, thus,
179
+ // in most cases, this set will be empty.
180
+ std::unordered_set<void *> clPtrs;
181
+
182
+ void setLastEvent (cl_event event) {
183
+ lastEvent = event;
184
+ if (event) {
185
+ waitListLen = 1 ;
186
+ waitList = &lastEvent;
187
+ } else {
188
+ waitListLen = 0 ;
189
+ waitList = nullptr ;
190
+ }
191
+ }
192
+ };
193
+
194
+ // The main function arguments in the following format -
195
+ // https://mlir.llvm.org/docs/TargetLLVMIR/#c-compatible-wrapper-emission.
196
+ // NOTE: The values are not copied, only the pointers are stored!
197
+ // NOTE: This class is mutable and not thread-safe!
198
+ template <unsigned N = 64 > struct OclModuleArgs {
199
+ explicit OclModuleArgs (OclContext &ctx) : ctx(ctx) {}
200
+ OclModuleArgs (const OclModuleArgs &) = delete ;
201
+ OclModuleArgs &operator =(const OclModuleArgs &) = delete ;
202
+
203
+ void add (void *&alignedPtr, size_t rank, const int64_t *shape,
204
+ const int64_t *strides, bool isUsm = true ) {
205
+ add (alignedPtr, alignedPtr, ZERO, rank, shape, strides, isUsm);
206
+ }
207
+
208
+ void add (void *&allocatedPtr, void *&alignedPtr, const int64_t &offset,
209
+ size_t rank, const int64_t *shape, const int64_t *strides,
210
+ bool isUsm = true ) {
211
+ #ifndef NDEBUG
212
+ assert (!isUsm || ctx.runtime .isUsm (alignedPtr));
213
+ // It's recommended to have at least 16-byte alignment
214
+ assert (reinterpret_cast <std::uintptr_t >(alignedPtr) % 16 == 0 );
215
+ #endif
216
+
217
+ args.emplace_back (&allocatedPtr);
218
+ args.emplace_back (&alignedPtr);
219
+ args.emplace_back (const_cast <int64_t *>(&offset));
220
+ for (size_t i = 0 ; i < rank; i++) {
221
+ args.emplace_back (const_cast <int64_t *>(&shape[i]));
222
+ }
223
+ for (size_t i = 0 ; i < rank; i++) {
224
+ args.emplace_back (const_cast <int64_t *>(&strides[i]));
225
+ }
226
+ if (!isUsm) {
227
+ ctx.clPtrs .insert (alignedPtr);
228
+ }
229
+ }
230
+
231
+ template <typename T>
232
+ void add (T *&alignedPtr, size_t rank, const int64_t *shape,
233
+ const int64_t *strides, bool isUsm = true ) {
234
+ add (reinterpret_cast <void *&>(alignedPtr), rank, shape, strides, isUsm);
235
+ }
236
+
237
+ template <typename T>
238
+ void add (T *&allocatedPtr, T *&alignedPtr, const int64_t &offset, size_t rank,
239
+ const int64_t *shape, const int64_t *strides, bool isUsm = true ) {
240
+ add (reinterpret_cast <void *&>(allocatedPtr),
241
+ reinterpret_cast <void *&>(alignedPtr), offset, rank, shape, strides,
242
+ isUsm);
243
+ }
244
+
245
+ void clear () {
246
+ args.clear ();
247
+ ctx.clPtrs .clear ();
248
+ }
249
+
250
+ private:
251
+ friend OclModule;
252
+ OclContext &ctx;
253
+ SmallVector<void *, N + 3 > args;
254
+ };
255
+
256
+ struct OclModule {
257
+ const OclRuntime runtime;
258
+
259
+ using MainFunc = void (*)(void **);
260
+
261
+ explicit OclModule (const OclRuntime &runtime,
262
+ std::unique_ptr<ExecutionEngine> engine, MainFunc main)
263
+ : runtime(runtime), engine(std::move(engine)), main(main) {}
264
+
265
+ template <unsigned N> void exec (OclModuleArgs<N> &args) const {
266
+ OclContext &ctx = args.ctx ;
267
+ #ifndef NDEBUG
268
+ auto rt = OclRuntime::get (ctx.queue );
269
+ assert (rt);
270
+ assert (*rt == this ->runtime );
271
+ #endif
272
+ auto size = args.args .size ();
273
+ auto ctxPtr = &ctx;
274
+ args.args .emplace_back (&ctxPtr);
275
+ args.args .emplace_back (&ctxPtr);
276
+ args.args .emplace_back (ZERO_PTR);
277
+ main (args.args .data ());
278
+ args.args .truncate (size);
279
+ }
280
+
281
+ ~OclModule ();
282
+ OclModule (const OclModule &) = delete ;
283
+ OclModule &operator =(const OclModule &) = delete ;
284
+ OclModule (const OclModule &&) = delete ;
285
+ OclModule &operator =(const OclModule &&) = delete ;
286
+
287
+ private:
288
+ std::unique_ptr<ExecutionEngine> engine;
289
+ MainFunc main;
290
+ };
291
+
292
+ struct OclModuleBuilder {
293
+ friend OclRuntime;
294
+ explicit OclModuleBuilder (ModuleOp module );
295
+ explicit OclModuleBuilder (OwningOpRef<ModuleOp> &module )
296
+ : OclModuleBuilder(module .release()) {}
297
+
298
+ llvm::Expected<std::shared_ptr<const OclModule>>
299
+ build (const OclRuntime &runtime);
300
+
301
+ llvm::Expected<std::shared_ptr<const OclModule>>
302
+ build (cl_command_queue queue);
303
+
304
+ llvm::Expected<std::shared_ptr<const OclModule>> build (cl_device_id device,
305
+ cl_context context);
306
+
307
+ private:
308
+ std::shared_mutex mux;
309
+ ModuleOp mlirModule;
310
+ SmallString<32 > funcName;
311
+ std::unordered_map<const OclDevCtxPair, std::shared_ptr<const OclModule>>
312
+ cache;
313
+ llvm::Expected<std::shared_ptr<const OclModule>>
314
+
315
+ build (const OclRuntime::Ext &ext);
316
+ };
317
+ }; // namespace mlir::gc::gpu
26
318
#else
27
319
#undef GC_GPU_OCL_CONST_ONLY
28
320
#endif
0 commit comments