Skip to content

Ivan/re #2368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: re
Choose a base branch
from
Open

Ivan/re #2368

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 109 additions & 28 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
// the function passed as the first argument.
//
//===----------------------------------------------------------------------===//
#include <llvm/IR/IntrinsicInst.h>
#include <llvm/Support/PointerLikeTypeTraits.h>
#define private public
#include "llvm/IR/Module.h"
#undef private
Expand Down Expand Up @@ -117,6 +119,10 @@ SmallVector<CallBase *> gatherCallers(Function *F) {
}

void fixup(Module &M) {

if (getenv("ENZYME_CLANG_DUMP_BEFORE_FIXUP"))
llvm::errs() << "BEFORE FIXUP:\n" << M << "\n";

auto LaunchKernelFunc = M.getFunction(cudaLaunchSymbolName);
if (!LaunchKernelFunc)
return;
Expand All @@ -137,19 +143,109 @@ void fixup(Module &M) {
BlockDim2, SharedMemSize, StreamPtr,
};
auto StubFunc = cast<Function>(CI->getArgOperand(0));

size_t idx = 0;
for (auto &Arg : StubFunc->args()) {
auto gep = Builder.CreateConstInBoundsGEP1_64(llvm::PointerType::getUnqual(CI->getContext()), ArgPtr, idx);
auto ld = Builder.CreateLoad(llvm::PointerType::getUnqual(CI->getContext()), gep);
ld = Builder.CreateLoad(Arg.getType(), ld);
Args.push_back(ld);
idx++;
LLVM_DEBUG(dbgs() << "StubFunc " << *StubFunc << "\n");

AllocaInst *ArgPtrAlloca = cast<AllocaInst>(ArgPtr);
assert(ArgPtrAlloca->getAllocatedType()->isPointerTy());
LLVM_DEBUG(dbgs() << "ALLOCA " << *ArgPtrAlloca << "\n");
unsigned NumArgs =
cast<ConstantInt>(ArgPtrAlloca->getArraySize())->getZExtValue();
unsigned ArgsOffset = Args.size();
for (unsigned I = 0; I < NumArgs; I++)
Args.push_back(nullptr);

LLVM_DEBUG(dbgs() << "ARG PTR " << *ArgPtr << "\n");
for (Use &ArgPtrUse : ArgPtr->uses()) {
LLVM_DEBUG(dbgs() << "USE " << *ArgPtrUse.getUser() << "\n");

Value *ThisArgPtr;
int ArgIdx;
auto Gep = dyn_cast<GetElementPtrInst>(ArgPtrUse.getUser());
auto SI = dyn_cast<StoreInst>(ArgPtrUse.getUser());
if (Gep && Gep->getPointerOperand() == ArgPtr) {
assert(Gep->getPointerOperand() == ArgPtr);
assert(Gep->getNumIndices() == 1);
Value *GepIdx = Gep->idx_begin()->get();
ArgIdx = cast<ConstantInt>(GepIdx)->getSExtValue();
assert(ArgIdx >= 0);
assert(Gep->getSourceElementType()->isPointerTy());
ThisArgPtr = Gep;
} else if (SI && SI->getPointerOperand() == ArgPtr) {
assert(false && "Should never happen if we properly run in StartEP");
ArgIdx = 0;
ThisArgPtr = ArgPtr;
} else {
continue;
}
LLVM_DEBUG(dbgs() << *ThisArgPtr << "\n");

for (Use &ThisArgPtrUse : ThisArgPtr->uses()) {
if (StoreInst *SI = dyn_cast<StoreInst>(ThisArgPtrUse.getUser())) {
assert(SI->getPointerOperand() == ThisArgPtr);
LLVM_DEBUG(dbgs() << ArgsOffset << " " << ArgIdx << " " << *SI << " "
<< "\n");
if (AllocaInst *ThisArgAlloca =
dyn_cast<AllocaInst>(SI->getValueOperand())) {
LLVM_DEBUG(dbgs() << ArgsOffset << " " << ArgIdx << " " << *SI
<< " " << *ThisArgAlloca << "\n");
for (Use &ThisArgAllocaUse : ThisArgAlloca->uses()) {
LLVM_DEBUG(dbgs()
<< "RealSI " << *ThisArgAllocaUse.getUser() << "\n");
StoreInst *RealSI =
dyn_cast<StoreInst>(ThisArgAllocaUse.getUser());
if (RealSI && RealSI->getPointerOperand() == ThisArgAlloca) {
LLVM_DEBUG(dbgs() << "YES\n");
assert(Args[ArgsOffset + ArgIdx] == nullptr);
Args[ArgsOffset + ArgIdx] = RealSI->getValueOperand();
}
}
if (Args[ArgsOffset + ArgIdx] == nullptr) {
errs() << "WARNING: Could not find corresponding store to `"
<< *ThisArgAlloca << "'.\n";
assert(cast<ConstantInt>(ThisArgAlloca->getArraySize())
->getZExtValue() == 1);
Args[ArgsOffset + ArgIdx] =
UndefValue::get(ThisArgAlloca->getAllocatedType());
}
} else {
// TODO this needs to be fixed
assert(isa<Argument>(SI->getValueOperand()));
errs()
<< "WARNING: Found argument that we cannot see the stores to `"
<< *SI->getValueOperand() << "'.\n";

Args[ArgsOffset + ArgIdx] =
ConstantPointerNull::get(PointerType::get(M.getContext(), 0));
}
}
}
}

if (NumArgs == 1) {
// There is one case where having a null ptr is allowed and it is because
// even if the kernel has 0 arguments, the codegen will still allocate a ptr
//
// i.e. in the case of
// 0 args:
// %kernel_args = alloca ptr
// 1 arg:
// %kernel_args = alloca ptr
// 2 (or more) args:
// %kernel_args = alloca [ 2 x ptr ]
//
if (Args[Args.size() - 1] == nullptr) {
Args.pop_back();
NumArgs = 0;
}
} else {
assert(all_of(Args, [](Value *V) {return V != nullptr;}));
}


SmallVector<Type *> ArgTypes;
for (Value *V : Args)
ArgTypes.push_back(V->getType());
auto MlirLaunchFunc = Function::Create(
Function *MlirLaunchFunc = Function::Create(
FunctionType::get(Type::getVoidTy(M.getContext()), ArgTypes,
/*isVarAtg=*/false),
llvm::GlobalValue::ExternalLinkage,
Expand Down Expand Up @@ -252,7 +348,8 @@ void fixup(Module &M) {

for (CallBase *PushCall : gatherCallers(PushConfigFunc)) {
// Replace with success
PushCall->replaceAllUsesWith(ConstantInt::get(IntegerType::get(PushCall->getContext(), 32), 0));
PushCall->replaceAllUsesWith(
ConstantInt::get(IntegerType::get(PushCall->getContext(), 32), 0));
PushCall->eraseFromParent();
}
for (CallBase *CI : CoercedKernels) {
Expand Down Expand Up @@ -631,13 +728,7 @@ extern "C" void registerReactantAndPassPipeline(llvm::PassBuilder &PB,
extern "C" void registerReactant(llvm::PassBuilder &PB, std::vector<std::string> gpubinaries) {

llvm::errs() << " registering reactant\n";
#if LLVM_VERSION_MAJOR >= 20
auto loadPass = [=](ModulePassManager &MPM, OptimizationLevel Level,
ThinOrFullLTOPhase)
#else
auto loadPass = [=](ModulePassManager &MPM, OptimizationLevel Level)
#endif
{
auto loadPass = [=](ModulePassManager &MPM, OptimizationLevel Level) {
MPM.addPass(ReactantNewPM(gpubinaries));
};

Expand All @@ -652,17 +743,7 @@ extern "C" void registerReactant(llvm::PassBuilder &PB, std::vector<std::string>
});

// TODO need for perf reasons to move Enzyme pass to the pre vectorization.
PB.registerOptimizerEarlyEPCallback(loadPass);

auto loadLTO = [loadPass](ModulePassManager &MPM,
OptimizationLevel Level) {
#if LLVM_VERSION_MAJOR >= 20
loadPass(MPM, Level, ThinOrFullLTOPhase::None);
#else
loadPass(MPM, Level);
#endif
};
PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadLTO);
PB.registerPipelineStartEPCallback(loadPass);
}

extern "C" void registerReactant2(llvm::PassBuilder &PB) {
Expand Down