Skip to content

Commit

Permalink
Implement length and distance, + a few fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterTh committed Dec 28, 2023
1 parent b7e9858 commit a8d5f69
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 14 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ add_library(simsycl
include/simsycl/sycl/item.hh
include/simsycl/sycl/kernel.hh
include/simsycl/sycl/math.hh
include/simsycl/sycl/math_geometric.hh
include/simsycl/sycl/marray.hh
include/simsycl/sycl/multi_ptr.hh
include/simsycl/sycl/nd_item.hh
Expand Down
1 change: 1 addition & 0 deletions include/simsycl/sycl.hh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "sycl/kernel.hh"
#include "sycl/marray.hh"
#include "sycl/math.hh"
#include "sycl/math_geometric.hh"
#include "sycl/multi_ptr.hh"
#include "sycl/nd_item.hh"
#include "sycl/nd_range.hh"
Expand Down
14 changes: 1 addition & 13 deletions include/simsycl/sycl/math.hh
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ SIMSYCL_DETAIL_MATH_DEFINE_BINARY_COMPONENT_WISE_VEC_FUNCTION(exp)
SIMSYCL_DETAIL_MATH_DEFINE_UNARY_COMPONENT_WISE_VEC_FUNCTION(exp2)
// TODO exp10
SIMSYCL_DETAIL_MATH_DEFINE_UNARY_COMPONENT_WISE_VEC_FUNCTION(fabs)
SIMSYCL_DETAIL_MATH_DEFINE_UNARY_COMPONENT_WISE_VEC_FUNCTION(fdim)
SIMSYCL_DETAIL_MATH_DEFINE_BINARY_COMPONENT_WISE_VEC_FUNCTION(fdim)
SIMSYCL_DETAIL_MATH_DEFINE_UNARY_COMPONENT_WISE_VEC_FUNCTION(floor)
SIMSYCL_DETAIL_MATH_DEFINE_BINARY_COMPONENT_WISE_VEC_FUNCTION(fma)
SIMSYCL_DETAIL_MATH_DEFINE_BINARY_COMPONENT_WISE_VEC_FUNCTION(fmax)
Expand Down Expand Up @@ -317,18 +317,6 @@ SIMSYCL_DETAIL_MATH_DEFINE_BINARY_COMPONENT_WISE_VEC_FUNCTION(min)

namespace simsycl::sycl {

// TODO dot
// TODO distance
// TODO length
// TODO normalize
// TODO fast_distance
// TODO fast_length
// TODO fast_normalize

}

namespace simsycl::sycl {

// TODO isequal
// TODO isnotequal
using std::isfinite;
Expand Down
75 changes: 75 additions & 0 deletions include/simsycl/sycl/math_geometric.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#pragma once

#include "math.hh"
#include "vec.hh"

namespace simsycl::detail {

template<typename T>
concept SyclFloat = std::is_same_v<T, float> || std::is_same_v<T, double>
#if SIMSYCL_FEATURE_HALF_TYPE
|| std::is_same_v<T, sycl::half>
#endif
;

template<typename T>
concept GeoFloat = //
SyclFloat<T>
|| ((is_swizzle_v<T> || is_vec_v<T>)&&(num_elements_v<T> > 0 && num_elements_v<T> <= 4)
&& SyclFloat<typename T::element_type>); // TODO: marray

template<typename T>
requires(is_vec_v<T> || is_swizzle_v<T>)
auto sum(const T &f) {
auto ret = f[0];
for(int i = 1; i < num_elements_v<T>; ++i) { ret += f[i]; }
return ret;
}
template<SyclFloat T>
auto sum(const T &f) {
return f;
}

template<GeoFloat T>
struct element_type {
using type = T;
};
template<GeoFloat T>
requires(is_vec_v<T> || is_swizzle_v<T>)
struct element_type<T> {
using type = typename T::element_type;
};
template<GeoFloat T>
using element_type_t = typename element_type<T>::type;

template<typename VT, typename T>
auto to_matching_vec(const T &v) {
return detail::to_vec<detail::element_type_t<VT>, detail::num_elements_v<VT>>(v);
}

} // namespace simsycl::detail

namespace simsycl::sycl {

// Note: arguments are passed by const ref rather than value to avoid gcc warnings
// the standard requires pass-by-value, but I'm not sure if this is visible to the user

// TODO cross
// TODO dot

template<detail::GeoFloat T>
auto length(const T &f) {
return sqrt(detail::sum(pow(detail::to_matching_vec<T>(f), detail::to_matching_vec<T>(2))));
}

template<detail::GeoFloat T1, detail::GeoFloat T2>
auto distance(const T1 &p0, const T2 &p1) {
return length(p1 - p0);
}

// TODO normalize
// TODO fast_distance
// TODO fast_length
// TODO fast_normalize

} // namespace simsycl::sycl
2 changes: 1 addition & 1 deletion include/simsycl/sycl/vec.hh
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ template<typename T, int... Indices>
struct num_elements<detail::swizzled_vec<T, Indices...>> : std::integral_constant<int, sizeof...(Indices)> {};

template<typename VecOrSwizzle>
static constexpr bool num_elements_v = num_elements<VecOrSwizzle>::value;
static constexpr int num_elements_v = num_elements<VecOrSwizzle>::value;


template<int... Is>
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ add_executable(tests
hierarchical_tests.cc
launch_tests.cc
marray_tests.cc
math_tests.cc
reduction_tests.cc
vec_tests.cc
)
Expand Down
61 changes: 61 additions & 0 deletions test/math_tests.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#include <catch2/catch_approx.hpp>
#include <catch2/catch_test_macros.hpp>

#define SYCL_SIMPLE_SWIZZLES
#include <sycl/sycl.hpp>

using namespace sycl;

TEST_CASE("Length function works as expected", "[math][geometric]") {
double x = 8.0f;
float y = 7.0f;
vec<double, 2> v1 = {1.0, 2.0};
vec<float, 2> v2 = {3.0f, 4.0f};
vec<double, 3> v3 = {5.0, 6.0, 7.0};
vec<float, 3> v4 = {8.0f, 9.0f, 10.0f};
vec<double, 4> v5 = {11.0, 12.0, 13.0, 14.0};
vec<float, 4> v6 = {15.0f, 16.0f, 17.0f, 18.0f};

CHECK(length(x) == Catch::Approx(8.0));
CHECK(length(y) == Catch::Approx(7.0f));
CHECK(length(v1) == Catch::Approx(2.23606797749979));
CHECK(length(v2) == Catch::Approx(5.0f));
CHECK(length(v3) == Catch::Approx(10.488088481701515));
CHECK(length(v4) == Catch::Approx(15.6524758425f));
CHECK(length(v5) == Catch::Approx(25.099800796));
CHECK(length(v6) == Catch::Approx(33.07567f));
CHECK(length(v1.xx()) == Catch::Approx(1.4142135624));
CHECK(length(v6.argb()) == Catch::Approx(33.07567f));

#if SIMSYCL_FEATURE_HALF_TYPE
using sycl::half;
half h = 7.0f;
vec<half, 2> vh1 = vec<half, 2>(3.0f, 4.0f);
CHECK(length(h) == Catch::Approx(7.0f));
CHECK(length(vh1) == Catch::Approx(5.0f));
CHECK(length(vh1.yx()) == Catch::Approx(5.0f));
#endif
}

TEST_CASE("Distance function works as expected", "[math][geometric]") {
double x = 8.0f;
float y = 7.0f;
vec<double, 2> v1 = {1.0, 2.0};
vec<float, 2> v2 = {3.0f, 4.0f};
vec<double, 3> v3 = {5.0, 6.0, 7.0};
vec<float, 3> v4 = {8.0f, 9.0f, 10.0f};
vec<double, 4> v5 = {11.0, 12.0, 13.0, 14.0};
vec<float, 4> v6 = {15.0f, 16.0f, 17.0f, 18.0f};

CHECK(distance(x, 2.0) == Catch::Approx(6.0));
CHECK(distance(y, -1.0f) == Catch::Approx(8.0f));
CHECK(distance(v1, vec<double, 2>{0.0, 0.0}) == Catch::Approx(2.23606797749979));
CHECK(distance(v2, vec<float, 2>{0.0f, 0.0f}) == Catch::Approx(5.0f));
CHECK(distance(v3, vec<double, 3>{0.0, 0.0, 0.0}) == Catch::Approx(10.488088481701515));
CHECK(distance(v4, vec<float, 3>{0.0f, 0.0f, 0.0f}) == Catch::Approx(15.6524758425f));
CHECK(distance(v5, vec<double, 4>{0.0, 0.0, 0.0, 0.0}) == Catch::Approx(25.099800796));
CHECK(distance(v6, vec<float, 4>{0.0f, 0.0f, 0.0f, 0.0f}) == Catch::Approx(33.07567f));
CHECK(distance(v1.xx(), v1.yx()) == Catch::Approx(1.0));
CHECK(distance(v6.argb(), vec<float, 4>{0.0f, 0.0f, 0.0f, 0.0f}) == Catch::Approx(33.07567f));
CHECK(distance(v6.argb(), v2.xyxy()) == Catch::Approx(26.15339f));
}
8 changes: 8 additions & 0 deletions test/vec_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ bool check_bool_vec(sycl::vec<bool, Dimensions> a) {
return true;
}

TEST_CASE("Compile time vector operations work as expected", "[vec]") {
CHECK(detail::num_elements_v<float> == 1);
CHECK(detail::num_elements_v<sycl::vec<float, 1>> == 1);
CHECK(detail::num_elements_v<sycl::vec<double, 2>> == 2);
CHECK(detail::num_elements_v<sycl::vec<int, 3>> == 3);
CHECK(detail::num_elements_v<sycl::vec<float, 4>> == 4);
}

TEST_CASE("Basic vector operations work as expected", "[vec]") {
auto vi1 = sycl::vec<int, 1>(1);

Expand Down

0 comments on commit a8d5f69

Please sign in to comment.