55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66//
77// ===----------------------------------------------------------------------===//
8- #include < vector >
8+ #include < unordered_set >
99
10- #define GC_GPU_OCL_DEF_ONLY
10+ #define GC_GPU_OCL_CONST_ONLY
1111#include " gc/ExecutionEngine/GPURuntime/GpuOclRuntime.h"
1212
1313#include " mlir/Conversion/LLVMCommon/ConversionTarget.h"
1717#include " mlir/Dialect/GPU/Transforms/Passes.h"
1818
1919using namespace mlir ;
20+ using namespace mlir ::gc::gpu;
2021
21- namespace mlir {
22- namespace gc {
22+ namespace mlir ::gc {
2323#define GEN_PASS_DECL_GPUTOGPUOCL
2424#define GEN_PASS_DEF_GPUTOGPUOCL
2525#include " gc/Transforms/Passes.h.inc"
26- } // namespace gc
27- } // namespace mlir
26+ } // namespace mlir::gc
2827
2928namespace {
30-
3129LLVM::CallOp funcCall (OpBuilder &builder, const StringRef name,
3230 const Type returnType, const ArrayRef<Type> argTypes,
3331 const Location loc, const ArrayRef<Value> arguments,
@@ -42,8 +40,10 @@ LLVM::CallOp funcCall(OpBuilder &builder, const StringRef name,
4240 return builder.create <LLVM::CallOp>(loc, function, arguments);
4341}
4442
45- // Assuming that the pointer to GcGpuOclContext is passed as the last
46- // memref<anyType> with zero dims argument of the current function.
43+ // Assuming that the pointer to the context is passed as the last argument
44+ // of the current function of type memref<anyType> with zero dims. When lowering
45+ // to LLVM, the memref arg is replaced with 3 args of types ptr, ptr, i64.
46+ // Returning the first one.
4747Value getCtxPtr (const OpBuilder &rewriter) {
4848 auto func =
4949 rewriter.getBlock ()->getParent ()->getParentOfType <LLVM::LLVMFuncOp>();
@@ -55,7 +55,7 @@ struct Helper final {
5555 Type voidType;
5656 Type ptrType;
5757 Type idxType;
58- mutable std::set<SmallString< 32 > > kernelNames;
58+ mutable std::unordered_set<std::string > kernelNames;
5959
6060 explicit Helper (MLIRContext *ctx, LLVMTypeConverter &converter)
6161 : converter(converter), voidType(LLVM::LLVMVoidType::get(ctx)),
@@ -81,7 +81,7 @@ struct Helper final {
8181 rewriter.create <LLVM::StoreOp>(loc, kernelPtrs[i], elementPtr);
8282 }
8383
84- funcCall (rewriter, GC_GPU_OCL_KERNEL_DESTROY , voidType, {idxType, ptrType},
84+ funcCall (rewriter, GPU_OCL_KERNEL_DESTROY , voidType, {idxType, ptrType},
8585 loc, {size, kernelPtrsArray});
8686 }
8787};
@@ -117,7 +117,7 @@ struct ConvertAlloc final : ConvertOpPattern<gpu::AllocOp> {
117117 }
118118 }
119119 auto size = helper.idxConstant (rewriter, loc, staticSize);
120- auto ptr = funcCall (rewriter, GC_GPU_OCL_MALLOC , helper.ptrType ,
120+ auto ptr = funcCall (rewriter, GPU_OCL_MALLOC , helper.ptrType ,
121121 {helper.ptrType , helper.idxType }, loc,
122122 {getCtxPtr (rewriter), size})
123123 .getResult ();
@@ -158,7 +158,7 @@ struct ConvertAlloc final : ConvertOpPattern<gpu::AllocOp> {
158158 }
159159
160160 size = idxMul (size, helper.idxConstant (rewriter, loc, staticSize));
161- auto ptr = funcCall (rewriter, GC_GPU_OCL_MALLOC , helper.ptrType ,
161+ auto ptr = funcCall (rewriter, GPU_OCL_MALLOC , helper.ptrType ,
162162 {helper.ptrType , helper.idxType }, loc,
163163 {getCtxPtr (rewriter), size})
164164 .getResult ();
@@ -194,7 +194,7 @@ struct ConvertDealloc final : ConvertOpPattern<gpu::DeallocOp> {
194194 auto loc = gpuDealloc.getLoc ();
195195 MemRefDescriptor dsc (adaptor.getMemref ());
196196 auto ptr = dsc.allocatedPtr (rewriter, loc);
197- auto oclDealloc = funcCall (rewriter, GC_GPU_OCL_DEALLOC , helper.voidType ,
197+ auto oclDealloc = funcCall (rewriter, GPU_OCL_DEALLOC , helper.voidType ,
198198 {helper.ptrType , helper.ptrType }, loc,
199199 {getCtxPtr (rewriter), ptr});
200200 rewriter.replaceOp (gpuDealloc, oclDealloc);
@@ -227,7 +227,7 @@ struct ConvertMemcpy final : ConvertOpPattern<gpu::MemcpyOp> {
227227 auto dstPtr = dstDsc.alignedPtr (rewriter, loc);
228228 auto size = helper.idxConstant (rewriter, loc, elementSize * numElements);
229229 auto oclMemcpy = funcCall (
230- rewriter, GC_GPU_OCL_MEMCPY , helper.voidType ,
230+ rewriter, GPU_OCL_MEMCPY , helper.voidType ,
231231 {helper.ptrType , helper.ptrType , helper.ptrType , helper.idxType }, loc,
232232 {getCtxPtr (rewriter), srcPtr, dstPtr, size});
233233 rewriter.replaceOp (gpuMemcpy, oclMemcpy);
@@ -249,7 +249,7 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
249249
250250 const Location loc = gpuLaunch.getLoc ();
251251 auto kernelArgs = adaptor.getKernelOperands ();
252- std::vector <Value> args;
252+ SmallVector <Value> args;
253253 args.reserve (kernelArgs.size () + 2 );
254254 args.emplace_back (getCtxPtr (rewriter));
255255 args.emplace_back (kernelPtr.value ());
@@ -265,7 +265,7 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
265265 }
266266
267267 const auto gpuOclLaunch =
268- funcCall (rewriter, GC_GPU_OCL_KERNEL_LAUNCH , helper.voidType ,
268+ funcCall (rewriter, GPU_OCL_KERNEL_LAUNCH , helper.voidType ,
269269 {helper.ptrType , helper.ptrType }, loc, args, true );
270270 rewriter.replaceOp (gpuLaunch, gpuOclLaunch);
271271 return success ();
@@ -284,7 +284,9 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
284284 SmallString<128 > getFuncName (" getGcGpuOclKernel_" );
285285 getFuncName.append (kernelModName);
286286
287- if (helper.kernelNames .insert (SmallString<32 >(kernelModName)).second ) {
287+ if (helper.kernelNames
288+ .insert (std::string (kernelModName.begin (), kernelModName.end ()))
289+ .second ) {
288290 auto insPoint = rewriter.saveInsertionPoint ();
289291 SmallString<128 > strBuf (" gcGpuOclKernel_" );
290292 strBuf.append (kernelModName);
@@ -391,10 +393,10 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
391393 auto spirv = LLVM::createGlobalString (loc, rewriter, str (" SPIRV" ),
392394 binaryAttr.getValue (),
393395 LLVM::Linkage::Internal);
394- auto spirvSize = rewriter.create <mlir:: LLVM::ConstantOp>(
396+ auto spirvSize = rewriter.create <LLVM::ConstantOp>(
395397 loc, helper.idxType ,
396- mlir:: IntegerAttr::get (helper.idxType ,
397- static_cast <int64_t >(binaryAttr.size ())));
398+ IntegerAttr::get (helper.idxType ,
399+ static_cast <int64_t >(binaryAttr.size ())));
398400
399401 SmallVector<int32_t > globalSize;
400402 SmallVector<int32_t > localSize;
@@ -436,7 +438,7 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
436438 auto argNum =
437439 helper.idxConstant (rewriter, loc, adaptor.getKernelOperands ().size ());
438440 auto createKernelCall = funcCall (
439- rewriter, GC_GPU_OCL_KERNEL_CREATE , helper.ptrType ,
441+ rewriter, GPU_OCL_KERNEL_CREATE , helper.ptrType ,
440442 {helper.ptrType , helper.idxType , helper.ptrType , helper.ptrType ,
441443 helper.ptrType , helper.ptrType , helper.idxType , helper.ptrType },
442444 loc,
@@ -501,7 +503,7 @@ struct GpuToGpuOcl final : gc::impl::GpuToGpuOclBase<GpuToGpuOcl> {
501503 assert (mod);
502504 OpBuilder rewriter (mod.getBody (), mod.getBody ()->end ());
503505 auto destruct = rewriter.create <LLVM::LLVMFuncOp>(
504- mod.getLoc (), GC_GPU_OCL_MOD_DESTRUCTOR ,
506+ mod.getLoc (), GPU_OCL_MOD_DESTRUCTOR ,
505507 LLVM::LLVMFunctionType::get (helper.voidType , {}),
506508 LLVM::Linkage::External);
507509 auto loc = destruct.getLoc ();
0 commit comments