Skip to content

Commit

Permalink
Apply changes from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreyPavlenko committed Sep 24, 2024
1 parent 501d434 commit 74842e0
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 160 deletions.
2 changes: 1 addition & 1 deletion cmake/imex-version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
20f2eef4f6c10fcbd68d358591c8b3ef4d1b97d2
20f2eef4f6c10fcbd68d358591c8b3ef4d1b97d2
17 changes: 13 additions & 4 deletions lib/gc/Transforms/GPU/AddContextArg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ namespace {
struct AddContextArg final : gc::impl::AddContextArgBase<AddContextArg> {
void runOnOperation() override {
auto func = getOperation();
if (func.isExternal()) {
return;
}

auto funcType = func.getFunctionType();
auto argTypes = llvm::to_vector<8>(funcType.getInputs());
auto resultTypes = llvm::to_vector<1>(funcType.getResults());
Expand All @@ -28,14 +32,19 @@ struct AddContextArg final : gc::impl::AddContextArgBase<AddContextArg> {
argTypes.emplace_back(newArgType);
auto newFuncType = FunctionType::get(ctx, argTypes, resultTypes);
func.setType(newFuncType);

if (func.getBody().hasOneBlock()) {
func.getBody().front().addArgument(newArgType, func.getLoc());
}
func.getBody().front().addArgument(newArgType, func.getLoc());

// Find all function calls and append the last argument of the current
// function to the call.
auto module = func->getParentOfType<ModuleOp>();
func.walk([&](func::CallOp call) {
// If the function to be called is defined in the current module, then the
// context arg will be added to this function signature either and, thus,
// wee need add the context arg to the function call.
if (auto callee = module.lookupSymbol<func::FuncOp>(call.getCallee());
!callee || callee.isExternal()) {
return;
}
auto args = llvm::to_vector<8>(call.getOperands());
args.emplace_back(func.getArgument(func.getNumArguments() - 1));
call->setOperands(args);
Expand Down
119 changes: 53 additions & 66 deletions lib/gc/Transforms/GPU/GpuToGpuOcl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,30 @@ struct Helper final {
rewriter.getIntegerAttr(idxType, static_cast<int64_t>(value)));
}

Value calculateStaticSize(OpBuilder &rewriter, const Location loc,
const MemRefType type) const {
if (type.getRank() == 0) {
return idxConstant(rewriter, loc, 0);
}

auto elementType = type.getElementType();
if (!elementType.isIntOrIndexOrFloat()) {
return nullptr;
}

int64_t numElements = 1;
for (auto dim : type.getShape()) {
if (dim == ShapedType::kDynamic) {
return nullptr;
}
numElements = numElements * dim;
}
auto elementSize = elementType.isIndex()
? idxType.getIntOrFloatBitWidth()
: elementType.getIntOrFloatBitWidth();
return idxConstant(rewriter, loc, elementSize * numElements / 8);
}

void destroyKernels(OpBuilder &rewriter, Location loc,
ArrayRef<Value> kernelPtrs) const {
auto size = idxConstant(rewriter, loc, kernelPtrs.size());
Expand Down Expand Up @@ -102,82 +126,44 @@ struct ConvertAlloc final : ConvertOpPattern<gpu::AllocOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = allocOp.getLoc();
MemRefType type = allocOp.getType();
auto shape = type.getShape();
auto dynamics = adaptor.getDynamicSizes();

if (shape.empty() || dynamics.empty()) {
int64_t staticSize;
if (shape.empty()) {
staticSize = 0;
} else {
staticSize = type.getElementType().getIntOrFloatBitWidth() / 8;
for (auto dim : shape) {
assert(dim != ShapedType::kDynamic);
staticSize *= dim;
}
}
auto size = helper.idxConstant(rewriter, loc, staticSize);
if (auto staticSize = helper.calculateStaticSize(rewriter, loc, type)) {
auto ptr = funcCall(rewriter, GPU_OCL_MALLOC, helper.ptrType,
{helper.ptrType, helper.idxType}, loc,
{getCtxPtr(rewriter), size})
{getCtxPtr(rewriter), staticSize})
.getResult();
Value replacement = MemRefDescriptor::fromStaticShape(
rewriter, loc, helper.converter, type, ptr, ptr);
rewriter.replaceOp(allocOp, replacement);
return success();
}

auto ndims = shape.size();
SmallVector<Value> newShape;
SmallVector<Value> newStrides(ndims);
auto staticSize = type.getElementType().getIntOrFloatBitWidth() / 8;
auto size = dynamics[0];

auto idxMul = [&](Value x, Value y) -> Value {
if (auto xConst = getConstantIntValue(x)) {
if (auto yConst = getConstantIntValue(y)) {
return helper.idxConstant(rewriter, loc,
xConst.value() * yConst.value());
}
}
return rewriter.create<LLVM::MulOp>(loc, x, y);
};

for (size_t i = 0, j = 0; i < ndims; i++) {
auto dim = shape[i];
if (dim == ShapedType::kDynamic) {
auto dynSize = dynamics[j++];
newShape.emplace_back(dynSize);
if (j != 1) {
size = idxMul(size, dynSize);
}
} else {
staticSize *= dim;
newShape.emplace_back(helper.idxConstant(rewriter, loc, dim));
}
auto dstType = helper.converter.convertType(type);
if (!dstType) {
allocOp.emitError() << "Failed to convert the MemRefType";
return failure();
}

size = idxMul(size, helper.idxConstant(rewriter, loc, staticSize));
SmallVector<Value> shape;
SmallVector<Value> strides;
Value size;
getMemRefDescriptorSizes(loc, type, adaptor.getDynamicSizes(), rewriter,
shape, strides, size);
assert(shape.size() == strides.size());

auto ptr = funcCall(rewriter, GPU_OCL_MALLOC, helper.ptrType,
{helper.ptrType, helper.idxType}, loc,
{getCtxPtr(rewriter), size})
.getResult();

newStrides[ndims - 1] = helper.idxConstant(rewriter, loc, 1);
for (int i = static_cast<int>(ndims) - 2; i >= 0; i--) {
newStrides[i] = idxMul(newStrides[i + 1], newShape[i]);
;
}

auto dsc = MemRefDescriptor::undef(rewriter, loc,
helper.converter.convertType(type));
auto dsc = MemRefDescriptor::undef(rewriter, loc, dstType);
dsc.setAllocatedPtr(rewriter, loc, ptr);
dsc.setAlignedPtr(rewriter, loc, ptr);
dsc.setOffset(rewriter, loc, helper.idxConstant(rewriter, loc, 0));

for (unsigned i = 0, n = static_cast<unsigned>(ndims); i < n; i++) {
dsc.setSize(rewriter, loc, i, newShape[i]);
dsc.setStride(rewriter, loc, i, newStrides[i]);
for (unsigned i = 0, n = static_cast<unsigned>(shape.size()); i < n; i++) {
dsc.setSize(rewriter, loc, i, shape[i]);
dsc.setStride(rewriter, loc, i, strides[i]);
}

rewriter.replaceOp(allocOp, static_cast<Value>(dsc));
Expand Down Expand Up @@ -209,23 +195,24 @@ struct ConvertMemcpy final : ConvertOpPattern<gpu::MemcpyOp> {
matchAndRewrite(gpu::MemcpyOp gpuMemcpy, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = gpuMemcpy.getLoc();
MemRefDescriptor srcDsc(adaptor.getSrc());
MemRefDescriptor dstDsc(adaptor.getDst());
auto srcType = gpuMemcpy.getSrc().getType();
auto elementSize = srcType.getElementType().getIntOrFloatBitWidth() / 8;
uint64_t numElements = 0;
for (auto dim : srcType.getShape()) {
if (dim == ShapedType::kDynamic) {
gpuMemcpy.emitOpError()
<< "dynamic shapes are not currently not supported";
return failure();
Value size = helper.calculateStaticSize(rewriter, loc, srcType);

if (!size) {
auto numElements = helper.idxConstant(rewriter, loc, 1);
for (unsigned i = 0, n = srcType.getRank(); i < n; i++) {
numElements = rewriter.create<LLVM::MulOp>(
loc, numElements, srcDsc.size(rewriter, loc, i));
}
numElements = numElements ? numElements * dim : dim;
size = rewriter.create<mlir::LLVM::MulOp>(
loc, numElements,
getSizeInBytes(loc, srcType.getElementType(), rewriter));
}

MemRefDescriptor srcDsc(adaptor.getSrc());
MemRefDescriptor dstDsc(adaptor.getDst());
auto srcPtr = srcDsc.alignedPtr(rewriter, loc);
auto dstPtr = dstDsc.alignedPtr(rewriter, loc);
auto size = helper.idxConstant(rewriter, loc, elementSize * numElements);
auto oclMemcpy = funcCall(
rewriter, GPU_OCL_MEMCPY, helper.voidType,
{helper.ptrType, helper.ptrType, helper.ptrType, helper.idxType}, loc,
Expand Down
40 changes: 15 additions & 25 deletions lib/gc/Transforms/GPU/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,35 @@
//
//===----------------------------------------------------------------------===//

#include <string>

#include "gc/Transforms/Passes.h"

#include "imex/Conversion/Passes.h"
#include "imex/Transforms/Passes.h"

#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#include <iostream>

#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"

#include <imex/Conversion/Passes.h>
#include <imex/Transforms/Passes.h>

#include <string>

#include "gc/Transforms/Passes.h"

namespace mlir::gc {

struct GPUPipelineOption : PassPipelineOptions<GPUPipelineOption> {
PassOptions::Option<bool> isUsmArgs{
Option<bool> isUsmArgs{
*this, "is-usm-args",
llvm::cl::desc("Whether to use USM(unified shared memory) func args, in "
"which the host and device could access the same buffer "
"and there is no need to add memcpy explicitly"),
llvm::cl::init(true)};
desc("Whether to use USM(unified shared memory) func args, in "
"which the host and device could access the same buffer "
"and there is no need to add memcpy explicitly"),
init(true)};
};

void populateGPUPipeline(OpPassManager &pm,
Expand Down
2 changes: 1 addition & 1 deletion lib/gc/Transforms/IterativeTilingAndFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op,
} else {
defaultTileSize.resize(iteratorTypes.size(), rewriter.getIndexAttr(0));
// Try tileSize from `32` to `16`.
SmallVector<int64_t> tsOrder = {16, 32};
SmallVector<int64_t> tsOrder = {32, 16};
// Record how many dims have been tiled, including fully tiled, i.e.
// tileSize == dimSize.
unsigned nonOneTileDims =
Expand Down
98 changes: 98 additions & 0 deletions test/mlir/test/gc/Transforms/GPU/gpu-to-gpuocl.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// RUN: gc-opt %s --gpu-to-gpuocl | FileCheck %s

module @test attributes {gpu.container_module} {
llvm.func @entry(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64) attributes {llvm.emit_c_interface} {
%0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%6 = llvm.insertvalue %arg5, %5[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%7 = llvm.insertvalue %arg6, %6[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%8 = builtin.unrealized_conversion_cast %7 : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<64x64xf32>
%gpu_mem = gpu.alloc host_shared () : memref<64x64xf32>
gpu.memcpy %gpu_mem, %8 : memref<64x64xf32>, memref<64x64xf32>
%9 = llvm.mlir.constant(32 : index) : i64
%10 = builtin.unrealized_conversion_cast %9 : i64 to index
%11 = llvm.mlir.constant(2 : index) : i64
%12 = builtin.unrealized_conversion_cast %11 : i64 to index
%13 = llvm.mlir.constant(1 : index) : i64
%14 = builtin.unrealized_conversion_cast %13 : i64 to index
gpu.launch_func @entry_kernel::@entry_kernel blocks in (%12, %12, %14) threads in (%14, %14, %14) args(%10 : index, %gpu_mem : memref<64x64xf32>)
gpu.memcpy %8, %gpu_mem : memref<64x64xf32>, memref<64x64xf32>
gpu.dealloc %gpu_mem : memref<64x64xf32>
llvm.return
}

gpu.module @entry_kernel attributes {gpu.binary = "Some SPIRV here \00"} {
gpu.func @entry_kernel(%arg0: index, %arg1: memref<64x64xf32>) kernel attributes {} {
gpu.return
}
}
}

// CHECK: llvm.mlir.global internal constant @gcGpuOclKernel_entry_kernel_SPIRV
// CHECK: llvm.mlir.global internal constant @gcGpuOclKernel_entry_kernel_Name
// CHECK: llvm.mlir.global internal @gcGpuOclKernel_entry_kernel_Ptr

// CHECK: llvm.func internal @createGcGpuOclKernel_entry_kernel([[CTX:%.+]]: !llvm.ptr) -> !llvm.ptr
// CHECK: [[NEW_PTR:%.+]] = llvm.call @gcGpuOclKernelCreate([[CTX]]
// CHECK: [[ZERO:%.+]] = llvm.mlir.zero
// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr
// CHECK: [[CMPXCHG:%.+]] = llvm.cmpxchg [[PTR_ADDR]], [[ZERO]], [[NEW_PTR]]
// CHECK: [[FLAG:%.+]] = llvm.extractvalue [[CMPXCHG]][1]
// CHECK: llvm.cond_br [[FLAG]], [[BB1:\^.+]], [[BB2:\^.+]]
// CHECK: [[BB1]]:
// CHECK: llvm.return [[NEW_PTR]]
// CHECK: [[BB2]]:
// CHECK: [[ONE:%.+]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[ARRAY:%.+]] = llvm.alloca [[ONE]]
// CHECK: [[ADDR:%.+]] = llvm.getelementptr [[ARRAY]]
// CHECK: llvm.store [[NEW_PTR]], [[ADDR]]
// CHECK: llvm.call @gcGpuOclKernelDestroy([[ONE]], [[ARRAY]])
// CHECK: [[OLD_PTR:%.+]] = llvm.extractvalue [[CMPXCHG]][0]
// CHECK: llvm.return [[OLD_PTR]]

// CHECK: llvm.func internal @getGcGpuOclKernel_entry_kernel([[CTX:%.+]]: !llvm.ptr) -> !llvm.ptr attributes {always_inline}
// CHECK: [[ZERO:%.+]] = llvm.mlir.zero
// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr
// CHECK: [[PTR:%.+]] = llvm.load [[PTR_ADDR]]
// CHECK: [[ICMP:%.+]] = llvm.icmp "eq" [[PTR]], [[ZERO]]
// CHECK: llvm.cond_br [[ICMP]], [[BB1:\^.+]], [[BB2:\^.+]]
// CHECK: [[BB1]]:
// CHECK: [[NEW_PTR:%.+]] = llvm.call @createGcGpuOclKernel_entry_kernel([[CTX]])
// CHECK: llvm.return [[NEW_PTR]]
// CHECK: [[BB2]]:
// CHECK: llvm.return [[PTR]]

// CHECK: llvm.func @entry(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, [[CTX:%.+]]: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64)
// CHECK: [[SIZE:%.+]] = llvm.mlir.constant(16384 : i64) : i64
// CHECK: llvm.call @gcGpuOclMalloc([[CTX]], [[SIZE]])
// CHECK: [[SIZE:%.+]] = llvm.mlir.constant(16384 : i64) : i64
// CHECK: [[SRC:%.+]] = llvm.extractvalue
// CHECK: [[DST:%.+]] = llvm.extractvalue [[GPU_MEMREF:%.+]][1]
// CHECK: llvm.call @gcGpuOclMemcpy([[CTX]], [[SRC]], [[DST]], [[SIZE]])
// CHECK: [[KERNEL:%.+]] = llvm.call @getGcGpuOclKernel_entry_kernel([[CTX:%.+]]) : (!llvm.ptr) -> !llvm.ptr
// CHECK: llvm.call @gcGpuOclKernelLaunch([[CTX]], [[KERNEL]],
// CHECK: [[SIZE:%.+]] = llvm.mlir.constant(16384 : i64) : i64
// CHECK: [[SRC:%.+]] = llvm.extractvalue [[GPU_MEMREF:%.+]][1]
// CHECK: [[DST:%.+]] = llvm.extractvalue
// CHECK: llvm.call @gcGpuOclMemcpy([[CTX]], [[SRC]], [[DST]], [[SIZE]])
// CHECK: [[GPU_PTR:%.+]] = llvm.extractvalue [[GPU_MEMREF:%.+]][0]
// CHECK: llvm.call @gcGpuOclDealloc([[CTX]], [[GPU_PTR]])

// CHECK: llvm.func @gcGpuOclKernelCreate
// CHECK: llvm.func @gcGpuOclKernelDestroy
// CHECK: llvm.func @gcGpuOclKernelLaunch


// CHECK: llvm.func @gcGpuOclModuleDestructor()
// CHECK: llvm.fence acquire
// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr
// CHECK: [[PTR:%.+]] = llvm.load [[PTR_ADDR]]
// CHECK: [[ONE:%.+]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[ARRAY:%.+]] = llvm.alloca [[ONE]]
// CHECK: [[ADDR:%.+]] = llvm.getelementptr [[ARRAY]]
// CHECK: llvm.store [[PTR]], [[ADDR]]
// CHECK: llvm.call @gcGpuOclKernelDestroy([[ONE]], [[ARRAY]])
Loading

0 comments on commit 74842e0

Please sign in to comment.