Skip to content

Commit

Permalink
put everything into namespace cu:: and fix ADL issues
Browse files Browse the repository at this point in the history
  • Loading branch information
neilkichler committed Jun 25, 2024
1 parent e129ed0 commit 2c3e194
Show file tree
Hide file tree
Showing 28 changed files with 347 additions and 255 deletions.
1 change: 1 addition & 0 deletions examples/bisection.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "utils.h"

using cu::interval;

template<typename I>
__device__ I f(I x)
Expand Down
8 changes: 4 additions & 4 deletions examples/cuinterval/examples/bisection.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ struct local_stack
size_type len {};
};

typedef interval<double> (*fn_t)(interval<double>);
typedef cu::interval<double> (*fn_t)(cu::interval<double>);

// Example implementation of the bisection method for finding all roots in a given interval.
template<typename T, int max_depth>
__global__ void bisection(fn_t f, interval<T> x_init, double tol, interval<T> *roots, std::size_t *max_roots)
__global__ void bisection(fn_t f, cu::interval<T> x_init, double tol, cu::interval<T> *roots, std::size_t *max_roots)
{
using I = interval<T>;
using I = cu::interval<T>;

std::size_t n_roots = 0;
local_stack<I, max_depth> intervals;
Expand Down Expand Up @@ -66,7 +66,7 @@ __global__ void bisection(fn_t f, interval<T> x_init, double tol, interval<T> *r
}
} else {
// interval could still contain a root -> bisect
split<T> c = bisect(x, 0.5);
cu::split<T> c = bisect(x, 0.5);
// we do depth-first search which often will not be optimal
intervals.push(c.upper_half);
intervals.push(c.lower_half);
Expand Down
75 changes: 69 additions & 6 deletions include/cuinterval/arithmetic/basic.cuh
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
#ifndef CUINTERVAL_ARITHMETIC_BASIC_CUH
#define CUINTERVAL_ARITHMETIC_BASIC_CUH

#include <cuinterval/interval.h>
#include "intrinsic.cuh"
#include <cuinterval/interval.h>

#include <cassert>
#include <cmath>
#include <numbers>

namespace cu
{

//
// Constant intervals
//
Expand Down Expand Up @@ -120,6 +124,8 @@ inline constexpr __device__ interval<T> sqrt(interval<T> x)
template<typename T>
inline constexpr __device__ interval<T> cbrt(interval<T> x)
{
using std::cbrt;

if (empty(x)) {
return x;
}
Expand All @@ -144,9 +150,9 @@ inline constexpr __device__ interval<T> recip(interval<T> a)
} else if (a.lb == zero && zero < a.ub) {
return { intrinsic::rcp_down(a.ub), intrinsic::pos_inf<T>() };
} else if (a.lb < zero && zero < a.ub) {
return ::entire<T>();
return entire<T>();
} else if (a.lb == zero && zero == a.ub) {
return ::empty<T>();
return empty<T>();
}
}

Expand Down Expand Up @@ -214,6 +220,8 @@ inline constexpr __device__ interval<T> div(interval<T> x, interval<T> y)
template<typename T>
inline constexpr __device__ T mag(interval<T> x)
{
using std::max;

if (empty(x)) {
return intrinsic::nan<T>();
}
Expand All @@ -223,6 +231,8 @@ inline constexpr __device__ T mag(interval<T> x)
template<typename T>
inline constexpr __device__ T mig(interval<T> x)
{
using std::min;

// TODO: we might want to split up the function into the bare interval operation and this part.
// we could perhaps use a monad for either result or empty using expected?
if (empty(x)) {
Expand Down Expand Up @@ -264,6 +274,8 @@ inline constexpr __device__ interval<T> abs(interval<T> x)
template<typename T>
inline constexpr __device__ interval<T> max(interval<T> x, interval<T> y)
{
using std::max;

if (empty(x) || empty(y)) {
return empty<T>();
}
Expand All @@ -274,6 +286,8 @@ inline constexpr __device__ interval<T> max(interval<T> x, interval<T> y)
template<typename T>
inline constexpr __device__ interval<T> min(interval<T> x, interval<T> y)
{
using std::min;

if (empty(x) || empty(y)) {
return empty<T>();
}
Expand Down Expand Up @@ -565,6 +579,9 @@ inline constexpr __device__ interval<T> cancel_plus(interval<T> x, interval<T> y
template<typename T>
inline constexpr __device__ interval<T> intersection(interval<T> x, interval<T> y)
{
using std::max;
using std::min;

// extended
if (disjoint(x, y)) {
return empty<T>();
Expand All @@ -576,6 +593,9 @@ inline constexpr __device__ interval<T> intersection(interval<T> x, interval<T>
template<typename T>
inline constexpr __device__ interval<T> convex_hull(interval<T> x, interval<T> y)
{
using std::max;
using std::min;

// extended
if (empty(x)) {
return y;
Expand Down Expand Up @@ -770,9 +790,9 @@ inline constexpr __device__ interval<T> pown(interval<T> x, std::integral auto n
return empty<T>();
}

using intrinsic::next_after;
using intrinsic::next_floating;
using intrinsic::prev_floating;
using intrinsic::next_after;

if (n % 2) { // odd power
if (entire(x)) {
Expand Down Expand Up @@ -857,6 +877,8 @@ inline constexpr __device__ interval<T> pow_(interval<T> x, T y)
template<typename T>
inline constexpr __device__ interval<T> rootn(interval<T> x, std::integral auto n)
{
using std::pow;

if (empty(x)) {
return x;
}
Expand Down Expand Up @@ -941,6 +963,10 @@ inline constexpr __device__ unsigned int quadrant_pi(T v)
template<typename T>
inline constexpr __device__ interval<T> sin(interval<T> x)
{
using std::max;
using std::min;
using std::sin;

if (empty(x)) {
return x;
}
Expand Down Expand Up @@ -1016,6 +1042,10 @@ inline constexpr __device__ interval<T> sin(interval<T> x)
template<typename T>
inline constexpr __device__ interval<T> sinpi(interval<T> x)
{
using ::sinpi;
using std::max;
using std::min;

if (empty(x)) {
return x;
}
Expand Down Expand Up @@ -1064,6 +1094,10 @@ inline constexpr __device__ interval<T> sinpi(interval<T> x)
template<typename T>
inline constexpr __device__ interval<T> cos(interval<T> x)
{
using std::cos;
using std::max;
using std::min;

if (empty(x)) {
return x;
}
Expand Down Expand Up @@ -1114,6 +1148,10 @@ inline constexpr __device__ interval<T> cos(interval<T> x)
template<typename T>
inline constexpr __device__ interval<T> cospi(interval<T> x)
{
using ::cospi;
using std::max;
using std::min;

if (empty(x)) {
return x;
}
Expand Down Expand Up @@ -1161,6 +1199,8 @@ inline constexpr __device__ interval<T> cospi(interval<T> x)
template<typename T>
inline constexpr __device__ interval<T> tan(interval<T> x)
{
using std::tan;

if (empty(x)) {
return x;
}
Expand Down Expand Up @@ -1192,6 +1232,8 @@ inline constexpr __device__ interval<T> tan(interval<T> x)
template<typename T>
inline constexpr __device__ interval<T> asin(interval<T> x)
{
using std::asin;

if (empty(x)) {
return x;
}
Expand All @@ -1207,6 +1249,8 @@ inline constexpr __device__ interval<T> asin(interval<T> x)
template<typename T>
inline constexpr __device__ interval<T> acos(interval<T> x)
{
using std::acos;

if (empty(x)) {
return x;
}
Expand All @@ -1222,6 +1266,8 @@ inline constexpr __device__ interval<T> acos(interval<T> x)
template<typename T>
inline constexpr __device__ interval<T> atan(interval<T> x)
{
using std::atan;

if (empty(x)) {
return x;
}
Expand All @@ -1235,6 +1281,9 @@ inline constexpr __device__ interval<T> atan(interval<T> x)
template<typename T>
inline constexpr __device__ interval<T> atan2(interval<T> y, interval<T> x)
{
using std::abs;
using std::atan2;

if (empty(x) || empty(y)) {
return empty<T>();
}
Expand Down Expand Up @@ -1324,6 +1373,8 @@ inline constexpr __device__ interval<T> atan2(interval<T> y, interval<T> x)
template<typename T>
inline constexpr __device__ interval<T> sinh(interval<T> x)
{
using std::sinh;

if (empty(x)) {
return x;
}
Expand All @@ -1335,6 +1386,8 @@ inline constexpr __device__ interval<T> sinh(interval<T> x)
template<typename T>
inline constexpr __device__ interval<T> cosh(interval<T> x)
{
using std::cosh;

if (empty(x)) {
return x;
}
Expand All @@ -1348,6 +1401,8 @@ inline constexpr __device__ interval<T> cosh(interval<T> x)
template<typename T>
inline constexpr __device__ interval<T> tanh(interval<T> x)
{
using std::tanh;

if (empty(x)) {
return x;
}
Expand All @@ -1361,6 +1416,8 @@ inline constexpr __device__ interval<T> tanh(interval<T> x)
template<typename T>
inline constexpr __device__ interval<T> asinh(interval<T> x)
{
using std::asinh;

if (empty(x)) {
return x;
}
Expand All @@ -1372,6 +1429,8 @@ inline constexpr __device__ interval<T> asinh(interval<T> x)
template<typename T>
inline constexpr __device__ interval<T> acosh(interval<T> x)
{
using std::acosh;

if (empty(x)) {
return x;
}
Expand All @@ -1388,6 +1447,8 @@ inline constexpr __device__ interval<T> acosh(interval<T> x)
template<typename T>
inline constexpr __device__ interval<T> atanh(interval<T> x)
{
using std::atanh;

if (empty(x)) {
return x;
}
Expand Down Expand Up @@ -1465,8 +1526,8 @@ inline constexpr __device__ void mince(interval<T> x, interval<T> *xs, std::size
xs[i] = empty<T>();
}
} else {
T lb = x.lb;
T ub = x.ub;
T lb = x.lb;
T ub = x.ub;
T step = (ub - lb) / static_cast<T>(out_xs_size);

for (std::size_t i = 0; i < out_xs_size; i++) {
Expand All @@ -1475,4 +1536,6 @@ inline constexpr __device__ void mince(interval<T> x, interval<T> *xs, std::size
}
}

} // namespace cu

#endif // CUINTERVAL_ARITHMETIC_BASIC_CUH
4 changes: 2 additions & 2 deletions include/cuinterval/arithmetic/intrinsic.cuh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#ifndef CUINTERVAL_ARITHMETIC_INTRINSIC_CUH
#define CUINTERVAL_ARITHMETIC_INTRINSIC_CUH

namespace intrinsic
namespace cu::intrinsic
{
// clang-format off
template<typename T> inline __device__ T fma_down (T x, T y, T z);
Expand Down Expand Up @@ -104,6 +104,6 @@ namespace intrinsic
template<> inline __device__ float prev_floating(float x) { return nextafterf(x, intrinsic::neg_inf<float>()); }

// clang-format on
} // namespace intrinsic
} // namespace cu::intrinsic

#endif // CUINTERVAL_ARITHMETIC_INTRINSIC_CUH
25 changes: 25 additions & 0 deletions include/cuinterval/format.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef CUINTERVAL_FORMAT_H
#define CUINTERVAL_FORMAT_H

#include <cuinterval/interval.h>

#include <ostream>

namespace cu
{

template<typename T>
std::ostream &operator<<(std::ostream &os, interval<T> x)
{
return os << "[" << x.lb << ", " << x.ub << "]";
}

template<typename T>
std::ostream &operator<<(std::ostream &os, split<T> x)
{
return os << "[" << x.lower_half << ", " << x.upper_half << "]";
}

} // namespace cu

#endif
6 changes: 5 additions & 1 deletion include/cuinterval/interval.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#ifndef CUINTERVAL_INTERVAL_H
#define CUINTERVAL_INTERVAL_H

namespace cu
{

template<typename T>
struct interval
{
Expand All @@ -11,7 +14,6 @@ struct interval
template<typename T>
bool operator==(interval<T> lhs, interval<T> rhs)
{

auto empty = [](interval<T> x) { return !(x.lb <= x.ub); };

if (empty(lhs) && empty(rhs)) {
Expand All @@ -30,4 +32,6 @@ struct split
auto operator<=>(const split &) const = default;
};

} // namespace cu

#endif // CUINTERVAL_INTERVAL_H
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ set(headers
"${LIB_PATH}/arithmetic/intrinsic.cuh"
"${LIB_PATH}/arithmetic/basic.cuh"
"${LIB_PATH}/cuinterval.h"
"${LIB_PATH}/format.h"
)

add_library(cuinterval "main.cu" ${headers})
Expand Down
2 changes: 1 addition & 1 deletion tests/generated/tests_atan2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ void tests_atan2(cuda_buffer buffer, cudaStream_t stream, cudaEvent_t event) {
using namespace boost::ut;

using T = double;
using I = interval<T>;
using I = cu::interval<T>;
using B = bool;
using N = int;

Expand Down
Loading

0 comments on commit 2c3e194

Please sign in to comment.