Skip to content

Commit

Permalink
Add comments explaining trial_elements_tuple and refactor for clarity…
Browse files Browse the repository at this point in the history
…. Clang format
  • Loading branch information
johnbowen42 committed Nov 1, 2023
1 parent d1fcd28 commit 37eb4e1
Show file tree
Hide file tree
Showing 28 changed files with 118 additions and 102 deletions.
16 changes: 2 additions & 14 deletions src/serac/numerics/functional/boundary_integral_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,6 @@ auto batch_apply_qf(lambda qf, const tensor<double, 3, n>& positions, const tens
return outputs;
}

template <mfem::Geometry::Type geom, typename test, typename... trials>
auto get_trial_elements(FunctionSignature<test(trials...)>)
{
return tuple<finite_element<geom, trials>...>{};
}

template <mfem::Geometry::Type geom, typename test, typename... trials>
auto get_test(FunctionSignature<test(trials...)>)
{
return finite_element<geom, test>{};
}

/// @trial_elements the element type for each trial space
template <uint32_t differentiation_index, int Q, mfem::Geometry::Type geom, typename test_element,
typename trial_element_type, typename lambda_type, typename derivative_type, int... indices>
Expand Down Expand Up @@ -339,8 +327,8 @@ std::function<void(const std::vector<const double*>&, double*, bool)> evaluation
signature s, lambda_type qf, const double* positions, const double* jacobians,
std::shared_ptr<derivative_type> qf_derivatives, uint32_t num_elements)
{
auto trial_elements = get_trial_elements<geom>(s);
auto test = get_test<geom>(s);
auto trial_elements = trial_elements_tuple<geom>(s);
auto test = get_test_element<geom>(s);
return [=](const std::vector<const double*>& inputs, double* outputs, bool /* update state */) {
evaluation_kernel_impl<wrt, Q, geom>(trial_elements, test, inputs, outputs, positions, jacobians, qf,
qf_derivatives.get(), num_elements, s.index_seq);
Expand Down
6 changes: 3 additions & 3 deletions src/serac/numerics/functional/detail/hexahedron_H1.inl
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ struct finite_element<mfem::Geometry::CUBE, H1<p, c> > {
}

template <int q>
SERAC_HOST_DEVICE
static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&)
SERAC_HOST_DEVICE static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&)
{
// we want to compute the following:
//
Expand Down Expand Up @@ -230,7 +229,8 @@ static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&

template <typename source_type, typename flux_type, int q>
SERAC_HOST_DEVICE static void integrate(const tensor<tuple<source_type, flux_type>, q * q * q>& qf_output,
const TensorProductQuadratureRule<q>&, dof_type* element_residual, int step = 1)
const TensorProductQuadratureRule<q>&, dof_type* element_residual,
int step = 1)
{
if constexpr (is_zero<source_type>{} && is_zero<flux_type>{}) {
return;
Expand Down
7 changes: 3 additions & 4 deletions src/serac/numerics/functional/detail/hexahedron_Hcurl.inl
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,7 @@ struct finite_element<mfem::Geometry::CUBE, Hcurl<p>> {
}

template <int q>
SERAC_HOST_DEVICE
static auto interpolate(const dof_type& element_values, const TensorProductQuadratureRule<q>&)
SERAC_HOST_DEVICE static auto interpolate(const dof_type& element_values, const TensorProductQuadratureRule<q>&)
{
constexpr bool apply_weights = false;
constexpr tensor<double, q, p> B1 = calculate_B1<apply_weights, q>();
Expand Down Expand Up @@ -397,8 +396,8 @@ static auto interpolate(const dof_type& element_values, const TensorProductQuadr

template <typename source_type, typename flux_type, int q>
SERAC_HOST_DEVICE static void integrate(const tensor<tuple<source_type, flux_type>, q * q * q>& qf_output,
const TensorProductQuadratureRule<q>&, dof_type* element_residual,
[[maybe_unused]] int step = 1)
const TensorProductQuadratureRule<q>&, dof_type* element_residual,
[[maybe_unused]] int step = 1)
{
constexpr bool apply_weights = true;
constexpr tensor<double, q, p> B1 = calculate_B1<apply_weights, q>();
Expand Down
6 changes: 3 additions & 3 deletions src/serac/numerics/functional/detail/hexahedron_L2.inl
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,7 @@ struct finite_element<mfem::Geometry::CUBE, L2<p, c> > {
}

template <int q>
SERAC_HOST_DEVICE
static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&)
SERAC_HOST_DEVICE static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&)
{
// we want to compute the following:
//
Expand Down Expand Up @@ -234,7 +233,8 @@ static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&

template <typename source_type, typename flux_type, int q>
SERAC_HOST_DEVICE static void integrate(const tensor<tuple<source_type, flux_type>, q * q * q>& qf_output,
const TensorProductQuadratureRule<q>&, dof_type* element_residual, int step = 1)
const TensorProductQuadratureRule<q>&, dof_type* element_residual,
int step = 1)
{
if constexpr (is_zero<source_type>{} && is_zero<flux_type>{}) {
return;
Expand Down
3 changes: 2 additions & 1 deletion src/serac/numerics/functional/detail/metaprogramming.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ constexpr auto get(std::integer_sequence<int, n...>)
namespace detail {

template <typename T>
struct always_false : std::false_type {};
struct always_false : std::false_type {
};

/**
* @brief unfortunately std::integral_constant doesn't have __host__ __device__ annotations
Expand Down
7 changes: 4 additions & 3 deletions src/serac/numerics/functional/detail/qoi.inl
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ struct finite_element<g, QOI> {

template <int Q, int q>
SERAC_HOST_DEVICE static void integrate(const tensor<zero, Q>&, const TensorProductQuadratureRule<q>&, dof_type*,
[[maybe_unused]] int step = 1)
[[maybe_unused]] int step = 1)
{
return; // integrating zeros is a no-op
}

template <int Q, int q>
SERAC_HOST_DEVICE static void integrate(const tensor<double, Q>& qf_output, const TensorProductQuadratureRule<q>&,
dof_type* element_total, [[maybe_unused]] int step = 1)
dof_type* element_total, [[maybe_unused]] int step = 1)
{
if constexpr (geometry == mfem::Geometry::SEGMENT) {
static_assert(Q == q);
Expand Down Expand Up @@ -78,7 +78,8 @@ struct finite_element<g, QOI> {
// output to be a tuple with a hardcoded `zero` flux term
template <typename source_type, int Q, int q>
SERAC_HOST_DEVICE static void integrate(const tensor<serac::tuple<source_type, zero>, Q>& qf_output,
const TensorProductQuadratureRule<q>&, dof_type* element_total, [[maybe_unused]] int step = 1)
const TensorProductQuadratureRule<q>&, dof_type* element_total,
[[maybe_unused]] int step = 1)
{
if constexpr (is_zero<source_type>{}) {
return;
Expand Down
6 changes: 3 additions & 3 deletions src/serac/numerics/functional/detail/quadrilateral_H1.inl
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,7 @@ struct finite_element<mfem::Geometry::SQUARE, H1<p, c> > {
// A(dy, qx) := B(qx, dx) * X_e(dy, dx)
// X_q(qy, qx) := B(qy, dy) * A(dy, qx)
template <int q>
SERAC_HOST_DEVICE
static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&)
SERAC_HOST_DEVICE static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&)
{
static constexpr bool apply_weights = false;
static constexpr auto B = calculate_B<apply_weights, q>();
Expand Down Expand Up @@ -247,7 +246,8 @@ static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&
// tensor<double,dim,dim,dim>}
template <typename source_type, typename flux_type, int q>
SERAC_HOST_DEVICE static void integrate(const tensor<tuple<source_type, flux_type>, q * q>& qf_output,
const TensorProductQuadratureRule<q>&, dof_type* element_residual, int step = 1)
const TensorProductQuadratureRule<q>&, dof_type* element_residual,
int step = 1)
{
if constexpr (is_zero<source_type>{} && is_zero<flux_type>{}) {
return;
Expand Down
7 changes: 3 additions & 4 deletions src/serac/numerics/functional/detail/quadrilateral_Hcurl.inl
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,7 @@ struct finite_element<mfem::Geometry::SQUARE, Hcurl<p> > {
}

template <int q>
SERAC_HOST_DEVICE
static auto interpolate(const dof_type& element_values, const TensorProductQuadratureRule<q>&)
SERAC_HOST_DEVICE static auto interpolate(const dof_type& element_values, const TensorProductQuadratureRule<q>&)
{
constexpr bool apply_weights = false;
constexpr tensor<double, q, p> B1 = calculate_B1<apply_weights, q>();
Expand Down Expand Up @@ -325,8 +324,8 @@ static auto interpolate(const dof_type& element_values, const TensorProductQuadr

template <typename source_type, typename flux_type, int q>
SERAC_HOST_DEVICE static void integrate(const tensor<tuple<source_type, flux_type>, q * q>& qf_output,
const TensorProductQuadratureRule<q>&, dof_type* element_residual,
[[maybe_unused]] int step = 1)
const TensorProductQuadratureRule<q>&, dof_type* element_residual,
[[maybe_unused]] int step = 1)
{
constexpr bool apply_weights = true;
constexpr tensor<double, q, p> B1 = calculate_B1<apply_weights, q>();
Expand Down
6 changes: 3 additions & 3 deletions src/serac/numerics/functional/detail/quadrilateral_L2.inl
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,7 @@ struct finite_element<mfem::Geometry::SQUARE, L2<p, c> > {
// A(dy, qx) := B(qx, dx) * X_e(dy, dx)
// X_q(qy, qx) := B(qy, dy) * A(dy, qx)
template <int q>
SERAC_HOST_DEVICE
static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&)
SERAC_HOST_DEVICE static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&)
{
static constexpr bool apply_weights = false;
static constexpr auto B = calculate_B<apply_weights, q>();
Expand Down Expand Up @@ -246,7 +245,8 @@ static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&
// tensor<double,dim,dim,dim>}
template <typename source_type, typename flux_type, int q>
SERAC_HOST_DEVICE static void integrate(const tensor<tuple<source_type, flux_type>, q * q>& qf_output,
const TensorProductQuadratureRule<q>&, dof_type* element_residual, int step = 1)
const TensorProductQuadratureRule<q>&, dof_type* element_residual,
int step = 1)
{
if constexpr (is_zero<source_type>{} && is_zero<flux_type>{}) {
return;
Expand Down
10 changes: 4 additions & 6 deletions src/serac/numerics/functional/detail/segment_H1.inl
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ struct finite_element<mfem::Geometry::SEGMENT, H1<p, c> > {
}

template <typename T, int q>
SERAC_HOST_DEVICE
static auto batch_apply_shape_fn(int jx, tensor<T, q> input, const TensorProductQuadratureRule<q>&)
SERAC_HOST_DEVICE static auto batch_apply_shape_fn(int jx, tensor<T, q> input, const TensorProductQuadratureRule<q>&)
{
static constexpr bool apply_weights = false;
static constexpr auto B = calculate_B<apply_weights, q>();
Expand All @@ -122,8 +121,7 @@ struct finite_element<mfem::Geometry::SEGMENT, H1<p, c> > {
}

template <int q>
SERAC_HOST_DEVICE
static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&)
SERAC_HOST_DEVICE static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&)
{
static constexpr bool apply_weights = false;
static constexpr auto B = calculate_B<apply_weights, q>();
Expand Down Expand Up @@ -158,8 +156,8 @@ static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&

template <typename source_type, typename flux_type, int q>
SERAC_HOST_DEVICE static void integrate(const tensor<tuple<source_type, flux_type>, q>& qf_output,
const TensorProductQuadratureRule<q>&, dof_type* element_residual,
[[maybe_unused]] int step = 1)
const TensorProductQuadratureRule<q>&, dof_type* element_residual,
[[maybe_unused]] int step = 1)
{
if constexpr (is_zero<source_type>{} && is_zero<flux_type>{}) {
return;
Expand Down
7 changes: 3 additions & 4 deletions src/serac/numerics/functional/detail/segment_L2.inl
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ struct finite_element<mfem::Geometry::SEGMENT, L2<p, c> > {
}

template <int q>
SERAC_HOST_DEVICE
static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&)
SERAC_HOST_DEVICE static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&)
{
static constexpr bool apply_weights = false;
static constexpr auto B = calculate_B<apply_weights, q>();
Expand Down Expand Up @@ -157,8 +156,8 @@ static auto interpolate(const dof_type& X, const TensorProductQuadratureRule<q>&

template <typename source_type, typename flux_type, int q>
SERAC_HOST_DEVICE static void integrate(const tensor<tuple<source_type, flux_type>, q>& qf_output,
const TensorProductQuadratureRule<q>&, dof_type* element_residual,
[[maybe_unused]] int step = 1)
const TensorProductQuadratureRule<q>&, dof_type* element_residual,
[[maybe_unused]] int step = 1)
{
if constexpr (is_zero<source_type>{} && is_zero<flux_type>{}) {
return;
Expand Down
6 changes: 3 additions & 3 deletions src/serac/numerics/functional/detail/tetrahedron_H1.inl
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,7 @@ struct finite_element<mfem::Geometry::TETRAHEDRON, H1<p, c> > {
}

template <int q>
SERAC_HOST_DEVICE
static auto interpolate(const tensor<double, c, ndof>& X, const TensorProductQuadratureRule<q>&)
SERAC_HOST_DEVICE static auto interpolate(const tensor<double, c, ndof>& X, const TensorProductQuadratureRule<q>&)
{
constexpr auto xi = GaussLegendreNodes<q, mfem::Geometry::TETRAHEDRON>();

Expand All @@ -386,7 +385,8 @@ static auto interpolate(const tensor<double, c, ndof>& X, const TensorProductQua

template <typename source_type, typename flux_type, int q>
SERAC_HOST_DEVICE static void integrate(const tensor<tuple<source_type, flux_type>, nqpts(q)>& qf_output,
const TensorProductQuadratureRule<q>&, tensor<double, c, ndof>* element_residual, int step = 1)
const TensorProductQuadratureRule<q>&,
tensor<double, c, ndof>* element_residual, int step = 1)
{
if constexpr (is_zero<source_type>{} && is_zero<flux_type>{}) {
return;
Expand Down
6 changes: 3 additions & 3 deletions src/serac/numerics/functional/detail/tetrahedron_L2.inl
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,7 @@ struct finite_element<mfem::Geometry::TETRAHEDRON, L2<p, c> > {
}

template <int q>
SERAC_HOST_DEVICE
static auto interpolate(const tensor<double, c, ndof>& X, const TensorProductQuadratureRule<q>&)
SERAC_HOST_DEVICE static auto interpolate(const tensor<double, c, ndof>& X, const TensorProductQuadratureRule<q>&)
{
constexpr auto xi = GaussLegendreNodes<q, mfem::Geometry::TETRAHEDRON>();

Expand All @@ -391,7 +390,8 @@ static auto interpolate(const tensor<double, c, ndof>& X, const TensorProductQua

template <typename source_type, typename flux_type, int q>
SERAC_HOST_DEVICE static void integrate(const tensor<tuple<source_type, flux_type>, nqpts(q)>& qf_output,
const TensorProductQuadratureRule<q>&, tensor<double, c, ndof>* element_residual, int step = 1)
const TensorProductQuadratureRule<q>&,
tensor<double, c, ndof>* element_residual, int step = 1)
{
if constexpr (is_zero<source_type>{} && is_zero<flux_type>{}) {
return;
Expand Down
6 changes: 3 additions & 3 deletions src/serac/numerics/functional/detail/triangle_H1.inl
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,7 @@ struct finite_element<mfem::Geometry::TRIANGLE, H1<p, c> > {
}

template <int q>
SERAC_HOST_DEVICE
static auto interpolate(const tensor<double, c, ndof>& X, const TensorProductQuadratureRule<q>&)
SERAC_HOST_DEVICE static auto interpolate(const tensor<double, c, ndof>& X, const TensorProductQuadratureRule<q>&)
{
constexpr auto xi = GaussLegendreNodes<q, mfem::Geometry::TRIANGLE>();
static constexpr int num_quadrature_points = q * (q + 1) / 2;
Expand All @@ -292,7 +291,8 @@ static auto interpolate(const tensor<double, c, ndof>& X, const TensorProductQua

template <typename source_type, typename flux_type, int q>
SERAC_HOST_DEVICE static void integrate(const tensor<tuple<source_type, flux_type>, q*(q + 1) / 2>& qf_output,
const TensorProductQuadratureRule<q>&, tensor<double, c, ndof>* element_residual, int step = 1)
const TensorProductQuadratureRule<q>&,
tensor<double, c, ndof>* element_residual, int step = 1)
{
if constexpr (is_zero<source_type>{} && is_zero<flux_type>{}) {
return;
Expand Down
6 changes: 3 additions & 3 deletions src/serac/numerics/functional/detail/triangle_L2.inl
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,7 @@ struct finite_element<mfem::Geometry::TRIANGLE, L2<p, c> > {
}

template <int q>
SERAC_HOST_DEVICE
static auto interpolate(const tensor<double, c, ndof>& X, const TensorProductQuadratureRule<q>&)
SERAC_HOST_DEVICE static auto interpolate(const tensor<double, c, ndof>& X, const TensorProductQuadratureRule<q>&)
{
constexpr auto xi = GaussLegendreNodes<q, mfem::Geometry::TRIANGLE>();
static constexpr int num_quadrature_points = q * (q + 1) / 2;
Expand All @@ -301,7 +300,8 @@ static auto interpolate(const tensor<double, c, ndof>& X, const TensorProductQua

template <typename source_type, typename flux_type, int q>
SERAC_HOST_DEVICE static void integrate(const tensor<tuple<source_type, flux_type>, q*(q + 1) / 2>& qf_output,
const TensorProductQuadratureRule<q>&, tensor<double, c, ndof>* element_residual, int step = 1)
const TensorProductQuadratureRule<q>&,
tensor<double, c, ndof>* element_residual, int step = 1)
{
if constexpr (is_zero<source_type>{} && is_zero<flux_type>{}) {
return;
Expand Down
5 changes: 3 additions & 2 deletions src/serac/numerics/functional/differentiate_wrt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ namespace serac {
static constexpr uint32_t NO_DIFFERENTIATION = uint32_t(1) << 31;

template <uint32_t i>
struct DifferentiateWRT {};
struct DifferentiateWRT {
};

/**
* @brief this type exists solely as a way to signal to `serac::Functional` that the function
Expand All @@ -22,7 +23,7 @@ struct differentiate_wrt_this {
const mfem::Vector& ref; ///< the actual data wrapped by this type

/// @brief implicitly convert back to `mfem::Vector` to extract the actual data
operator const mfem::Vector&() const { return ref; }
operator const mfem::Vector &() const { return ref; }
};

/**
Expand Down
16 changes: 2 additions & 14 deletions src/serac/numerics/functional/domain_integral_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,18 +137,6 @@ SERAC_HOST_DEVICE auto batch_apply_qf(lambda qf, const tensor<double, dim, n> x,
return outputs;
}

template <mfem::Geometry::Type geom, typename test, typename... trials>
auto get_trial_elements(FunctionSignature<test(trials...)>)
{
return tuple<finite_element<geom, trials>...>{};
}

template <mfem::Geometry::Type geom, typename test, typename... trials>
auto get_test(FunctionSignature<test(trials...)>)
{
return finite_element<geom, test>{};
}

template <uint32_t differentiation_index, int Q, mfem::Geometry::Type geom, typename test_element,
typename trial_element_type, typename lambda_type, typename state_type, typename derivative_type,
int... indices>
Expand Down Expand Up @@ -365,8 +353,8 @@ std::function<void(const std::vector<const double*>&, double*, bool)> evaluation
std::shared_ptr<QuadratureData<state_type> > qf_state, std::shared_ptr<derivative_type> qf_derivatives,
uint32_t num_elements)
{
auto trial_elements = get_trial_elements<geom>(s);
auto test = get_test<geom>(s);
auto trial_elements = trial_elements_tuple<geom>(s);
auto test = get_test_element<geom>(s);
return [=](const std::vector<const double*>& inputs, double* outputs, bool update_state) {
domain_integral::evaluation_kernel_impl<wrt, Q, geom>(trial_elements, test, inputs, outputs, positions, jacobians,
qf, (*qf_state)[geom], qf_derivatives.get(), num_elements,
Expand Down
5 changes: 3 additions & 2 deletions src/serac/numerics/functional/finite_element.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ struct TensorProductQuadratureRule {
};

template <auto val>
struct CompileTimeValue {};
struct CompileTimeValue {
};

/**
* @brief this struct is used to look up mfem's memory layout of
Expand Down Expand Up @@ -346,7 +347,7 @@ SERAC_HOST_DEVICE void physical_to_parent(tensor<T, q>& qf_output, const tensor<
*
*/
template <mfem::Geometry::Type g, typename family>
SERAC_HOST_DEVICE struct finite_element;
struct finite_element;

#include "detail/segment_H1.inl"
#include "detail/segment_Hcurl.inl"
Expand Down
Loading

0 comments on commit 37eb4e1

Please sign in to comment.