Skip to content

Commit

Permalink
FFT: Support complex to complex
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Feb 10, 2025
1 parent 2f64d71 commit ee7a868
Show file tree
Hide file tree
Showing 6 changed files with 403 additions and 103 deletions.
8 changes: 8 additions & 0 deletions Docs/sphinx_documentation/source/FFT.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ in an :cpp:`FFT::Info` object passed to the constructor of

r2c.backward(cmf, mf);

.. _sec:FFT:c2c:

FFT::C2C Class
==============

:cpp:`FFT::C2C` is a class template that supports complex to complex Fourier
transforms. It has a similar interface as :cpp:`FFT::R2C`.

.. _sec:FFT:localr2c:

FFT::LocalR2C Class
Expand Down
78 changes: 63 additions & 15 deletions Src/FFT/AMReX_FFT_Helper.H
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ struct Info

//! For automatic strategy, this is the size per process below which we
//! switch from slab to pencil.
int pencil_threshold = 8;
int pencil_threshold = 4;

//! Supported only in 3D. When twod_mode is true, FFT is performed on
//! the first two dimensions only and the third dimension size is the
Expand Down Expand Up @@ -310,7 +310,7 @@ struct Plan
void init_r2c (IntVectND<M> const& fft_size, void*, void*, bool cache, int ncomp = 1);

template <Direction D>
void init_c2c (Box const& box, VendorComplex* p, int ncomp = 1)
void init_c2c (Box const& box, VendorComplex* p, int ncomp = 1, int ndims = 1)
{
static_assert(D == Direction::forward || D == Direction::backward);

Expand All @@ -319,9 +319,35 @@ struct Plan
pf = (void*)p;
pb = (void*)p;

n = box.length(0);
howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
howmany *= ncomp;
int len[3];

if (ndims == 1) {
n = box.length(0);
howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
howmany *= ncomp;
len[0] = box.length(0);
}
#if (AMREX_SPACEDIM >= 2)
else if (ndims == 2) {
n = box.length(0) * box.length(1);
#if (AMREX_SPACEDIM == 2)
howmany = ncomp;
#else
howmany = box.length(2) * ncomp;
#endif
len[0] = box.length(1);
len[1] = box.length(0);
}
#if (AMREX_SPACEDIM == 3)
else if (ndims == 3) {
n = box.length(0) * box.length(1) * box.length(2);
howmany = ncomp;
len[0] = box.length(2);
len[1] = box.length(1);
len[2] = box.length(0);
}
#endif
#endif

#if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
Expand All @@ -330,22 +356,39 @@ struct Plan
cufftType t = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
std::size_t work_size;
AMREX_CUFFT_SAFE_CALL
(cufftMakePlanMany(plan, 1, &n, nullptr, 1, n, nullptr, 1, n, t, howmany, &work_size));
(cufftMakePlanMany(plan, ndims, len, nullptr, 1, n, nullptr, 1, n, t, howmany, &work_size));

#elif defined(AMREX_USE_HIP)

auto prec = std::is_same_v<float,T> ? rocfft_precision_single
: rocfft_precision_double;
auto dir= (D == Direction::forward) ? rocfft_transform_type_complex_forward
: rocfft_transform_type_complex_inverse;
const std::size_t length = n;
std::size_t length[3];
if (ndims == 1) {
length[0] = len[0];
} else if (ndims == 2) {
length[0] = len[1];
length[1] = len[0];
} else {
length[0] = len[2];
length[1] = len[1];
length[2] = len[0];
}
AMREX_ROCFFT_SAFE_CALL
(rocfft_plan_create(&plan, rocfft_placement_inplace, dir, prec, 1,
&length, howmany, nullptr));
(rocfft_plan_create(&plan, rocfft_placement_inplace, dir, prec, ndims,
length, howmany, nullptr));

#elif defined(AMREX_USE_SYCL)

auto* pp = new mkl_desc_c(n);
mkl_desc_c* pp;
if (ndims == 1) {
pp = new mkl_desc_c(n);
} else if (ndims == 2) {
pp = new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1])});
} else {
pp = new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1]), std::int64_t(len[2])});
}
#ifndef AMREX_USE_MKL_DFTI_2024
pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::INPLACE);
Expand All @@ -355,7 +398,12 @@ struct Plan
pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, n);
std::vector<std::int64_t> strides = {0,1};
std::vector<std::int64_t> strides(ndims+1);
strides[0] = 0;
strides[ndims] = 1;
for (int i = ndims-1; i >= 1; --i) {
strides[i] = strides[i+1] * len[ndims-1-i];
}
#ifndef AMREX_USE_MKL_DFTI_2024
pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
Expand All @@ -373,21 +421,21 @@ struct Plan
if constexpr (std::is_same_v<float,T>) {
if constexpr (D == Direction::forward) {
plan = fftwf_plan_many_dft
(1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
(ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
FFTW_ESTIMATE);
} else {
plan = fftwf_plan_many_dft
(1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
(ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
FFTW_ESTIMATE);
}
} else {
if constexpr (D == Direction::forward) {
plan = fftw_plan_many_dft
(1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
(ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
FFTW_ESTIMATE);
} else {
plan = fftw_plan_many_dft
(1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
(ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
FFTW_ESTIMATE);
}
}
Expand Down
Loading

0 comments on commit ee7a868

Please sign in to comment.