Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement popcount, dot, normalize #27

Merged
merged 1 commit into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions include/simsycl/sycl/math.hh
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ using std::round;

template<detail::GenFloat T>
auto rsqrt(T x) {
return detail::component_wise_op(x, [](auto x) { return static_cast<decltype(x)>(1.0) / sqrt(x); });
return detail::component_wise_op(x, [](auto v) { return static_cast<decltype(v)>(1.0) / sqrt(v); });
}

using std::sin;
Expand Down Expand Up @@ -308,7 +308,13 @@ using std::min;
// TODO rotate
// TODO sub_sat
// TODO upsample
// TODO popcount

template<detail::GenInt T>
auto popcount(T x) {
return detail::component_wise_op(
x, [](auto v) { return std::popcount(static_cast<std::make_unsigned_t<decltype(v)>>(v)); });
}

// TODO mad24
// TODO mul24

Expand Down
12 changes: 10 additions & 2 deletions include/simsycl/sycl/math_geometric.hh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ namespace simsycl::sycl {
// the standard requires pass-by-value, but I'm not sure if this is visible to the user

// TODO cross
// TODO dot

template<detail::GeoFloat T1, detail::GeoFloat T2>
auto dot(const T1 &f, const T2 &g) {
return detail::sum(detail::to_matching_vec<T1>(f) * detail::to_matching_vec<T2>(g));
}

template<detail::GeoFloat T>
auto length(const T &f) {
Expand All @@ -20,7 +24,11 @@ auto distance(const T1 &p0, const T2 &p1) {
return length(p1 - p0);
}

// TODO normalize
template<detail::GeoFloat T>
auto normalize(const T &f) {
return detail::to_matching_vec<T>(f) / detail::to_matching_vec<T>(length(f));
}

// TODO fast_distance
// TODO fast_length
// TODO fast_normalize
Expand Down
124 changes: 123 additions & 1 deletion test/math_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

using namespace sycl;


TEST_CASE("Length function works as expected", "[math][geometric]") {
double x = 8.0f;
float y = 7.0f;
Expand Down Expand Up @@ -42,6 +43,45 @@ TEST_CASE("Length function works as expected", "[math][geometric]") {
#endif
}

TEST_CASE("Dot function works as expected", "[math][geometric]") {
float x = 8.0f;
double y = 7.0f;
CHECK(dot(x, x) == Catch::Approx(64.0f));
CHECK(dot(y, y) == Catch::Approx(49.0f));


vec<double, 2> v1 = {1.0, 2.0};
vec<float, 2> v2 = {3.0f, 4.0f};
vec<double, 3> v3 = {5.0, 6.0, 7.0};
vec<float, 3> v4 = {8.0f, 9.0f, 10.0f};
vec<double, 4> v5 = {11.0, 12.0, 13.0, 14.0};
vec<float, 4> v6 = {15.0f, 16.0f, 17.0f, 18.0f};

CHECK(dot(v1, v1) == Catch::Approx(5.0));
CHECK(dot(v2, v2) == Catch::Approx(25.0f));
CHECK(dot(v3, v3) == Catch::Approx(110.0));
CHECK(dot(v4, v4) == Catch::Approx(245.0f));
CHECK(dot(v5, v5) == Catch::Approx(630.0));
CHECK(dot(v6, v6) == Catch::Approx(1094.0f));
CHECK(dot(v1.xy(), v1.yx()) == Catch::Approx(4.0));
CHECK(dot(v2.xx(), v2.xy()) == Catch::Approx(21.0f));
CHECK(dot(v6.argb(), v2.xyxy()) == Catch::Approx(230.0f));
CHECK(dot(v3, v1.xyx()) == Catch::Approx(24.0));

marray<float, 4> m1 = {1.0f, 2.0f, 3.0f, 4.0f};
marray<float, 4> m2 = {5.0f, 6.0f, 7.0f, 8.0f};
CHECK(dot(m1, m2) == Catch::Approx(70.0f));

#if SIMSYCL_FEATURE_HALF_TYPE
using sycl::half;
auto vh1 = vec<half, 2>(half(1.0), half(2.0));
auto vh2 = vec<half, 2>(half(3.0), half(4.0));
CHECK(dot(vh1, vh1) == Catch::Approx(5.0f));
CHECK(dot(vh2, vh2) == Catch::Approx(25.0f));
CHECK(dot(vh1, vh2) == Catch::Approx(11.0f));
#endif
}

TEST_CASE("Distance function works as expected", "[math][geometric]") {
double x = 8.0f;
float y = 7.0f;
Expand All @@ -65,6 +105,56 @@ TEST_CASE("Distance function works as expected", "[math][geometric]") {
CHECK(distance(v6.argb(), v2.xyxy()) == Catch::Approx(26.15339f));
}

TEST_CASE("Normalize function works as expected", "[math][geometric]") {
double x = 8.0f;
float y = 7.0f;
vec<double, 2> v1 = {1.0, 2.0};
vec<float, 2> v2 = {3.0f, 4.0f};
vec<double, 3> v3 = {5.0, 6.0, 7.0};
vec<float, 3> v4 = {8.0f, 9.0f, 10.0f};
vec<double, 4> v5 = {11.0, 12.0, 13.0, 14.0};
vec<float, 4> v6 = {15.0f, 16.0f, 17.0f, 18.0f};

CHECK(normalize(x) == Catch::Approx(1.0));
CHECK(normalize(y) == Catch::Approx(1.0f));
auto v1n = normalize(v1);
CHECK(v1n.x() == Catch::Approx(0.4472135954999579));
CHECK(v1n.y() == Catch::Approx(0.8944271909999159));
auto v1xxn = normalize(v1.xx());
CHECK(v1xxn.x() == Catch::Approx(0.7071067811865475));
CHECK(v1xxn.y() == Catch::Approx(0.7071067811865475));
auto v2n = normalize(v2);
CHECK(v2n.x() == Catch::Approx(0.6f));
CHECK(v2n.y() == Catch::Approx(0.8f));
auto v2yxn = normalize(v2.yx());
CHECK(v2yxn.x() == Catch::Approx(0.8f));
CHECK(v2yxn.y() == Catch::Approx(0.6f));
auto v3n = normalize(v3);
CHECK(v3n.x() == Catch::Approx(0.4767312946227962));
CHECK(v3n.y() == Catch::Approx(0.5720775535473553));
CHECK(v3n.z() == Catch::Approx(0.6674238124719146));
auto v4n = normalize(v4);
CHECK(v4n.x() == Catch::Approx(0.5111012519999519f));
CHECK(v4n.y() == Catch::Approx(0.5749889084999459f));
CHECK(v4n.z() == Catch::Approx(0.6388765649999398f));
auto v5n = normalize(v5);
CHECK(v5n.x() == Catch::Approx(0.4382504900892777));
CHECK(v5n.y() == Catch::Approx(0.4780914437337575));
CHECK(v5n.z() == Catch::Approx(0.5179323973782373));
CHECK(v5n.w() == Catch::Approx(0.5577733510227171));
auto v6n = normalize(v6);
CHECK(v6n.r() == Catch::Approx(0.4535055413676754f));
CHECK(v6n.g() == Catch::Approx(0.4837392441255204f));
CHECK(v6n.b() == Catch::Approx(0.5139729468833655f));
CHECK(v6n.a() == Catch::Approx(0.5442066496412105f));
auto v6argnn = normalize(v6.argb());
CHECK(v6argnn.x() == Catch::Approx(0.5442066496412105f));
CHECK(v6argnn.y() == Catch::Approx(0.4535055413676754f));
CHECK(v6argnn.z() == Catch::Approx(0.4837392441255204f));
CHECK(v6argnn.w() == Catch::Approx(0.5139729468833655f));
}


TEST_CASE("Clamp function works as expected", "[math]") {
using simsycl::test::check_bool_vec;

Expand Down Expand Up @@ -92,4 +182,36 @@ TEST_CASE("Inverse square root function works as expected", "[math]") {
vec<double, 2> v1_result = rsqrt(v1);
CHECK(v1_result.x() == Catch::Approx(rsqrt(v1.x())));
CHECK(v1_result.y() == Catch::Approx(rsqrt(v1.y())));
}
}

TEST_CASE("Popcount function works as expected", "[math]") {
CHECK(popcount(static_cast<std::uint8_t>(-1)) == 8);
CHECK(popcount(static_cast<std::uint16_t>(-1)) == 16);
CHECK(popcount(static_cast<std::uint32_t>(-1)) == 32);
CHECK(popcount(static_cast<std::uint64_t>(-1)) == 64);

CHECK(popcount(static_cast<signed char>(0b101010)) == 3);
CHECK(popcount(static_cast<unsigned char>(0b111111)) == 6);
CHECK(popcount(static_cast<signed short>(0b101010)) == 3);
CHECK(popcount(static_cast<unsigned short>(0b111111)) == 6);
CHECK(popcount(static_cast<signed int>(0b101010)) == 3);
CHECK(popcount(static_cast<unsigned int>(0b111111)) == 6);
CHECK(popcount(static_cast<signed long>(0b101010)) == 3);
CHECK(popcount(static_cast<unsigned long>(0b111111)) == 6);
CHECK(popcount(static_cast<signed long long>(0b101010)) == 3);
CHECK(popcount(static_cast<unsigned long long>(0b111111)) == 6);

vec<int, 4> v1 = {0b101010, 0b111111, 0b101010, -2};
vec<int, 4> v1_result = popcount(v1);
CHECK(v1_result.x() == 3);
CHECK(v1_result.y() == 6);
CHECK(v1_result.z() == 3);
CHECK(v1_result.w() == 31);

vec<std::uint8_t, 3> v2 = {0b0, 0b1101, 0b1111};
auto v2_result = popcount(v2.xxyz());
CHECK(v2_result.x() == 0);
CHECK(v2_result.y() == 0);
CHECK(v2_result.z() == 3);
CHECK(v2_result.w() == 4);
}
Loading