Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
barne856 committed Jul 12, 2024
1 parent 16252e1 commit e142faf
Showing 2 changed files with 174 additions and 90 deletions.
222 changes: 137 additions & 85 deletions include/squint/linear_algebra.hpp
Original file line number Diff line number Diff line change
@@ -68,6 +68,29 @@ template <fixed_shape_tensor A, fixed_shape_tensor B> constexpr bool compatible_
return matmat || matvec;
}

template <fixed_shape_tensor A, fixed_shape_tensor B>
static constexpr bool compatible_for_type()
requires((quantitative<typename B::value_type> || arithmetic<typename B::value_type>) &&
(quantitative<typename A::value_type> || arithmetic<typename A::value_type>))
{
// Type compatibility check
if constexpr (quantitative<typename A::value_type> && quantitative<typename B::value_type>) {
static_assert(std::is_same_v<typename A::value_type::value_type, typename B::value_type::value_type>,
"A and B underlying types must match");
} else if constexpr (quantitative<typename A::value_type> && arithmetic<typename B::value_type>) {
static_assert(std::is_same_v<typename A::value_type::value_type, typename B::value_type>,
"A and B underlying types must match");
} else if constexpr (arithmetic<typename A::value_type> && quantitative<typename B::value_type>) {
static_assert(std::is_same_v<typename A::value_type, typename B::value_type::value_type>,
"A and B underlying types must match");
} else {
static_assert(std::is_same_v<typename A::value_type, typename B::value_type>,
"A and B underlying types must match");
}

return true;
}

template <fixed_shape_tensor A, fixed_shape_tensor B>
static constexpr bool compatible_for_solve()
requires((quantitative<typename B::value_type> || arithmetic<typename B::value_type>) &&
@@ -105,19 +128,7 @@ static constexpr bool compatible_for_solve()
}

// Type compatibility check
if constexpr (quantitative<typename A::value_type> && quantitative<typename B::value_type>) {
static_assert(std::is_same_v<typename A::value_type::value_type, typename B::value_type::value_type>,
"A and B underlying types must match");
} else if constexpr (quantitative<typename A::value_type> && arithmetic<typename B::value_type>) {
static_assert(std::is_same_v<typename A::value_type::value_type, typename B::value_type>,
"A and B underlying types must match");
} else if constexpr (arithmetic<typename A::value_type> && quantitative<typename B::value_type>) {
static_assert(std::is_same_v<typename A::value_type, typename B::value_type::value_type>,
"A and B underlying types must match");
} else {
static_assert(std::is_same_v<typename A::value_type, typename B::value_type>,
"A and B underlying types must match");
}
compatible_for_type<A, B>();

return true;
}
@@ -160,45 +171,6 @@ class fixed_linear_algebra_mixin : public linear_algebra_mixin<Derived, ErrorChe
return *static_cast<Derived *>(this);
}

template <fixed_shape_tensor B> auto solve(B &b) const {
static_assert(compatible_for_solve<Derived, B>(), "Incompatible types or shapes for solving linear system");

constexpr auto a_shape = Derived::constexpr_shape();
constexpr auto b_shape = B::constexpr_shape();
constexpr auto a_strides = Derived::constexpr_strides();
constexpr auto b_strides = B::constexpr_strides();

constexpr int n = a_shape[0];
constexpr int nrhs = (B::rank() == 1) ? 1 : b_shape[1];

// Determine leading dimensions based on layout
constexpr int lda = (Derived::get_layout() == layout::row_major) ? a_strides[0] : a_strides[1];
constexpr int ldb = (B::get_layout() == layout::row_major) ? ((B::rank() == 1) ? 1 : b_strides[0])
: ((B::rank() == 1) ? n : b_strides[1]);

std::vector<int> ipiv(n);

// Determine LAPACK layout
int lapack_layout = (Derived::get_layout() == layout::row_major) ? LAPACK_ROW_MAJOR : LAPACK_COL_MAJOR;

int info;
if constexpr (std::is_same_v<decltype(b.raw_data()), float *>) {
info = LAPACKE_sgesv(lapack_layout, n, nrhs,
const_cast<float *>(static_cast<Derived const *>(this)->raw_data()), lda, ipiv.data(),
b.raw_data(), ldb);
} else {
info = LAPACKE_dgesv(lapack_layout, n, nrhs,
const_cast<double *>(static_cast<Derived const *>(this)->raw_data()), lda, ipiv.data(),
b.raw_data(), ldb);
}

if (info != 0) {
throw std::runtime_error("LAPACKE_gesv failed with error code " + std::to_string(info));
}
}

template <fixed_shape_tensor B> auto solve_lls(const B &b) const;

template <fixed_shape_tensor B> auto operator/(const B &b) const;

template <fixed_shape_tensor Other> bool operator==(const Other &other) const {
@@ -240,31 +212,121 @@ template <fixed_shape_tensor A, fixed_shape_tensor B> auto operator-(const A &a,
// Matrix multiplication between fixed tensors
template <fixed_shape_tensor A, fixed_shape_tensor B> auto operator*(const A &a, const B &b) {
static_assert(compatible_for_matmul<A, B>(), "Incompatible shapes for matrix multiplication");
constexpr auto other_rank = B::rank();
constexpr auto this_shape = A::constexpr_shape();
constexpr auto other_shape = B::constexpr_shape();
// get result value_type
using result_value_type = decltype(typename A::value_type{} * typename B::value_type{});
if constexpr (other_rank == 2) {
auto result =
fixed_tensor<result_value_type, A::get_layout(), A::get_error_checking(), this_shape[0], other_shape[1]>();
for (std::size_t i = 0; i < this_shape[0]; ++i) {
for (std::size_t j = 0; j < other_shape[1]; ++j) {
for (std::size_t k = 0; k < this_shape[1]; ++k) {
result[i, j] += a[i, k] * b[k, j];
}
}
}
return result;
static_assert(compatible_for_type<A, B>(), "Incompatible types for BLAS matrix multiplication");
// matmul using BLAS
constexpr auto a_shape = A::constexpr_shape();
constexpr auto b_shape = B::constexpr_shape();
constexpr auto a_strides = A::constexpr_strides();
constexpr auto b_strides = B::constexpr_strides();

constexpr int m = a_shape[0];
constexpr int n = a_shape[1];
constexpr int k = b_shape[1];

// Determine leading dimensions based on layout
constexpr int lda = (A::get_layout() == layout::row_major) ? a_strides[0] : a_strides[1];
constexpr int ldb = (B::get_layout() == layout::row_major) ? b_strides[0] : b_strides[1];
constexpr int ldc = (A::get_layout() == layout::row_major) ? a_strides[0] : a_strides[1];

auto result = fixed_tensor<decltype(typename A::value_type{} * typename B::value_type{}), A::get_layout(),
A::get_error_checking(), m, k>();

// Determine BLAS layout
auto layout = (A::get_layout() == layout::row_major) ? CBLAS_ORDER::CblasRowMajor : CBLAS_ORDER::CblasColMajor;

if constexpr (std::is_same_v<decltype(a.raw_data()), float *> ||
std::is_same_v<decltype(a.raw_data()), const float *>) {
cblas_sgemm(layout, CblasNoTrans, CblasNoTrans, m, k, n, 1.0F, const_cast<float *>(a.raw_data()), lda,
const_cast<float *>(b.raw_data()), ldb, 0.0F, result.raw_data(), ldc);
} else {
auto result = fixed_tensor<result_value_type, A::get_layout(), A::get_error_checking(), this_shape[0]>();
for (std::size_t i = 0; i < this_shape[0]; ++i) {
for (std::size_t j = 0; j < this_shape[1]; ++j) {
result[i] += a[i, j] * b[j];
}
}
return result;
cblas_dgemm(layout, CblasNoTrans, CblasNoTrans, m, k, n, 1.0, const_cast<double *>(a.raw_data()), lda,
const_cast<double *>(b.raw_data()), ldb, 0.0, result.raw_data(), ldc);
}

return result;
}

// Solve linear system of equations Ax = b
template <fixed_shape_tensor A, fixed_shape_tensor B> auto solve(A &a, B &b) {
static_assert(compatible_for_solve<A, B>(), "Incompatible types or shapes for solving linear system");

constexpr auto a_shape = A::constexpr_shape();
constexpr auto b_shape = B::constexpr_shape();
constexpr auto a_strides = A::constexpr_strides();
constexpr auto b_strides = B::constexpr_strides();

constexpr int n = a_shape[0];
constexpr int nrhs = (B::rank() == 1) ? 1 : b_shape[1];

// Determine leading dimensions based on layout
constexpr int lda = (A::get_layout() == layout::row_major) ? a_strides[0] : a_strides[1];
constexpr int ldb = (B::get_layout() == layout::row_major) ? ((B::rank() == 1) ? 1 : b_strides[0])
: ((B::rank() == 1) ? n : b_strides[1]);

std::vector<int> ipiv(n);

// Determine LAPACK layout
int lapack_layout = (A::get_layout() == layout::row_major) ? LAPACK_ROW_MAJOR : LAPACK_COL_MAJOR;

int info;
if constexpr (std::is_same_v<decltype(b.raw_data()), float *> ||
std::is_same_v<decltype(b.raw_data()), const float *>) {
info = LAPACKE_sgesv(lapack_layout, n, nrhs, const_cast<float *>(a.raw_data()), lda, ipiv.data(), b.raw_data(),
ldb);
} else {
info = LAPACKE_dgesv(lapack_layout, n, nrhs, const_cast<double *>(a.raw_data()), lda, ipiv.data(), b.raw_data(),
ldb);
}

if (info != 0) {
throw std::runtime_error("LAPACKE_gesv failed with error code " + std::to_string(info));
}
return ipiv;
}

// Solve linear least squares problem Ax = b
template <fixed_shape_tensor A, fixed_shape_tensor B> auto solve_lls(A &a, const B &b) {
static_assert(A::rank() == 2, "Matrix A must be 2-dimensional");
static_assert(B::rank() == 1 || B::rank() == 2, "B must be 1D or 2D");
static_assert(A::constexpr_shape()[0] == B::constexpr_shape()[0],
"Matrix A and vector/matrix B dimensions must match");

constexpr auto a_shape = A::constexpr_shape();
constexpr auto b_shape = B::constexpr_shape();
constexpr auto a_strides = A::constexpr_strides();
constexpr auto b_strides = B::constexpr_strides();

constexpr int m = a_shape[0];
constexpr int n = a_shape[1];
constexpr int nrhs = (B::rank() == 1) ? 1 : b_shape[1];

// Determine leading dimensions based on layout
constexpr int lda = (A::get_layout() == layout::row_major) ? a_strides[0] : a_strides[1];
constexpr int ldb = (B::get_layout() == layout::row_major) ? ((B::rank() == 1) ? 1 : b_strides[0])
: ((B::rank() == 1) ? std::max(m, n) : b_strides[1]);

// Determine LAPACK layout
int lapack_layout = (A::get_layout() == layout::row_major) ? LAPACK_ROW_MAJOR : LAPACK_COL_MAJOR;

// Create a copy of b to store the solution, ensuring it's large enough for any case
auto x = fixed_tensor < typename B::value_type, B::get_layout(), B::get_error_checking(), std::max(m, n),
(B::rank() == 1) ? 1 : nrhs > ::zeros();
std::copy(b.begin(), b.end(), x.begin());

int info;
if constexpr (std::is_same_v<decltype(a.raw_data()), float *> ||
std::is_same_v<decltype(a.raw_data()), const float *>) {
info = LAPACKE_sgels(lapack_layout, 'N', m, n, nrhs, const_cast<float *>(a.raw_data()), lda, x.raw_data(), ldb);
} else {
info =
LAPACKE_dgels(lapack_layout, 'N', m, n, nrhs, const_cast<double *>(a.raw_data()), lda, x.raw_data(), ldb);
}

if (info != 0) {
throw std::runtime_error("LAPACKE_gels failed with error code " + std::to_string(info));
}

return x;
}

// Cross product between two 3D vectors
@@ -314,21 +376,11 @@ template <typename Derived, error_checking ErrorChecking>
class dynamic_linear_algebra_mixin : public linear_algebra_mixin<Derived, ErrorChecking> {
public:
template <typename Scalar> auto &operator*=(const Scalar &scalar);

template <typename Scalar> auto &operator/=(const Scalar &scalar);

template <tensor Other> auto &operator+=(const Other &other);

template <tensor Other> auto &operator-=(const Other &other);

template <tensor B> auto solve(B &b) const;

template <tensor B> auto solve_lls(const B &b) const;

template <tensor B> auto operator/(const B &b) const;

template <tensor Other> bool operator==(const Other &other) const;

template <tensor Other> bool operator!=(const Other &other) const;
};

42 changes: 37 additions & 5 deletions main.cpp
Original file line number Diff line number Diff line change
@@ -3,16 +3,48 @@
#include "squint/quantity.hpp"
#include "squint/tensor.hpp"
#include <iostream>
#include <vector>

using namespace squint;
using namespace squint::units;

int main() {
const auto A = tens::arange({4, 4}, 1);
std::cout << A << std::endl;
for (const auto &view : A.subviews({2, 2})) {
std::cout << view << std::endl;
// Test overdetermined system: 4 equations, 3 unknowns
{
auto A = fixed_tensor<double, layout::row_major, error_checking::disabled, 4, 3>::random();
auto b = fixed_tensor<double, layout::row_major, error_checking::disabled, 4>::random();

std::cout << "Overdetermined system:" << std::endl;
std::cout << "A = " << A << std::endl;
std::cout << "b = " << b << std::endl;

auto x = solve_lls(A, b);
std::cout << "Solution x = " << x << std::endl << std::endl;
}

// Test underdetermined system: 3 equations, 4 unknowns
{
auto A = fixed_tensor<double, layout::row_major, error_checking::disabled, 3, 4>::random();
auto b = fixed_tensor<double, layout::row_major, error_checking::disabled, 3>::random();

std::cout << "Underdetermined system:" << std::endl;
std::cout << "A = " << A << std::endl;
std::cout << "b = " << b << std::endl;

auto x = solve_lls(A, b);
std::cout << "Solution x = " << x << std::endl << std::endl;
}

// Test exactly determined system: 3 equations, 3 unknowns
{
auto A = fixed_tensor<double, layout::row_major, error_checking::disabled, 3, 3>::random();
auto b = fixed_tensor<double, layout::row_major, error_checking::disabled, 3>::random();

std::cout << "Exactly determined system:" << std::endl;
std::cout << "A = " << A << std::endl;
std::cout << "b = " << b << std::endl;

auto x = solve_lls(A, b);
std::cout << "Solution x = " << x << std::endl;
}

return 0;

0 comments on commit e142faf

Please sign in to comment.