@@ -398,38 +398,37 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
398398 IntegerAttr::get (helper.idxType ,
399399 static_cast <int64_t >(binaryAttr.size ())));
400400
401- SmallVector<int32_t > globalSize;
402- SmallVector<int32_t > localSize;
403- SmallVector<int32_t > argSize;
404- kernelMod->walk ([&](gpu::GPUFuncOp func) {
405- if (func.getName () == gpuLaunch.getKernelName ()) {
406- for (auto s : func.getKnownGridSize ().value ()) {
407- globalSize.emplace_back (s);
408- }
409- for (auto s : func.getKnownBlockSize ().value ()) {
410- localSize.emplace_back (s);
411- }
412- }
413- });
414- assert (globalSize.size () == 3 && localSize.size () == 3 );
415- globalSize = {globalSize[0 ] * localSize[0 ], globalSize[1 ] * localSize[1 ],
416- globalSize[2 ] * localSize[2 ]};
401+ SmallVector<Value> gridSize;
402+ SmallVector<Value> blockSize;
403+ SmallVector<Value> argSize;
404+ gridSize.emplace_back (gpuLaunch.getGridSizeX ());
405+ gridSize.emplace_back (gpuLaunch.getGridSizeY ());
406+ gridSize.emplace_back (gpuLaunch.getGridSizeZ ());
407+ blockSize.emplace_back (gpuLaunch.getBlockSizeX ());
408+ blockSize.emplace_back (gpuLaunch.getBlockSizeY ());
409+ blockSize.emplace_back (gpuLaunch.getBlockSizeZ ());
410+
417411 for (auto arg : adaptor.getKernelOperands ()) {
418412 auto type = arg.getType ();
419413 auto size = type.isIntOrFloat () ? type.getIntOrFloatBitWidth () / 8 : 0 ;
420- argSize.emplace_back (size);
414+ argSize.emplace_back (helper. idxConstant (rewriter, loc, size) );
421415 }
422416
423- auto array = [&](SmallVector<int32_t > &values) {
417+ auto array = [&](SmallVector<Value > &values) {
424418 auto size = helper.idxConstant (rewriter, loc, values.size ());
425419 auto arrayPtr = rewriter.create <LLVM::AllocaOp>(loc, helper.ptrType ,
426420 helper.idxType , size);
427421 for (size_t i = 0 , n = values.size (); i < n; i++) {
428422 auto elementPtr = rewriter.create <LLVM::GEPOp>(
429423 loc, helper.ptrType , helper.idxType , arrayPtr,
430424 helper.idxConstant (rewriter, loc, i));
431- rewriter.create <LLVM::StoreOp>(
432- loc, helper.idxConstant (rewriter, loc, values[i]), elementPtr);
425+ auto value = values[i];
426+ if (auto cast = value.getDefiningOp <UnrealizedConversionCastOp>()) {
427+ assert (getConstantIntValue (cast.getOperand (0 )));
428+ value = helper.idxConstant (
429+ rewriter, loc, getConstantIntValue (cast.getOperand (0 )).value ());
430+ }
431+ rewriter.create <LLVM::StoreOp>(loc, value, elementPtr);
433432 }
434433 return arrayPtr.getResult ();
435434 };
@@ -442,8 +441,8 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
442441 {helper.ptrType , helper.idxType , helper.ptrType , helper.ptrType ,
443442 helper.ptrType , helper.ptrType , helper.idxType , helper.ptrType },
444443 loc,
445- {ctx, spirvSize, spirv, name, array (globalSize ), array (localSize) ,
446- argNum, array (argSize)});
444+ {ctx, spirvSize, spirv, name, array (gridSize ), array (blockSize), argNum ,
445+ array (argSize)});
447446 auto result = createKernelCall.getResult ();
448447
449448 // Save the kernel pointer to the global var using CAS
0 commit comments