Skip to content

Seeking wisdom on 3D R2C/C2R FFTs #213

@David-OConnor

Description

@David-OConnor

Hi! I'm looking for an example of how to configure vkFFT to perform 3D FFTs. I have RustFFT and cuFFT implementations that get the same, correct answer, and I've tried various vkFFT configs, but have been unable to get meaningful or correct results. I've streamlined the interfaces so everything is identical between these 3 FFT tools except for the functions called. I'm using vkFFT's CUDA backend.

Stated another way, I'm looking how to do exactly this in vkFFT. I'm confident it's a matter of changing a few configuration steps, but I'm not sure what precisely: (I can post the RustFFT as well, but I think this is simpler):

cuFFT (Basically defaults):

// https://docs.nvidia.com/cuda/cufft/#cufftplan3d
extern "C"
void* make_plan(int nx, int ny, int nz, void* cu_stream) {
    auto* w = new PlanWrap();

    w->stream = reinterpret_cast<cudaStream_t>(cu_stream);

    // With Plan3D, Z is the fastest-changing dimension (contiguous); x is the slowest.
    CUFFT_CHECK(cufftPlan3d(&w->plan_r2c, nx, ny, nz, CUFFT_R2C));
    CUFFT_CHECK(cufftPlan3d(&w->plan_c2r, nx, ny, nz, CUFFT_C2R));

    CUFFT_CHECK(cufftSetStream(w->plan_r2c, w->stream));
    CUFFT_CHECK(cufftSetStream(w->plan_c2r, w->stream));

    return w;
}


extern "C"
void destroy_plan(void* plan) {
    auto* w = reinterpret_cast<PlanWrap*>(plan);
    if (!w) return;

    cufftDestroy(w->plan_r2c);
    cufftDestroy(w->plan_c2r);

    delete w;
}

// https://docs.nvidia.com/cuda/cufft/#cufftexecr2c-and-cufftexecd2z
// Performs a forward real-to-copmlex FFT of rho. Note: This is more efficient
// than complex-to-complex.
extern "C"
void exec_forward(void* plan, float* rho_real, cufftComplex* rho) {
    auto* w = reinterpret_cast<PlanWrap*>(plan);
    if (!w) return;

    CUFFT_CHECK(cufftExecR2C(w->plan_r2c, rho_real, rho));
}

extern "C"
void exec_inverse(
    void* plan,
    cufftComplex* exk,
    cufftComplex* eyk,
    cufftComplex* ezk,
    float* ex,
    float* ey,
    float* ez
){
    auto* w = reinterpret_cast<PlanWrap*>(plan);
    if (!w) return;

    CUFFT_CHECK(cufftExecC2R(w->plan_c2r, exk, ex));
    CUFFT_CHECK(cufftExecC2R(w->plan_c2r, eyk, ey));
    CUFFT_CHECK(cufftExecC2R(w->plan_c2r, ezk, ez));
}

My latest vkFFT attempt:

typedef struct VkContext {
    CUdevice  dev;
    CUcontext ctx;
    CUstream  stream;
    int       owns_stream;
} VkContext;

typedef struct VkFftPlan {
    VkFFTApplication   app;
    VkFFTConfiguration cfg;
    CUdevice           cu_dev;
    CUcontext          cu_ctx;
    cudaStream_t       stream;
    uint64_t           Nx, Ny, Nz;
} VkFftPlan;

void* vk_make_context_from_stream(void* cu_stream_void) {
    VkContext* c = (VkContext*)calloc(1, sizeof(VkContext));
    if (!c) return NULL;

    c->stream = (CUstream)cu_stream_void;
    c->owns_stream = 0;

    cuInit(0);

    CUcontext cur = NULL;
    cuCtxGetCurrent(&cur);
    if (cur == NULL) {
        CUdevice dev0;
        cuDeviceGet(&dev0, 0);
        cuDevicePrimaryCtxRetain(&cur, dev0);
        cuCtxSetCurrent(cur);
    }

    c->ctx = cur;
    cuCtxGetDevice(&c->dev);
    return c;
}

void* vk_make_context_default(void) {
    VkContext* c = (VkContext*)calloc(1, sizeof(VkContext));
    if (!c) return NULL;

    cuInit(0);
    cuDeviceGet(&c->dev, 0);

    CUcontext primary = NULL;
    cuDevicePrimaryCtxRetain(&primary, c->dev);
    cuCtxSetCurrent(primary);
    c->ctx = primary;

    cuStreamCreate(&c->stream, CU_STREAM_DEFAULT);
    c->owns_stream = 1;
    return c;
}

void vk_destroy_context(void* ctx_) {
    VkContext* c = (VkContext*)ctx_;
    if (!c) return;
    if (c->owns_stream) cuStreamDestroy(c->stream);
    if (c->owns_stream) cuDevicePrimaryCtxRelease(c->dev);
    free(c);
}

void* make_plan(void* ctx_, int32_t nx, int32_t ny, int32_t nz, void* cu_stream)
{
    VkContext* g = (VkContext*)ctx_;

    VkFftPlan* p = (VkFftPlan*)calloc(1, sizeof(VkFftPlan));
    if (!p) return NULL;

    p->cu_dev  = g->dev;
    p->cu_ctx  = g->ctx;
    p->stream  = cu_stream ? (cudaStream_t)cu_stream : (cudaStream_t)g->stream;

    p->Nx = (uint64_t)nx;
    p->Ny = (uint64_t)ny;
    p->Nz = (uint64_t)nz;

    VkFFTConfiguration* cfg = &p->cfg;
    memset(cfg, 0, sizeof(*cfg));

    // make sure this context is current for init
    cuCtxSetCurrent(p->cu_ctx);

    cfg->device      = &p->cu_dev;
    cfg->stream      = &p->stream;
    cfg->num_streams = 1;

    cfg->isInputFormatted  = 1;
    cfg->isOutputFormatted = 1;

    cfg->FFTdim  = 3;
    cfg->size[0] = (uint64_t)nz;
    cfg->size[1] = (uint64_t)ny;
    cfg->size[2] = (uint64_t)nx;

    cfg->performR2C    = 1;
    cfg->normalize     = 0;
    cfg->numberBatches = 1;

    VkFFTResult res = initializeVkFFT(&p->app, *cfg);
    if (res != VKFFT_SUCCESS) {
        free(p);
        return NULL;
    }

    return p;
}

void destroy_plan(void* plan_) {
    VkFftPlan* p = (VkFftPlan*)plan_;
    if (!p) return;
    deleteVkFFT(&p->app);
    free(p);
}

void exec_forward(void* plan_, void* real_in, void* complex_out) {
    VkFftPlan* p = (VkFftPlan*)plan_;

    cuCtxSetCurrent(p->cu_ctx);

    CUdeviceptr in  = (CUdeviceptr)real_in;
    CUdeviceptr out = (CUdeviceptr)complex_out;

    VkFFTLaunchParams lp;
    memset(&lp, 0, sizeof(lp));

    lp.buffer       = (void**)&in;
    lp.outputBuffer = (void**)&out;

    VkFFTResult res = VkFFTAppend(&p->app, -1, &lp);
    if (res != VKFFT_SUCCESS) {
        // printf("VkFFT forward failed: %d\n", res);
    }
}

void exec_inverse(void* plan_, void* complex_in, void* real_out) {
    VkFftPlan* p = (VkFftPlan*)plan_;

    cuCtxSetCurrent(p->cu_ctx);

    CUdeviceptr in  = (CUdeviceptr)complex_in;
    CUdeviceptr out = (CUdeviceptr)real_out;

    VkFFTLaunchParams lp;
    memset(&lp, 0, sizeof(lp));

    lp.buffer       = (void**)&in;
    lp.outputBuffer = (void**)&out;

    VkFFTResult res = VkFFTAppend(&p->app, 1, &lp);
    if (res != VKFFT_SUCCESS) {
        // printf("VkFFT inverse failed: %d\n", res);
    }
}

Ty! I would love to replace cuFFT with vkFFT, but I am stuck.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions