Skip to content

Commit

Permalink
Complete addition of execution space parameter to code
Browse files Browse the repository at this point in the history
  • Loading branch information
johnbowen42 committed Sep 24, 2024
1 parent bdafb37 commit 0a8b42e
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 106 deletions.
103 changes: 55 additions & 48 deletions src/serac/numerics/functional/boundary_integral_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,10 @@ void evaluation_kernel_impl(trial_element_type trial_elements, test_element, dou

auto e_range = RAJA::TypedRangeSegment<uint32_t>(0, num_elements);
// for each element in the domain
RAJA::launch<launch_policy>(
RAJA::launch<typename EvaluationSpacePolicy<exec>::launch_t>(
RAJA::LaunchParams(RAJA::Teams(static_cast<int>(num_elements)), RAJA::Threads(BLOCK_SZ)),
[=] RAJA_HOST_DEVICE(RAJA::LaunchContext ctx) {
RAJA::loop<teams_e>(
RAJA::loop<typename EvaluationSpacePolicy<exec>::teams_t>(
ctx, e_range,
// The explicit capture list is needed here because the capture occurs in a function
// template with a variadic non-type parameter.
Expand All @@ -222,8 +222,8 @@ void evaluation_kernel_impl(trial_element_type trial_elements, test_element, dou

ctx.teamSync();

(promote_each_to_dual_when<indices == differentiation_index>(get<indices>(interpolate_result[e]),
&get<indices>(qf_inputs[e]), ctx),
(promote_each_to_dual_when<indices == differentiation_index, exec>(get<indices>(interpolate_result[e]),
&get<indices>(qf_inputs[e]), ctx),
...);

ctx.teamSync();
Expand All @@ -239,7 +239,7 @@ void evaluation_kernel_impl(trial_element_type trial_elements, test_element, dou
// won't need to be applied in the action_of_gradient and element_gradient kernels
if constexpr (differentiation_index != serac::NO_DIFFERENTIATION) {
RAJA::RangeSegment x_range(0, leading_dimension(qf_outputs));
RAJA::loop<threads_x>(ctx, x_range, [&](int q) {
RAJA::loop<typename EvaluationSpacePolicy<exec>::threads_t>(ctx, x_range, [&](int q) {
qf_derivatives[e * qpts_per_elem + uint32_t(q)] = get_gradient(qf_outputs[q]);
});
}
Expand All @@ -263,15 +263,14 @@ SERAC_HOST_DEVICE auto chain_rule(const S& dfdx, const T& dx)
}
//clang-format on

template <typename derivative_type, int n, typename T>
SERAC_HOST_DEVICE auto batch_apply_chain_rule(derivative_type* qf_derivatives, const tensor<T, n>& inputs,
const RAJA::LaunchContext& ctx)
template <ExecutionSpace exec, typename derivative_type, int n, typename T>
SERAC_HOST_DEVICE auto batch_apply_chain_rule(
derivative_type* qf_derivatives, const tensor<T, n>& inputs,
tensor<tuple<decltype(chain_rule(derivative_type{}, T{})), zero>, n>& outputs, const RAJA::LaunchContext& ctx)
{
using return_type = decltype(chain_rule(derivative_type{}, T{}));
tensor<tuple<return_type, zero>, n> outputs{};
RAJA::RangeSegment i_range(0, n);

RAJA::loop<threads_x>(ctx, i_range, [&](int i) { get<0>(outputs[i]) = chain_rule(qf_derivatives[i], inputs[i]); });
RAJA::RangeSegment i_range(0, n);
RAJA::loop<typename EvaluationSpacePolicy<exec>::threads_t>(
ctx, i_range, [&](int i) { get<0>(outputs[i]) = chain_rule(qf_derivatives[i], inputs[i]); });
return outputs;
}

Expand Down Expand Up @@ -322,25 +321,30 @@ void action_of_gradient_kernel(const double* dU, double* dR, derivatives_type* q
auto& rm = umpire::ResourceManager::getInstance();
auto allocator = rm.getAllocator(device_name);
qf_inputs_type* qf_inputs = static_cast<qf_inputs_type*>(allocator.allocate(sizeof(qf_inputs_type) * num_elements));
// This typedef is needed to declare qf_outputs in shared memory.
using qf_outputs_type = decltype(batch_apply_chain_rule(qf_derivatives, *qf_inputs, RAJA::LaunchContext{}));
rm.memset(qf_inputs, 0);
// This typedef is needed to declare qf_outputs in shared memory.
using qf_outputs_type =
tensor<tuple<decltype(chain_rule(derivatives_type{}, typename trial_element::qf_input_type{})), zero>, nqp>;

// for each element in the domain
RAJA::launch<launch_policy>(RAJA::LaunchParams(RAJA::Teams(static_cast<int>(num_elements)), RAJA::Threads(BLOCK_SZ)),
[=] RAJA_HOST_DEVICE(RAJA::LaunchContext ctx) {
RAJA::loop<teams_e>(ctx, e_range, [&](int e) {
// (batch) interpolate each quadrature point's value
trial_element::interpolate(du[elements[e]], rule, qf_inputs, ctx);

// (batch) evalute the q-function at each quadrature point
RAJA_TEAM_SHARED qf_outputs_type qf_outputs;
qf_outputs = batch_apply_chain_rule(qf_derivatives + e * nqp, *qf_inputs, ctx);

// (batch) integrate the material response against the test-space basis functions
test_element::integrate(qf_outputs, rule, &dr[elements[e]], ctx);
});
});
RAJA::launch<typename EvaluationSpacePolicy<exec>::launch_t>(
RAJA::LaunchParams(RAJA::Teams(static_cast<int>(num_elements)), RAJA::Threads(BLOCK_SZ)),
[=] RAJA_HOST_DEVICE(RAJA::LaunchContext ctx) {
RAJA::loop<typename EvaluationSpacePolicy<exec>::teams_t>(ctx, e_range, [&](int e) {
// (batch) interpolate each quadrature point's value
trial_element::interpolate(du[elements[e]], rule, &(qf_inputs[e]), ctx);
ctx.teamSync();

// (batch) evalute the q-function at each quadrature point
RAJA_TEAM_SHARED qf_outputs_type qf_outputs;
batch_apply_chain_rule<exec>(qf_derivatives + e * nqp, qf_inputs[e], qf_outputs, ctx);
ctx.teamSync();

// (batch) integrate the material response against the test-space basis functions
test_element::integrate(qf_outputs, rule, &dr[elements[e]], ctx);
ctx.teamSync();
});
});
rm.deallocate(qf_inputs);
}

Expand Down Expand Up @@ -382,29 +386,31 @@ void element_gradient_kernel(ExecArrayView<double, 3, exec> dK, derivatives_type
constexpr TensorProductQuadratureRule<Q> rule{};

// for each element in the domain
RAJA::launch<launch_policy>(
RAJA::launch<typename EvaluationSpacePolicy<exec>::launch_t>(
RAJA::LaunchParams(RAJA::Teams(static_cast<int>(num_elements)), RAJA::Threads(BLOCK_SZ)),
[=] RAJA_HOST_DEVICE(RAJA::LaunchContext ctx) {
RAJA::loop<teams_e>(ctx, elements_range, [&ctx, dK, elements, qf_derivatives, nquad, rule](uint32_t e) {
(void)nquad;
auto* output_ptr = reinterpret_cast<typename test_element::dof_type*>(&dK(elements[e], 0, 0));
RAJA::loop<typename EvaluationSpacePolicy<exec>::teams_t>(
ctx, elements_range, [&ctx, dK, elements, qf_derivatives, nquad, rule](uint32_t e) {
(void)nquad;
auto* output_ptr = reinterpret_cast<typename test_element::dof_type*>(&dK(elements[e], 0, 0));

RAJA_TEAM_SHARED tensor<derivatives_type, nquad> derivatives;
RAJA::RangeSegment x_range(0, nquad);
RAJA::loop<threads_x>(ctx, x_range, [&](int q) { derivatives(q) = qf_derivatives[e * nquad + uint32_t(q)]; });
RAJA_TEAM_SHARED tensor<derivatives_type, nquad> derivatives;
RAJA::RangeSegment x_range(0, nquad);
RAJA::loop<typename EvaluationSpacePolicy<exec>::threads_t>(
ctx, x_range, [&](int q) { derivatives(q) = qf_derivatives[e * nquad + uint32_t(q)]; });

ctx.teamSync();
ctx.teamSync();

RAJA_TEAM_SHARED
typename trial_element::template batch_apply_shape_fn_output<derivatives_type, Q>::type source_and_flux;
for (int J = 0; J < trial_element::ndof; J++) {
trial_element::batch_apply_shape_fn(J, derivatives, &source_and_flux, rule, ctx);
ctx.teamSync();
RAJA_TEAM_SHARED
typename trial_element::template batch_apply_shape_fn_output<derivatives_type, Q>::type source_and_flux;
for (int J = 0; J < trial_element::ndof; J++) {
trial_element::batch_apply_shape_fn(J, derivatives, &source_and_flux, rule, ctx);
ctx.teamSync();

test_element::integrate(source_and_flux, rule, output_ptr + J, ctx, trial_element::ndof);
ctx.teamSync();
}
});
test_element::integrate(source_and_flux, rule, output_ptr + J, ctx, trial_element::ndof);
ctx.teamSync();
}
});
});
}

Expand All @@ -416,8 +422,9 @@ auto evaluation_kernel(signature s, lambda_type qf, const double* positions, con
auto trial_elements = trial_elements_tuple<geom, exec>(s);
auto test_element = get_test_element<geom, exec>(s);
return [=](double time, const std::vector<const double*>& inputs, double* outputs, bool /* update state */) {
evaluation_kernel_impl<wrt, Q, geom, exec>(trial_elements, test_element, time, inputs, outputs, positions, jacobians, qf,
qf_derivatives.get(), elements, num_elements, s.index_seq);
evaluation_kernel_impl<wrt, Q, geom, exec>(trial_elements, test_element, time, inputs, outputs, positions,
jacobians, qf, qf_derivatives.get(), elements, num_elements,
s.index_seq);
};
}

Expand Down
4 changes: 2 additions & 2 deletions src/serac/numerics/functional/detail/hexahedron_H1.inl
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ struct finite_element<mfem::Geometry::CUBE, H1<p, c>, exec> {
// A1(dz, dy, qx) := B(qx, dx) * X_e(dz, dy, dx)
// A2(dz, qy, qx) := B(qy, dy) * A1(dz, dy, qx)
// X_q(qz, qy, qx) := B(qz, dz) * A2(dz, qy, qx)
using threads_t = typename EvaluationSpacePolicy<exec>::EvaluationSpacePolicy<exec>::threads_t;
using threads_t = typename EvaluationSpacePolicy<exec>::threads_t;
static constexpr bool apply_weights = false;

RAJA::RangeSegment x_range(0, BLOCK_SZ);
Expand Down Expand Up @@ -309,7 +309,7 @@ struct finite_element<mfem::Geometry::CUBE, H1<p, c>, exec> {
}

constexpr int ntrial = std::max(size(source_type{}), size(flux_type{}) / dim) / c;
using threads_t = typename EvaluationSpacePolicy<exec>::EvaluationSpacePolicy<exec>::threads_t;
using threads_t = typename EvaluationSpacePolicy<exec>::threads_t;
using s_buffer_type = std::conditional_t<is_zero<source_type>{}, zero, tensor<double, q, q, q>>;
using f_buffer_type = std::conditional_t<is_zero<flux_type>{}, zero, tensor<double, dim, q, q, q>>;

Expand Down
2 changes: 1 addition & 1 deletion src/serac/numerics/functional/detail/tetrahedron_H1.inl
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ struct finite_element<mfem::Geometry::TETRAHEDRON, H1<p, c>, exec> {
constexpr auto xi = GaussLegendreNodes<q, mfem::Geometry::TETRAHEDRON>();

auto x_range = RAJA::RangeSegment(0, nqpts(q));
RAJA::loop<threads_x>(ctx, x_range, [&](int i) {
RAJA::loop<typename EvaluationSpacePolicy<exec>::threads_t>(ctx, x_range, [&](int i) {
double phi_j = shape_function(xi[i], j);
tensor<double, dim> dphi_j_dxi = shape_function_gradient(xi[i], j);

Expand Down
Loading

0 comments on commit 0a8b42e

Please sign in to comment.