From 4a730f17b1490edd61105ca9efdc7977102378a8 Mon Sep 17 00:00:00 2001 From: Brendan Barnes Date: Wed, 10 Jul 2024 07:16:25 +0000 Subject: [PATCH] save --- include/squint/dynamic_tensor.hpp | 3 + include/squint/fixed_tensor.hpp | 14 +- include/squint/linear_algebra.hpp | 336 +++++++++++++++++++++++++----- include/squint/quantity.hpp | 56 ++++- include/squint/tensor_view.hpp | 8 + main.cpp | 18 +- 6 files changed, 373 insertions(+), 62 deletions(-) diff --git a/include/squint/dynamic_tensor.hpp b/include/squint/dynamic_tensor.hpp index 0e863ea..a2ee570 100644 --- a/include/squint/dynamic_tensor.hpp +++ b/include/squint/dynamic_tensor.hpp @@ -2,6 +2,7 @@ #define SQUINT_DYNAMIC_TENSOR_HPP #include "squint/iterable_tensor.hpp" +#include "squint/quantity.hpp" #include "squint/tensor_base.hpp" #include "squint/tensor_view.hpp" #include @@ -18,6 +19,7 @@ class dynamic_tensor : public iterable_tensor, layout layout_; public: + using value_type = T; using iterable_tensor, T, ErrorChecking>::subviews; constexpr dynamic_tensor() = default; // virtual destructor @@ -125,6 +127,7 @@ class dynamic_tensor : public iterable_tensor, constexpr std::size_t size() const { return data_.size(); } constexpr std::vector shape() const { return shape_; } constexpr layout get_layout() const { return layout_; } + constexpr error_checking get_error_checking() const { return ErrorChecking; } std::vector strides() const { return calculate_strides(); } T &at_impl(const std::vector &indices) { return data_[calculate_index(indices)]; } diff --git a/include/squint/fixed_tensor.hpp b/include/squint/fixed_tensor.hpp index cbfbada..6001392 100644 --- a/include/squint/fixed_tensor.hpp +++ b/include/squint/fixed_tensor.hpp @@ -2,6 +2,7 @@ #define SQUINT_FIXED_TENSOR_HPP #include "squint/iterable_tensor.hpp" +#include "squint/linear_algebra.hpp" #include "squint/quantity.hpp" #include "squint/tensor_base.hpp" #include "squint/tensor_view.hpp" @@ -22,11 +23,13 @@ constexpr auto make_subviews_iterator(const BlockTensor & /*unused*/, Tensor &te // Fixed tensor implementation template -class fixed_tensor : public iterable_tensor, T, ErrorChecking> { +class fixed_tensor : public iterable_tensor, T, ErrorChecking>, + public fixed_linear_algebra_mixin, ErrorChecking> { static constexpr std::size_t total_size = (Dims * ...); std::array data_; public: + using value_type = T; using iterable_tensor, T, ErrorChecking>::subviews; // virtual destructor virtual ~fixed_tensor() = default; @@ -59,6 +62,7 @@ class fixed_tensor : public iterable_tensor strides() { auto strides_array = calculate_strides(); return std::vector(std::begin(strides_array), std::end(strides_array)); @@ -131,6 +135,14 @@ class fixed_tensor : public iterable_tensor(slices...); } + template auto as() const { + fixed_tensor result; + for (std::size_t i = 0; i < total_size; ++i) { + result.data()[i] = static_cast(data_[i]); + } + return result; + } + static constexpr fixed_tensor zeros() { fixed_tensor result; result.data_.fill(T{}); diff --git a/include/squint/linear_algebra.hpp b/include/squint/linear_algebra.hpp index 04f7bd9..c8c1c3e 100644 --- a/include/squint/linear_algebra.hpp +++ b/include/squint/linear_algebra.hpp @@ -1,86 +1,316 @@ #ifndef SQUINT_LINEAR_ALGEBRA_HPP #define SQUINT_LINEAR_ALGEBRA_HPP -#include "squint/dynamic_tensor.hpp" -#include "squint/fixed_tensor.hpp" #include "squint/quantity.hpp" #include "squint/tensor_base.hpp" -#include "squint/tensor_view.hpp" #include +#include namespace squint { -// Linear algebra mixin +// Base linear algebra mixin template class linear_algebra_mixin { public: - // Element-wise addition - template - constexpr auto operator+(const linear_algebra_mixin &other) const { - const auto &a = static_cast(*this); - const auto &b = static_cast(other); - - if constexpr (ErrorChecking == error_checking::enabled) { - if (a.shape() != b.shape()) { - throw std::invalid_argument("Incompatible shapes for addition"); + auto transpose(); + auto transpose() const; + auto determinant() const; + auto inv() const; + auto norm(int p = 2) const; + auto trace() const; + auto eigenvalues() const; + auto eigenvectors() const; + auto pinv() const; + + auto mean() const { return sum() / static_cast(this)->size(); } + auto sum() const { + auto result = typename Derived::value_type{}; + for (const auto &elem : *static_cast(this)) { + result += elem; + } + return result; + } + auto min() const { + auto result = *static_cast(this)->begin(); + for (const auto &elem : *static_cast(this)) { + if (elem < result) { + result = elem; + } + } + return result; + } + auto max() const { + auto result = *static_cast(this)->begin(); + for (const auto &elem : *static_cast(this)) { + if (elem > result) { + result = elem; } } + return result; + } +}; + +// Compile-time shape checks +template constexpr bool compatible_for_element_wise_op() { + auto min_dims = std::min(A::constexpr_shape().size(), B::constexpr_shape().size()); + for (std::size_t i = 0; i < min_dims; ++i) { + if (A::constexpr_shape()[i] != B::constexpr_shape()[i]) { + return false; + } + } + return A::size() == B::size(); +} + +template constexpr bool compatible_for_matmul() { + constexpr auto this_shape = A::constexpr_shape(); + constexpr auto other_shape = B::constexpr_shape(); + constexpr bool matmat = (A::rank() == 2) && (B::rank() == 2) && (this_shape[1] == other_shape[0]); + constexpr bool matvec = (A::rank() == 2) && (B::rank() == 1) && (this_shape[1] == other_shape[0]); + return matmat || matvec; +} + +// Fixed tensor with linear algebra +template +class fixed_linear_algebra_mixin : public linear_algebra_mixin { + public: + template auto &operator*=(const Scalar &scalar) { + for (auto &elem : *static_cast(this)) { + elem *= scalar; + } + return *static_cast(this); + } + + template auto &operator/=(const Scalar &scalar) { + for (auto &elem : *static_cast(this)) { + elem /= scalar; + } + return *static_cast(this); + } + + template auto &operator+=(const Other &other) { + static_assert(compatible_for_element_wise_op(), + "Incompatible shapes for element-wise addition"); + auto it = static_cast(this)->begin(); + for (const auto &elem : other) { + *it++ += elem; + } + return *static_cast(this); + } - if constexpr (fixed_shape_tensor && fixed_shape_tensor) { - constexpr auto shape = Derived::constexpr_shape(); - fixed_tensor result; - for (std::size_t i = 0; i < a.size(); ++i) { - result.at(i) = a.at(i) + b.at(i); + template auto &operator-=(const Other &other) { + static_assert(compatible_for_element_wise_op(), + "Incompatible shapes for element-wise subtraction"); + auto it = static_cast(this)->begin(); + for (const auto &elem : other) { + *it++ -= elem; + } + return *static_cast(this); + } + + template auto solve(const B &b) const { + // PA=LU factorization (Doolittle algorithm) + constexpr std::size_t count = Derived::constexpr_shape()[0]; + auto U = static_cast(this)->template as(); + using L_type = decltype(typename Derived::value_type{} / typename Derived::value_type{}); + auto L = static_cast(this)->template as(); + auto P = L; + for (std::size_t i = 0; i < count; ++i) { + P[i, i] = L_type{1}; + } + for (std::size_t k = 0; k < count - 1; ++k) { + // select index >= k to maximize |A[i][k]| + std::size_t index = k; + auto val = typename Derived::value_type{}; + for (std::size_t i = k; i < count; ++i) { + auto a_val = U[i, k]; + a_val = a_val < typename Derived::value_type(0) ? -a_val : a_val; + if (a_val > val) { + index = i; + val = a_val; + } + } + // Swap Rows + auto U_row = U.row(k); + U.row(k) = U.row(index); + U.row(index) = U_row; + auto L_row = L.row(k); + L.row(k) = L.row(index); + L.row(index) = L_row; + auto P_row = P.row(k); + P.row(k) = P.row(index); + P.row(index) = P_row; + // compute factorization + for (std::size_t j = k + 1; j < count; ++j) { + L[j, k] = U[j, k] / U[k, k]; + for (std::size_t i = k; i < count; ++i) { + U[j, i] = U[j, i] - L[j, k] * U[k, i]; + } + } + } + // fill diagonals of L with 1 + for (std::size_t i = 0; i < count; ++i) { + L[i, i] = L_type{1}; + } + // for each column in B, solve the system using forward and back substitution + using result_type = decltype(typename Derived::value_type{} * typename B::value_type{}); + auto result = b.template as(); + for (auto &col : result.cols()) { + // forward substitute + auto y = col.template as().template reshape; + auto b_col = col.template as().template reshape; + // permute b + b_col = P * b_col; + for (std::size_t i = 0; i < count; ++i) { + auto tmp = b_col[i]; + for (std::size_t j = 0; j < i; ++j) { + tmp -= L[i, j] * y[j]; + } + y[i] = tmp / L[i, i]; } - return result; - } else { - dynamic_tensor result(a.shape()); - for (std::size_t i = 0; i < a.size(); ++i) { - result.at(i) = a.at(i) + b.at(i); + // back substitute into column + for (int i = count - 1; i > -1; --i) { + auto tmp = y[i]; + for (std::size_t j = i + 1; j < count; ++j) { + tmp -= U[i, j] * col[j]; + } + col[i] = tmp / U[i, i]; } - return result; } + return result; } - // Scalar multiplication - constexpr auto operator*(const typename Derived::value_type &scalar) const { - const auto &a = static_cast(*this); + template auto solve_lls(const B &b) const; - if constexpr (fixed_shape_tensor) { - constexpr auto shape = Derived::constexpr_shape(); - fixed_tensor result; - for (std::size_t i = 0; i < a.size(); ++i) { - result.at(i) = a.at(i) * scalar; - } - return result; - } else { - dynamic_tensor result(a.shape()); - for (std::size_t i = 0; i < a.size(); ++i) { - result.at(i) = a.at(i) * scalar; + template auto operator/(const B &b) const; + + template bool operator==(const Other &other) const { + static_assert(compatible_for_element_wise_op(), + "Incompatible shapes for element-wise addition"); + auto it = static_cast(this)->begin(); + for (const auto &elem : other) { + if (*it++ != elem) { + return false; } - return result; } + return true; } + + template bool operator!=(const Other &other) const { return !(*this == other); } }; -// Fixed tensor with linear algebra -template -class fixed_tensor_with_la - : public fixed_tensor, - public linear_algebra_mixin, ErrorChecking> { - public: - using fixed_tensor::fixed_tensor; +// Element-wise operations between fixed tensors +template auto operator+(const A &a, const B &b) { + static_assert(compatible_for_element_wise_op(), "Incompatible shapes for element-wise addition"); + auto result = a; + auto it = result.begin(); + for (const auto &elem : b) { + *it++ += elem; + } + return result; +} - // Add any additional methods or overrides specific to fixed_tensor_with_la -}; +template auto operator-(const A &a, const B &b) { + static_assert(compatible_for_element_wise_op(), "Incompatible shapes for element-wise subtraction"); + auto result = a; + auto it = result.begin(); + for (const auto &elem : b) { + *it++ -= elem; + } + return result; +} + +// Matrix multiplication between fixed tensors +template auto operator*(const A &a, const B &b) { + static_assert(compatible_for_matmul(), "Incompatible shapes for matrix multiplication"); + constexpr auto other_rank = B::rank(); + constexpr auto this_shape = A::constexpr_shape(); + constexpr auto other_shape = B::constexpr_shape(); + // get result value_type + using result_value_type = decltype(typename A::value_type{} * typename B::value_type{}); + if constexpr (other_rank == 2) { + auto result = + fixed_tensor(); + for (std::size_t i = 0; i < this_shape[0]; ++i) { + for (std::size_t j = 0; j < other_shape[1]; ++j) { + for (std::size_t k = 0; k < this_shape[1]; ++k) { + result[i, j] += a[i, k] * b[k, j]; + } + } + } + return result; + } else { + auto result = fixed_tensor(); + for (std::size_t i = 0; i < this_shape[0]; ++i) { + for (std::size_t j = 0; j < this_shape[1]; ++j) { + result[i] += a[i, j] * b[j]; + } + } + return result; + } +} + +// Cross product between two 3D vectors +template auto cross(const A &a, const B &b) { + static_assert(A::size() == 3 && B::size() == 3, "Cross product requires 3D vectors"); + using result_value_type = decltype(typename A::value_type{} * typename B::value_type{}); + auto result = fixed_tensor(); + result[0] = a[1] * b[2] - a[2] * b[1]; + result[1] = a[2] * b[0] - a[0] * b[2]; + result[2] = a[0] * b[1] - a[1] * b[0]; + return result; +} + +template +bool approx_equal(const A &a, const B &b, const Epsilon &epsilon) { + static_assert(compatible_for_element_wise_op(), "Incompatible shapes for element-wise subtraction"); + auto it = a.begin(); + for (const auto &elem : b) { + if (!approx_equal(*it++, elem, epsilon)) { + return false; + } + } + return true; +} + +// Runtime shape checks +template bool compatible_for_element_wise_op(const A &a, const B &b) { + auto min_dims = std::min(a.rank(), b.rank()); + for (std::size_t i = 0; i < min_dims; ++i) { + if (a.shape()[i] != b.shape()[i]) { + return false; + } + } + return a.size() == b.size(); +} + +template bool compatible_for_matmul(const A &a, const B &b) { + const auto this_shape = a.shape(); + const auto other_shape = b.shape(); + const bool matmat = (a.rank() == 2) && (b.rank() == 2) && (this_shape[1] == other_shape[0]); + const bool matvec = (a.rank() == 2) && (b.rank() == 1) && (this_shape[1] == other_shape[0]); + return matmat || matvec; +} // Dynamic tensor with linear algebra -template -class dynamic_tensor_with_la : public dynamic_tensor, - public linear_algebra_mixin, ErrorChecking> { +template +class dynamic_linear_algebra_mixin : public linear_algebra_mixin { public: - using dynamic_tensor::dynamic_tensor; + template auto &operator*=(const Scalar &scalar); + + template auto &operator/=(const Scalar &scalar); + + template auto &operator+=(const Other &other); + + template auto &operator-=(const Other &other); + + template auto solve(const B &b) const; + + template auto solve_lls(const B &b) const; + + template auto operator/(const B &b) const; + + template bool operator==(const Other &other) const; - // Add any additional methods or overrides specific to dynamic_tensor_with_la + template bool operator!=(const Other &other) const; }; } // namespace squint diff --git a/include/squint/quantity.hpp b/include/squint/quantity.hpp index 0ceb928..7243f43 100644 --- a/include/squint/quantity.hpp +++ b/include/squint/quantity.hpp @@ -9,6 +9,7 @@ #define SQUINT_QUANTITY_HPP #include "squint/dimension.hpp" +#include #include #include #include @@ -67,8 +68,8 @@ template ) : value_(static_cast(value)) {} + // Allow explicit cast to other types + template explicit operator U() const noexcept { return U(value_); } + // Explicit conversion operator for non-dimensionless quantities explicit constexpr operator T() const noexcept requires(!std::is_same_v) @@ -181,13 +185,19 @@ template constexpr quantity &operator*=(const U &scalar) { + template + constexpr quantity &operator*=(const U &scalar) + requires(arithmetic || std::is_same_v) + { check_overflow_multiply(value_, scalar); value_ *= scalar; return *this; } - template constexpr quantity &operator/=(const U &scalar) { + template + constexpr quantity &operator/=(const U &scalar) + requires(arithmetic || std::is_same_v) + { check_division_by_zero(scalar); check_underflow_divide(value_, scalar); value_ /= scalar; @@ -407,6 +417,44 @@ concept quantitative = requires { typename T::dimension_type; }; +// approx equal for arithmetic types +template + requires arithmetic +bool approx_equal(T a, T b, T epsilon = 128 * 1.192092896e-04, T abs_th = std::numeric_limits::epsilon()) { + assert(std::numeric_limits::epsilon() <= epsilon); + assert(epsilon < 1.F); + + if (a == b) + return true; + + auto diff = std::abs(a - b); + auto norm = std::min((std::abs(a) + std::abs(b)), std::numeric_limits::max()); + return diff < std::max(abs_th, epsilon * norm); +} + +// approx equal for quantities +template +bool approx_equal(const T &a, const U &b, const Epsilon &epsilon = Epsilon{128 * 1.192092896e-04}) { + static_assert(std::is_same_v, + "Quantities must have the same dimension"); + return approx_equal(a.value(), b.value(), epsilon); +} + +// approx equal for mixed types +template +bool approx_equal(const T &a, const U &b, const V &epsilon = V{128 * 1.192092896e-04}) + requires std::is_same_v +{ + return approx_equal(a.value(), b, epsilon); +} + +template +bool approx_equal(const T &a, const U &b, const V &epsilon = V{128 * 1.192092896e-04}) + requires std::is_same_v +{ + return approx_equal(a, b.value(), epsilon); +} + namespace units { // Base unit type diff --git a/include/squint/tensor_view.hpp b/include/squint/tensor_view.hpp index 2d747b2..3ac77b0 100644 --- a/include/squint/tensor_view.hpp +++ b/include/squint/tensor_view.hpp @@ -132,6 +132,14 @@ class fixed_tensor_view_base : public tensor_view_base auto as() const { + fixed_tensor result; + for (std::size_t i = 0; i < Derived::size(); ++i) { + result.data()[i] = static_cast(data_[i]); + } + return result; + } + template auto reshape() const { static_assert((NewDims * ...) == size(), "New shape must have the same total size"); using new_strides = compile_time_strides; diff --git a/main.cpp b/main.cpp index a6fd549..0985f17 100644 --- a/main.cpp +++ b/main.cpp @@ -1,7 +1,17 @@ -#include "squint/dynamic_tensor.hpp" -#include "squint/fixed_tensor.hpp" +#include "squint/quantity.hpp" +#include "squint/tensor.hpp" #include #include -#include -int main() { return 0; } \ No newline at end of file +using namespace squint; +using namespace squint::units; + +void func(length t) { std::cout << "float: " << t << std::endl; } + +int main() { + + mat3 a({1, 2, 3, 4, 5, 6, 7, 8, 9}); + vec3 b({1, 2, 3}); + // a.solve(b); + return 0; +} \ No newline at end of file