Skip to content

Commit

Permalink
FFT: Add batch support (#4327)
Browse files Browse the repository at this point in the history
To set batch, one calls `FFT::Info::setBatchSize` and passes an object
of `FFT::Info` to the constructor of `FFT::R2C`. The data in `MultiFab`
and `cMultiFab` should have the batch size as the number of components.
  • Loading branch information
WeiqunZhang authored Feb 10, 2025
1 parent 823ec7f commit 2f64d71
Show file tree
Hide file tree
Showing 11 changed files with 471 additions and 181 deletions.
26 changes: 26 additions & 0 deletions Docs/sphinx_documentation/source/FFT.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,32 @@ object. Therefore, one should cache it for reuse if possible. Although
:cpp:`std::unique_ptr<FFT::R2C<Real>>` to store an object in one's class.


Class template `FFT::R2C` also supports batched FFTs. The batch size is set
in an :cpp:`FFT::Info` object passed to the constructor of
:cpp:`FFT::R2C`. Below is an example.

.. highlight:: c++

::

int batch_size = 10;
Geometry geom(...);
MultiFab mf(ba, dm, batch_size, 0);

FFT::Info info{};
info.setBatchSize(batch_size));
FFT::R2C<Real,FFT::Direction::both> r2c(geom.Domain(), info);

auto const& [cba, cdm] = r2c.getSpectralDataLayout();
cMultiFab cmf(cba, cdm, batch_size, 0);

r2c.forward(mf, cmf);

// Do work on cmf.
// Function forwardThenBackward is not yet supported for a batched FFT.

r2c.backward(cmf, mf);

.. _sec:FFT:localr2c:

FFT::LocalR2C Class
Expand Down
99 changes: 59 additions & 40 deletions Src/FFT/AMReX_FFT_Helper.H
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace amrex::FFT

enum struct Direction { forward, backward, both, none };

enum struct DomainStrategy { slab, pencil };
enum struct DomainStrategy { automatic, slab, pencil };

AMREX_ENUM( Boundary, periodic, even, odd );

Expand All @@ -56,15 +56,28 @@ enum struct Kind { none, r2c_f, r2c_b, c2c_f, c2c_b, r2r_ee_f, r2r_ee_b,

struct Info
{
//! Supported only in 3D. When batch_mode is true, FFT is performed on
//! Domain composition strategy.
DomainStrategy domain_strategy = DomainStrategy::automatic;

//! For automatic strategy, this is the size per process below which we
//! switch from slab to pencil.
int pencil_threshold = 8;

//! Supported only in 3D. When twod_mode is true, FFT is performed on
//! the first two dimensions only and the third dimension size is the
//! batch size.
bool batch_mode = false;
bool twod_mode = false;

//! Batched FFT size. Only support in R2C, not R2X.
int batch_size = 1;

//! Max number of processes to use
int nprocs = std::numeric_limits<int>::max();

Info& setBatchMode (bool x) { batch_mode = x; return *this; }
Info& setDomainStrategy (DomainStrategy s) { domain_strategy = s; return *this; }
Info& setPencilThreshold (int t) { pencil_threshold = t; return *this; }
Info& setTwoDMode (bool x) { twod_mode = x; return *this; }
Info& setBatchSize (int bsize) { batch_size = bsize; return *this; }
Info& setNumProcs (int n) { nprocs = n; return *this; }
};

Expand Down Expand Up @@ -170,7 +183,7 @@ struct Plan
}

template <Direction D>
void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false)
void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false, int ncomp = 1)
{
static_assert(D == Direction::forward || D == Direction::backward);

Expand Down Expand Up @@ -198,6 +211,7 @@ struct Plan
howmany = (rank == 1) ? AMREX_D_TERM(1, *box.length(1), *box.length(2))
: AMREX_D_TERM(1, *1 , *box.length(2));
#endif
howmany *= ncomp;

amrex::ignore_unused(nc);

Expand Down Expand Up @@ -293,10 +307,10 @@ struct Plan
}

template <Direction D, int M>
void init_r2c (IntVectND<M> const& fft_size, void*, void*, bool cache);
void init_r2c (IntVectND<M> const& fft_size, void*, void*, bool cache, int ncomp = 1);

template <Direction D>
void init_c2c (Box const& box, VendorComplex* p)
void init_c2c (Box const& box, VendorComplex* p, int ncomp = 1)
{
static_assert(D == Direction::forward || D == Direction::backward);

Expand All @@ -307,6 +321,7 @@ struct Plan

n = box.length(0);
howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
howmany *= ncomp;

#if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
Expand Down Expand Up @@ -1131,7 +1146,7 @@ struct Plan
}
};

using Key = std::tuple<IntVectND<3>,Direction,Kind>;
using Key = std::tuple<IntVectND<3>,int,Direction,Kind>;
using PlanD = typename Plan<double>::VendorPlan;
using PlanF = typename Plan<float>::VendorPlan;

Expand All @@ -1143,7 +1158,7 @@ void add_vendor_plan_f (Key const& key, PlanF plan);

template <typename T>
template <Direction D, int M>
void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool cache)
void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool cache, int ncomp)
{
static_assert(D == Direction::forward || D == Direction::backward);

Expand All @@ -1154,10 +1169,10 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool

n = 1;
for (auto s : fft_size) { n *= s; }
howmany = 1;
howmany = ncomp;

#if defined(AMREX_USE_GPU)
Key key = {fft_size.template expand<3>(), D, kind};
Key key = {fft_size.template expand<3>(), ncomp, D, kind};
if (cache) {
VendorPlan* cached_plan = nullptr;
if constexpr (std::is_same_v<float,T>) {
Expand All @@ -1174,27 +1189,34 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool
amrex::ignore_unused(cache);
#endif

int len[M];
for (int i = 0; i < M; ++i) {
len[i] = fft_size[M-1-i];
}

int nc = fft_size[0]/2+1;
for (int i = 1; i < M; ++i) {
nc *= fft_size[i];
}

#if defined(AMREX_USE_CUDA)

AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0));
cufftType type;
int n_in, n_out;
if constexpr (D == Direction::forward) {
type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
n_in = n;
n_out = nc;
} else {
type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
n_in = nc;
n_out = n;
}
std::size_t work_size;
if constexpr (M == 1) {
AMREX_CUFFT_SAFE_CALL
(cufftMakePlan1d(plan, fft_size[0], type, howmany, &work_size));
} else if constexpr (M == 2) {
AMREX_CUFFT_SAFE_CALL
(cufftMakePlan2d(plan, fft_size[1], fft_size[0], type, &work_size));
} else if constexpr (M == 3) {
AMREX_CUFFT_SAFE_CALL
(cufftMakePlan3d(plan, fft_size[2], fft_size[1], fft_size[0], type, &work_size));
}
AMREX_CUFFT_SAFE_CALL
(cufftMakePlanMany(plan, M, len, nullptr, 1, n_in, nullptr, 1, n_out, type, howmany, &work_size));

#elif defined(AMREX_USE_HIP)

Expand All @@ -1219,19 +1241,21 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool
if (M == 1) {
pp = new mkl_desc_r(fft_size[0]);
} else {
std::vector<std::int64_t> len(M);
std::vector<std::int64_t> len64(M);
for (int idim = 0; idim < M; ++idim) {
len[idim] = fft_size[M-1-idim];
len64[idim] = len[idim];
}
pp = new mkl_desc_r(len);
pp = new mkl_desc_r(len64);
}
#ifndef AMREX_USE_MKL_DFTI_2024
pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::NOT_INPLACE);
#else
pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
#endif

pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc);
std::vector<std::int64_t> strides(M+1);
strides[0] = 0;
strides[M] = 1;
Expand All @@ -1258,29 +1282,24 @@ void Plan<T>::init_r2c (IntVectND<M> const& fft_size, void* pbf, void* pbb, bool
return;
}

int size_for_row_major[M];
for (int idim = 0; idim < M; ++idim) {
size_for_row_major[idim] = fft_size[M-1-idim];
}

if constexpr (std::is_same_v<float,T>) {
if constexpr (D == Direction::forward) {
plan = fftwf_plan_dft_r2c
(M, size_for_row_major, (float*)pf, (fftwf_complex*)pb,
plan = fftwf_plan_many_dft_r2c
(M, len, howmany, (float*)pf, nullptr, 1, n, (fftwf_complex*)pb, nullptr, 1, nc,
FFTW_ESTIMATE);
} else {
plan = fftwf_plan_dft_c2r
(M, size_for_row_major, (fftwf_complex*)pb, (float*)pf,
plan = fftwf_plan_many_dft_c2r
(M, len, howmany, (fftwf_complex*)pb, nullptr, 1, nc, (float*)pf, nullptr, 1, n,
FFTW_ESTIMATE);
}
} else {
if constexpr (D == Direction::forward) {
plan = fftw_plan_dft_r2c
(M, size_for_row_major, (double*)pf, (fftw_complex*)pb,
plan = fftw_plan_many_dft_r2c
(M, len, howmany, (double*)pf, nullptr, 1, n, (fftw_complex*)pb, nullptr, 1, nc,
FFTW_ESTIMATE);
} else {
plan = fftw_plan_dft_c2r
(M, size_for_row_major, (fftw_complex*)pb, (double*)pf,
plan = fftw_plan_many_dft_c2r
(M, len, howmany, (fftw_complex*)pb, nullptr, 1, nc, (double*)pf, nullptr, 1, n,
FFTW_ESTIMATE);
}
}
Expand Down Expand Up @@ -1508,10 +1527,10 @@ namespace detail
b = make_box(b);
}
auto const& ng = make_iv(mf.nGrowVect());
FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), 1, ng, MFInfo{}.SetAlloc(false));
FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), mf.nComp(), ng, MFInfo{}.SetAlloc(false));
using FAB = typename FA::fab_type;
for (MFIter mfi(submf, MFItInfo().DisableDeviceSync()); mfi.isValid(); ++mfi) {
submf.setFab(mfi, FAB(mfi.fabbox(), 1, mf[mfi].dataPtr()));
submf.setFab(mfi, FAB(mfi.fabbox(), mf.nComp(), mf[mfi].dataPtr()));
}
return submf;
}
Expand Down
23 changes: 12 additions & 11 deletions Src/FFT/AMReX_FFT_OpenBCSolver.H
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Box OpenBCSolver<T>::make_grown_domain (Box const& domain, Info const& info)
{
IntVect len = domain.length();
#if (AMREX_SPACEDIM == 3)
if (info.batch_mode) { len[2] = 0; }
if (info.twod_mode) { len[2] = 0; }
#else
amrex::ignore_unused(info);
#endif
Expand All @@ -48,18 +48,19 @@ template <typename T>
OpenBCSolver<T>::OpenBCSolver (Box const& domain, Info const& info)
: m_domain(domain),
m_info(info),
m_r2c(OpenBCSolver<T>::make_grown_domain(domain,info), info)
m_r2c(OpenBCSolver<T>::make_grown_domain(domain,info),
m_info.setDomainStrategy(FFT::DomainStrategy::slab))
{
#if (AMREX_SPACEDIM == 3)
if (m_info.batch_mode) {
if (m_info.twod_mode) {
auto gdom = make_grown_domain(domain,m_info);
gdom.enclosedCells(2);
gdom.setSmall(2, 0);
int nprocs = std::min({ParallelContext::NProcsSub(),
m_info.nprocs,
m_domain.length(2)});
gdom.setBig(2, nprocs-1);
m_r2c_green = std::make_unique<R2C<T>>(gdom,info);
m_r2c_green = std::make_unique<R2C<T>>(gdom,m_info);
auto [sd, ord] = m_r2c_green->getSpectralData();
m_G_fft = cMF(*sd, amrex::make_alias, 0, 1);
} else
Expand All @@ -78,7 +79,7 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
{
BL_PROFILE("OpenBCSolver::setGreensFunction");

auto* infab = m_info.batch_mode ? detail::get_fab(m_r2c_green->m_rx)
auto* infab = m_info.twod_mode ? detail::get_fab(m_r2c_green->m_rx)
: detail::get_fab(m_r2c.m_rx);
auto const& lo = m_domain.smallEnd();
auto const& lo3 = lo.dim3();
Expand All @@ -87,7 +88,7 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
auto const& a = infab->array();
auto box = infab->box();
GpuArray<int,3> nimages{1,1,1};
int ndims = m_info.batch_mode ? AMREX_SPACEDIM-1 : AMREX_SPACEDIM;
int ndims = m_info.twod_mode ? AMREX_SPACEDIM-1 : AMREX_SPACEDIM;
for (int idim = 0; idim < ndims; ++idim) {
if (box.smallEnd(idim) == lo[idim] && box.length(idim) == 2*len[idim]) {
box.growHi(idim, -len[idim]+1); // +1 to include the middle plane
Expand Down Expand Up @@ -129,13 +130,13 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
});
}

if (m_info.batch_mode) {
if (m_info.twod_mode) {
m_r2c_green->forward(m_r2c_green->m_rx);
} else {
m_r2c.forward(m_r2c.m_rx);
}

if (!m_info.batch_mode) {
if (!m_info.twod_mode) {
auto [sd, ord] = m_r2c.getSpectralData();
amrex::ignore_unused(ord);
auto const* srcfab = detail::get_fab(*sd);
Expand Down Expand Up @@ -166,7 +167,7 @@ void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
inmf.setVal(T(0));
inmf.ParallelCopy(rho, 0, 0, 1);

m_r2c.m_openbc_half = !m_info.batch_mode;
m_r2c.m_openbc_half = !m_info.twod_mode;
m_r2c.forward(inmf);
m_r2c.m_openbc_half = false;

Expand All @@ -183,7 +184,7 @@ void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
Box const& rhobox = rhofab->box();
#if (AMREX_SPACEDIM == 3)
Long leng = gfab->box().numPts();
if (m_info.batch_mode) {
if (m_info.twod_mode) {
AMREX_ASSERT(gfab->box().length(2) == 1 &&
leng == (rhobox.length(0) * rhobox.length(1)));
} else {
Expand All @@ -204,7 +205,7 @@ void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
}
}

m_r2c.m_openbc_half = !m_info.batch_mode;
m_r2c.m_openbc_half = !m_info.twod_mode;
m_r2c.backward_doit(phi, phi.nGrowVect());
m_r2c.m_openbc_half = false;
}
Expand Down
4 changes: 2 additions & 2 deletions Src/FFT/AMReX_FFT_Poisson.H
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public:
}
}
Info info{};
info.setBatchMode(true);
info.setTwoDMode(true);
if (periodic_xy) {
m_r2c = std::make_unique<R2C<typename MF::value_type>>(m_geom.Domain(),
info);
Expand All @@ -145,7 +145,7 @@ public:
std::make_pair(Boundary::periodic,Boundary::periodic),
std::make_pair(Boundary::even,Boundary::even))},
m_r2c(std::make_unique<R2C<typename MF::value_type>>
(geom.Domain(), Info().setBatchMode(true)))
(geom.Domain(), Info().setTwoDMode(true)))
{
#if (AMREX_SPACEDIM == 3)
AMREX_ALWAYS_ASSERT(geom.isPeriodic(0) && geom.isPeriodic(1));
Expand Down
Loading

0 comments on commit 2f64d71

Please sign in to comment.