-
-
Notifications
You must be signed in to change notification settings - Fork 120
Open
Description
Hi! I'm looking for an example of how to configure vkFFT to perform 3D FFTs. I have RustFFT and cuFFT implementations that get the same, correct answer, and I've tried various vkFFT configs, but have been unable to get meaningful or correct results. I've streamlined the interfaces so everything is identical between these 3 FFT tools except for the functions called. I'm using vkFFT's CUDA backend.
Stated another way, I'm looking how to do exactly this in vkFFT. I'm confident it's a matter of changing a few configuration steps, but I'm not sure what precisely: (I can post the RustFFT as well, but I think this is simpler):
cuFFT (Basically defaults):
// https://docs.nvidia.com/cuda/cufft/#cufftplan3d
extern "C"
void* make_plan(int nx, int ny, int nz, void* cu_stream) {
auto* w = new PlanWrap();
w->stream = reinterpret_cast<cudaStream_t>(cu_stream);
// With Plan3D, Z is the fastest-changing dimension (contiguous); x is the slowest.
CUFFT_CHECK(cufftPlan3d(&w->plan_r2c, nx, ny, nz, CUFFT_R2C));
CUFFT_CHECK(cufftPlan3d(&w->plan_c2r, nx, ny, nz, CUFFT_C2R));
CUFFT_CHECK(cufftSetStream(w->plan_r2c, w->stream));
CUFFT_CHECK(cufftSetStream(w->plan_c2r, w->stream));
return w;
}
extern "C"
void destroy_plan(void* plan) {
auto* w = reinterpret_cast<PlanWrap*>(plan);
if (!w) return;
cufftDestroy(w->plan_r2c);
cufftDestroy(w->plan_c2r);
delete w;
}
// https://docs.nvidia.com/cuda/cufft/#cufftexecr2c-and-cufftexecd2z
// Performs a forward real-to-copmlex FFT of rho. Note: This is more efficient
// than complex-to-complex.
extern "C"
void exec_forward(void* plan, float* rho_real, cufftComplex* rho) {
auto* w = reinterpret_cast<PlanWrap*>(plan);
if (!w) return;
CUFFT_CHECK(cufftExecR2C(w->plan_r2c, rho_real, rho));
}
extern "C"
void exec_inverse(
void* plan,
cufftComplex* exk,
cufftComplex* eyk,
cufftComplex* ezk,
float* ex,
float* ey,
float* ez
){
auto* w = reinterpret_cast<PlanWrap*>(plan);
if (!w) return;
CUFFT_CHECK(cufftExecC2R(w->plan_c2r, exk, ex));
CUFFT_CHECK(cufftExecC2R(w->plan_c2r, eyk, ey));
CUFFT_CHECK(cufftExecC2R(w->plan_c2r, ezk, ez));
}My latest vkFFT attempt:
typedef struct VkContext {
CUdevice dev;
CUcontext ctx;
CUstream stream;
int owns_stream;
} VkContext;
typedef struct VkFftPlan {
VkFFTApplication app;
VkFFTConfiguration cfg;
CUdevice cu_dev;
CUcontext cu_ctx;
cudaStream_t stream;
uint64_t Nx, Ny, Nz;
} VkFftPlan;
void* vk_make_context_from_stream(void* cu_stream_void) {
VkContext* c = (VkContext*)calloc(1, sizeof(VkContext));
if (!c) return NULL;
c->stream = (CUstream)cu_stream_void;
c->owns_stream = 0;
cuInit(0);
CUcontext cur = NULL;
cuCtxGetCurrent(&cur);
if (cur == NULL) {
CUdevice dev0;
cuDeviceGet(&dev0, 0);
cuDevicePrimaryCtxRetain(&cur, dev0);
cuCtxSetCurrent(cur);
}
c->ctx = cur;
cuCtxGetDevice(&c->dev);
return c;
}
void* vk_make_context_default(void) {
VkContext* c = (VkContext*)calloc(1, sizeof(VkContext));
if (!c) return NULL;
cuInit(0);
cuDeviceGet(&c->dev, 0);
CUcontext primary = NULL;
cuDevicePrimaryCtxRetain(&primary, c->dev);
cuCtxSetCurrent(primary);
c->ctx = primary;
cuStreamCreate(&c->stream, CU_STREAM_DEFAULT);
c->owns_stream = 1;
return c;
}
void vk_destroy_context(void* ctx_) {
VkContext* c = (VkContext*)ctx_;
if (!c) return;
if (c->owns_stream) cuStreamDestroy(c->stream);
if (c->owns_stream) cuDevicePrimaryCtxRelease(c->dev);
free(c);
}
void* make_plan(void* ctx_, int32_t nx, int32_t ny, int32_t nz, void* cu_stream)
{
VkContext* g = (VkContext*)ctx_;
VkFftPlan* p = (VkFftPlan*)calloc(1, sizeof(VkFftPlan));
if (!p) return NULL;
p->cu_dev = g->dev;
p->cu_ctx = g->ctx;
p->stream = cu_stream ? (cudaStream_t)cu_stream : (cudaStream_t)g->stream;
p->Nx = (uint64_t)nx;
p->Ny = (uint64_t)ny;
p->Nz = (uint64_t)nz;
VkFFTConfiguration* cfg = &p->cfg;
memset(cfg, 0, sizeof(*cfg));
// make sure this context is current for init
cuCtxSetCurrent(p->cu_ctx);
cfg->device = &p->cu_dev;
cfg->stream = &p->stream;
cfg->num_streams = 1;
cfg->isInputFormatted = 1;
cfg->isOutputFormatted = 1;
cfg->FFTdim = 3;
cfg->size[0] = (uint64_t)nz;
cfg->size[1] = (uint64_t)ny;
cfg->size[2] = (uint64_t)nx;
cfg->performR2C = 1;
cfg->normalize = 0;
cfg->numberBatches = 1;
VkFFTResult res = initializeVkFFT(&p->app, *cfg);
if (res != VKFFT_SUCCESS) {
free(p);
return NULL;
}
return p;
}
void destroy_plan(void* plan_) {
VkFftPlan* p = (VkFftPlan*)plan_;
if (!p) return;
deleteVkFFT(&p->app);
free(p);
}
void exec_forward(void* plan_, void* real_in, void* complex_out) {
VkFftPlan* p = (VkFftPlan*)plan_;
cuCtxSetCurrent(p->cu_ctx);
CUdeviceptr in = (CUdeviceptr)real_in;
CUdeviceptr out = (CUdeviceptr)complex_out;
VkFFTLaunchParams lp;
memset(&lp, 0, sizeof(lp));
lp.buffer = (void**)∈
lp.outputBuffer = (void**)&out;
VkFFTResult res = VkFFTAppend(&p->app, -1, &lp);
if (res != VKFFT_SUCCESS) {
// printf("VkFFT forward failed: %d\n", res);
}
}
void exec_inverse(void* plan_, void* complex_in, void* real_out) {
VkFftPlan* p = (VkFftPlan*)plan_;
cuCtxSetCurrent(p->cu_ctx);
CUdeviceptr in = (CUdeviceptr)complex_in;
CUdeviceptr out = (CUdeviceptr)real_out;
VkFFTLaunchParams lp;
memset(&lp, 0, sizeof(lp));
lp.buffer = (void**)∈
lp.outputBuffer = (void**)&out;
VkFFTResult res = VkFFTAppend(&p->app, 1, &lp);
if (res != VKFFT_SUCCESS) {
// printf("VkFFT inverse failed: %d\n", res);
}
}Ty! I would love to replace cuFFT with vkFFT, but I am stuck.
Metadata
Metadata
Assignees
Labels
No labels