Skip to content

Commit

Permalink
Automatic creation of well-partitioned meshes based on #tasks and #el…
Browse files Browse the repository at this point in the history
…ements-per-task.
  • Loading branch information
vladotomov committed Dec 3, 2024
1 parent 47f56e6 commit ef0678d
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 14 deletions.
51 changes: 37 additions & 14 deletions remhos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ MFEM_EXPORT int remhos(int argc, char *argv[], double &final_mass_u)
mfem::MPI_Session mpi(argc, argv);
const int myid = mpi.WorldRank();

const char *mesh_file = "data/periodic-square.mesh";
const char *mesh_file = "default";
int dim = 3;
int elem_per_mpi = 1;
int rs_levels = 2;
int rp_levels = 0;
int order = 3;
Expand Down Expand Up @@ -189,6 +191,9 @@ MFEM_EXPORT int remhos(int argc, char *argv[], double &final_mass_u)
OptionsParser args(argc, argv);
args.AddOption(&mesh_file, "-m", "--mesh",
"Mesh file to use.");
args.AddOption(&dim, "-dim", "--dimension", "Dimension of the problem.");
args.AddOption(&elem_per_mpi, "-epm", "--elem-per-mpi",
"Number of element per mpi task.");
args.AddOption(&problem_num, "-p", "--problem",
"Problem setup to use. See options in velocity_function().");
args.AddOption(&rs_levels, "-rs", "--refine-serial",
Expand Down Expand Up @@ -282,29 +287,47 @@ MFEM_EXPORT int remhos(int argc, char *argv[], double &final_mass_u)
else if (problem_num < 20) { exec_mode = 1; }
else { MFEM_ABORT("Unspecified execution mode."); }

// Read the serial mesh from the given mesh file on all processors.
// Refine the mesh in serial to increase the resolution.
Mesh *mesh = new Mesh(Mesh::LoadFromFile(mesh_file, 1, 1));
const int dim = mesh->Dimension();
for (int lev = 0; lev < rs_levels; lev++) { mesh->UniformRefinement(); }
mesh->GetBoundingBox(bb_min, bb_max, max(order, 1));

// Only standard assembly in 1D (some mfem functions just abort in 1D).
if ((pa || next_gen_full) && dim == 1)
Mesh *mesh = nullptr;
int *mpi_partitioning = nullptr;
if (strncmp(mesh_file, "default", 7) != 0)
{
MFEM_WARNING("Disabling PA / FA for 1D.");
pa = false;
next_gen_full = false;
// Read the serial mesh from the given mesh file on all processors.
// Refine the mesh in serial to increase the resolution.
mesh = new Mesh(Mesh::LoadFromFile(mesh_file, 1, 1));
for (int lev = 0; lev < rs_levels; lev++) { mesh->UniformRefinement(); }
}
else
{
mesh = CartesianMesh(dim, Mpi::WorldSize(), elem_per_mpi, myid == 0,
rp_levels, &mpi_partitioning);
}
dim = mesh->Dimension();
mesh->GetBoundingBox(bb_min, bb_max, max(order, 1));

// Parallel partitioning of the mesh.
// Refine the mesh further in parallel to increase the resolution.
ParMesh pmesh(MPI_COMM_WORLD, *mesh);
ParMesh pmesh(MPI_COMM_WORLD, *mesh, mpi_partitioning);
delete mesh;
delete mpi_partitioning;
for (int lev = 0; lev < rp_levels; lev++) { pmesh.UniformRefinement(); }
MPI_Comm comm = pmesh.GetComm();
const int NE = pmesh.GetNE();

if (strncmp(mesh_file, "default", 7) == 0)
{
MFEM_VERIFY(pmesh.GetGlobalNE() == Mpi::WorldSize() * elem_per_mpi,
"Mesh generation error.");
MFEM_VERIFY(NE == elem_per_mpi, "Mesh generation error.");
}

// Only standard assembly in 1D (some mfem functions just abort in 1D).
if ((pa || next_gen_full) && dim == 1)
{
MFEM_WARNING("Disabling PA / FA for 1D.");
pa = false;
next_gen_full = false;
}

// Define the ODE solver used for time integration. Several explicit
// Runge-Kutta methods are available.
ODESolver *ode_solver = NULL;
Expand Down
121 changes: 121 additions & 0 deletions remhos_tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,127 @@ void MixedConvectionIntegrator::AssembleElementMatrix2(
}
}

Mesh *CartesianMesh(int dim, int mpi_cnt, int elem_per_mpi, bool print,
int &par_ref, int **partitioning)
{
MFEM_VERIFY(dim > 1, "Not implemented for 1D meshes.");

auto factor = [&](int N)
{
for (int i = static_cast<int>(sqrt(N)); i > 0; i--)
{ if (N % i == 0) { return i; } }
return 1;
};

par_ref = 0;
const int ref_factor = (dim == 2) ? 4 : 8;

// Elements per task before performing parallel refinements.
// This will be used to form the serial mesh.
int el0 = elem_per_mpi;
while (el0 % ref_factor == 0)
{
el0 /= ref_factor;
par_ref++;
}

// In the serial mesh we have:
// The number of MPI blocks is mpi_cnt = mp_x.mpy_y.mpy_z.
// The size of each MPI block is el0 = el0_x.el0_y.el0_z.
int mpi_x, mpi_y, mpi_z;
int el0_x, el0_y, el0_z;
if (dim == 2)
{
mpi_x = factor(mpi_cnt);
mpi_y = mpi_cnt / mpi_x;

// Switch order for better balance.
el0_y = factor(el0);
el0_x = el0 / el0_y;
}
else
{
mpi_x = factor(mpi_cnt);
mpi_y = factor(mpi_cnt / mpi_x);
mpi_z = mpi_cnt / mpi_x / mpi_y;

// Switch order for better balance.
el0_z = factor(el0);
el0_y = factor(el0 / el0_z);
el0_x = el0 / el0_y / el0_z;
}

if (print && dim == 2)
{
int elem_par_x = mpi_x * el0_x * pow(2, par_ref),
elem_par_y = mpi_y * el0_y * pow(2, par_ref);

std::cout << "--- Mesh generation: \n";
std::cout << "Par mesh: " << elem_par_x << " x " << elem_par_y
<< " (" << elem_par_x * elem_par_y << " elements)\n"
<< "Elem / task: "
<< el0_x * pow(2, par_ref) << " x "
<< el0_y * pow(2, par_ref)
<< " (" << el0_x * pow(2, 2*par_ref) * el0_y << " elements)\n"
<< "MPI blocks: " << mpi_x << " x " << mpi_y
<< " (" << mpi_x * mpi_y << " mpi tasks)\n" << "-\n"
<< "Serial mesh: "
<< mpi_x * el0_x << " x " << mpi_y * el0_y
<< " (" << mpi_x * el0_x * mpi_y * el0_y << " elements)\n"
<< "Elem / task: " << el0_x << " x " << el0_y << std::endl
<< "Par refine: " << par_ref << std::endl;
std::cout << "--- \n";
}

if (print && dim == 3)
{
int elem_par_x = mpi_x * el0_x * pow(2, par_ref),
elem_par_y = mpi_y * el0_y * pow(2, par_ref),
elem_par_z = mpi_z * el0_z * pow(2, par_ref);

std::cout << "--- Mesh generation: \n";
std::cout << "Par mesh: "
<< elem_par_x << " x " << elem_par_y << " x " << elem_par_z
<< " (" << elem_par_x*elem_par_y*elem_par_z << " elements)\n"
<< "Elem / task: "
<< el0_x * pow(2, par_ref) << " x "
<< el0_y * pow(2, par_ref) << " x "
<< el0_z * pow(2, par_ref)
<< " (" << el0_x*pow(2, 3*par_ref)*el0_y*el0_z << " elements)\n"
<< "MPI blocks: " << mpi_x << " x " << mpi_y << " x " << mpi_z
<< " (" << mpi_x * mpi_y * mpi_z << " mpi tasks)\n" << "-\n"
<< "Serial mesh: "
<< mpi_x*el0_x << " x " << mpi_y*el0_y << " x " << mpi_z*el0_z
<< " (" << mpi_x*el0_x*mpi_y*el0_y*mpi_z*el0_z << " elements)\n"
<< "Elem / task: "
<< el0_x << " x " << el0_y << " x " << el0_z << std::endl
<< "Par refine: " << par_ref << std::endl;
std::cout << "--- \n";
}

Mesh *mesh;
if (dim == 2)
{
mesh = new Mesh(Mesh::MakeCartesian2D(mpi_x * el0_x,
mpi_y * el0_y,
Element::QUADRILATERAL, true));
}
else
{
mesh = new Mesh(Mesh::MakeCartesian3D(mpi_x * el0_x,
mpi_y * el0_y,
mpi_z * el0_z,
Element::HEXAHEDRON, true));
}

int nxyz[dim];
if (dim == 2) { nxyz[0] = mpi_x; nxyz[1] = mpi_y; }
else { nxyz[0] = mpi_x; nxyz[1] = mpi_y; nxyz[2] = mpi_z; }
*partitioning = mesh->CartesianPartitioning(nxyz);

return mesh;
}

int GetLocalFaceDofIndex3D(int loc_face_id, int face_orient,
int face_dof_id, int face_dof1D_cnt)
{
Expand Down
3 changes: 3 additions & 0 deletions remhos_tools.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
namespace mfem
{

Mesh *CartesianMesh(int dim, int mpi_cnt, int elem_per_mpi, bool print,
int &par_ref, int **partitioning);

int GetLocalFaceDofIndex(int dim, int loc_face_id, int face_orient,
int face_dof_id, int face_dof1D_cnt);

Expand Down

0 comments on commit ef0678d

Please sign in to comment.