Skip to content

Commit

Permalink
Use CUDA native functions in math namespace
Browse files Browse the repository at this point in the history
This commit extends the math namespace to replace the C++ standard
library of floating point math functions (e.g. `atan`) with the CUDA
variants. The benefit of this is that the CUDA versions are better
optimized for GPU execution, and it enables us to speed them up
massively using fast-math mode.
  • Loading branch information
stephenswat committed Oct 15, 2024
1 parent ce2f877 commit 5b82560
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 2 deletions.
152 changes: 152 additions & 0 deletions core/include/detray/definitions/detail/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#pragma once

#include <concepts>

// SYCL include(s).
#if defined(CL_SYCL_LANGUAGE_VERSION) || defined(SYCL_LANGUAGE_VERSION)
#include <CL/sycl.hpp>
Expand Down Expand Up @@ -137,6 +139,156 @@ requires Vc::Traits::is_simd_vector<T>::value inline decltype(auto) fma(T &&x,
}
/// @}

} // namespace math
#elif defined(__CUDA_ARCH__)
namespace math {
using std::abs;

DETRAY_DEVICE inline float asin(float i) {
return ::asinf(i);
}

DETRAY_DEVICE inline double asin(double i) {
return ::asin(i);
}

DETRAY_DEVICE inline float atan(float i) {
return ::atanf(i);
}

DETRAY_DEVICE inline double atan(double i) {
return ::atan(i);
}

DETRAY_DEVICE inline float ceil(float i) {
return ::ceilf(i);
}

DETRAY_DEVICE inline double ceil(double i) {
return ::ceil(i);
}

DETRAY_DEVICE inline float copysign(float i, float j) {
return ::copysignf(i, j);
}

DETRAY_DEVICE inline double copysign(double i, double j) {
return ::copysign(i, j);
}

DETRAY_DEVICE inline float cos(float i) {
return ::cosf(i);
}

DETRAY_DEVICE inline double cos(double i) {
return ::cos(i);
}

DETRAY_DEVICE inline float exp(float i) {
return ::expf(i);
}

DETRAY_DEVICE inline double exp(double i) {
return ::exp(i);
}

DETRAY_DEVICE inline float fabs(float i) {
return ::fabsf(i);
}

DETRAY_DEVICE inline double fabs(double i) {
return ::fabs(i);
}

DETRAY_DEVICE inline float fma(float i, float j, float k) {
return ::fmaf(i, j, k);
}

DETRAY_DEVICE inline double fma(double i, double j, double k) {
return ::fma(i, j, k);
}

DETRAY_DEVICE inline float log(float i) {
return ::logf(i);
}

DETRAY_DEVICE inline double log(double i) {
return ::log(i);
}

DETRAY_DEVICE inline float log10(float i) {
return ::log10f(i);
}

DETRAY_DEVICE inline double log10(double i) {
return ::log10(i);
}

template <std::integral T>
DETRAY_DEVICE inline auto min(T i, T j) {
return std::min(i, j);
}

DETRAY_DEVICE inline float min(float i, float j) {
return ::min(i, j);
}

DETRAY_DEVICE inline double min(double i, double j) {
return ::min(i, j);
}

template <std::integral T>
DETRAY_DEVICE inline auto max(T i, T j) {
return std::max(i, j);
}

DETRAY_DEVICE inline float max(float i, float j) {
return ::max(i, j);
}

DETRAY_DEVICE inline double max(double i, double j) {
return ::max(i, j);
}

DETRAY_DEVICE inline float pow(float i, float p) {
return ::powf(i, p);
}

DETRAY_DEVICE inline double pow(double i, double p) {
return ::pow(i, p);
}

DETRAY_DEVICE inline auto signbit(float i) {
return ::signbit(i);
}

DETRAY_DEVICE inline auto signbit(double i) {
return ::signbit(i);
}

DETRAY_DEVICE inline float sin(float i) {
return ::sinf(i);
}

DETRAY_DEVICE inline double sin(double i) {
return ::sin(i);
}

DETRAY_DEVICE inline float sqrt(float i) {
return ::sqrtf(i);
}

DETRAY_DEVICE inline double sqrt(double i) {
return ::sqrt(i);
}

DETRAY_DEVICE inline float tan(float i) {
return ::tanf(i);
}

DETRAY_DEVICE inline double tan(double i) {
return ::tan(i);
}
} // namespace math
#else
namespace math = std;
Expand Down
3 changes: 2 additions & 1 deletion tests/include/detray/test/validation/detector_scan_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,8 @@ inline std::string print_trace(const truth_trace_t &truth_trace,
/// Print an adjacency list
inline std::string print_adj(const dvector<dindex> &adjacency_matrix) {

std::size_t dim = static_cast<dindex>(math::sqrt(adjacency_matrix.size()));
std::size_t dim = static_cast<dindex>(
math::sqrt(static_cast<scalar>(adjacency_matrix.size())));
std::stringstream out_stream{};

for (std::size_t i = 0u; i < dim - 1; ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,8 @@ inline auto print_efficiency(std::size_t n_tracks, std::size_t n_surfaces,
}

// How many significant digits to print
const auto n_sig{2 + static_cast<int>(math::ceil(math::log10(n_surfaces)))};
const auto n_sig{2 + static_cast<int>(math::ceil(
math::log10(static_cast<scalar>(n_surfaces))))};

assert(n_miss_nav <= n_surfaces);

Expand Down

0 comments on commit 5b82560

Please sign in to comment.