From 79c98fa66f069d5603e2a561503de1f8dbdcfce5 Mon Sep 17 00:00:00 2001 From: Brendan Barnes Date: Wed, 10 Jul 2024 02:35:45 +0000 Subject: [PATCH] save --- include/squint/dynamic_tensor.hpp | 77 ++++++++++++++- include/squint/fixed_tensor.hpp | 153 +++++++++++++++++++++++++++--- 2 files changed, 217 insertions(+), 13 deletions(-) diff --git a/include/squint/dynamic_tensor.hpp b/include/squint/dynamic_tensor.hpp index 1602e80..0e863ea 100644 --- a/include/squint/dynamic_tensor.hpp +++ b/include/squint/dynamic_tensor.hpp @@ -162,6 +162,20 @@ class dynamic_tensor : public iterable_tensor, return view().subview(slices...); } + auto subview(const std::vector &slices) { + if constexpr (ErrorChecking == error_checking::enabled) { + this->check_subview_bounds(slices); + } + return view().subview(slices); + } + + auto subview(const std::vector &slices) const { + if constexpr (ErrorChecking == error_checking::enabled) { + this->check_subview_bounds(slices); + } + return view().subview(slices); + } + static dynamic_tensor zeros(const std::vector &shape, layout l = layout::column_major) { dynamic_tensor result(shape, l); result.fill(T{}); @@ -307,6 +321,62 @@ class dynamic_tensor : public iterable_tensor, return this->subviews(col_shape); } + auto row(std::size_t index) { + if constexpr (ErrorChecking == error_checking::enabled) { + if (index >= shape_[0]) { + throw std::out_of_range("Row index out of range"); + } + } + std::vector row_slices(shape_.size()); + row_slices[0] = slice{index, 1}; + for (std::size_t i = 1; i < shape_.size(); ++i) { + row_slices[i] = slice{0, shape_[i]}; + } + return this->subview(row_slices); + } + + auto row(std::size_t index) const { + if constexpr (ErrorChecking == error_checking::enabled) { + if (index >= shape_[0]) { + throw std::out_of_range("Row index out of range"); + } + } + std::vector row_slices(shape_.size()); + row_slices[0] = slice{index, 1}; + for (std::size_t i = 1; i < shape_.size(); ++i) { + row_slices[i] = slice{0, shape_[i]}; + } + return this->subview(row_slices); + } + + auto col(std::size_t index) { + if constexpr (ErrorChecking == error_checking::enabled) { + if (index >= shape_[shape_.size() - 1]) { + throw std::out_of_range("Column index out of range"); + } + } + std::vector col_slices(shape_.size()); + col_slices[shape_.size() - 1] = slice{index, 1}; + for (std::size_t i = 0; i < shape_.size() - 1; ++i) { + col_slices[i] = slice{0, shape_[i]}; + } + return this->subview(col_slices); + } + + auto col(std::size_t index) const { + if constexpr (ErrorChecking == error_checking::enabled) { + if (index >= shape_[shape_.size() - 1]) { + throw std::out_of_range("Column index out of range"); + } + } + std::vector col_slices(shape_.size()); + col_slices[shape_.size() - 1] = slice{index, 1}; + for (std::size_t i = 0; i < shape_.size() - 1; ++i) { + col_slices[i] = slice{0, shape_[i]}; + } + return this->subview(col_slices); + } + private: size_t calculate_index(const std::vector &indices) const { if constexpr (ErrorChecking == error_checking::enabled) { @@ -371,7 +441,12 @@ class dynamic_tensor : public iterable_tensor, } }; -template using dtens = dynamic_tensor; +template using tens_t = dynamic_tensor; +using itens = tens_t; +using utens = tens_t; +using tens = tens_t; +using dtens = tens_t; +using btens = tens_t; } // namespace squint diff --git a/include/squint/fixed_tensor.hpp b/include/squint/fixed_tensor.hpp index 38c3dd3..cbfbada 100644 --- a/include/squint/fixed_tensor.hpp +++ b/include/squint/fixed_tensor.hpp @@ -6,6 +6,7 @@ #include "squint/tensor_base.hpp" #include "squint/tensor_view.hpp" #include +#include #include namespace squint { @@ -254,6 +255,40 @@ class fixed_tensor : public iterable_tensor= std::array{Dims...}[0]) { + throw std::out_of_range("Row index out of range"); + } + } + if constexpr (sizeof...(Dims) == 0) { + // For 0D tensors + return this->template subview<>(); + } else { + return [&](std::index_sequence) { + constexpr std::array dims = {Dims...}; + return this->template subview<1, std::get(dims)...>(slice{index, 1}, slice{0, Is}...); + }(std::make_index_sequence{}); + } + } + + auto row(std::size_t index) const { + if constexpr (ErrorChecking == error_checking::enabled) { + if (index >= std::array{Dims...}[0]) { + throw std::out_of_range("Row index out of range"); + } + } + if constexpr (sizeof...(Dims) == 0) { + // For 0D tensors + return this->template subview<>(); + } else { + return [&](std::index_sequence) { + constexpr std::array dims = {Dims...}; + return this->template subview<1, std::get(dims)...>(slice{index, 1}, slice{0, Is}...); + }(std::make_index_sequence{}); + } + } + auto cols() { if constexpr (sizeof...(Dims) > 1) { return [&](std::index_sequence) { @@ -278,6 +313,40 @@ class fixed_tensor : public iterable_tensor= std::array{Dims...}[sizeof...(Dims) - 1]) { + throw std::out_of_range("Row index out of range"); + } + } + if constexpr (sizeof...(Dims) > 1) { + return [&](std::index_sequence) { + constexpr std::array dims = {Dims...}; + return this->template subview(dims)..., 1>(slice{0, Is}..., slice{index, 1}); + }(std::make_index_sequence{}); + } else { + // For 1D tensors + return this->template subview(slice{0, size()}); + } + } + + auto col(std::size_t index) const { + if constexpr (ErrorChecking == error_checking::enabled) { + if (index >= std::array{Dims...}[sizeof...(Dims) - 1]) { + throw std::out_of_range("Row index out of range"); + } + } + if constexpr (sizeof...(Dims) > 1) { + return [&](std::index_sequence) { + constexpr std::array dims = {Dims...}; + return this->template subview(dims)..., 1>(slice{0, Is}..., slice{index, 1}); + }(std::make_index_sequence{}); + } else { + // For 1D tensors + return this->template subview(slice{0, size()}); + } + } + private: template static constexpr size_t calculate_index(std::index_sequence /*unused*/, Indices... indices) { @@ -353,22 +422,82 @@ class fixed_tensor : public iterable_tensor using vec2 = fixed_tensor; -template using vec3 = fixed_tensor; -template using vec4 = fixed_tensor; +template using vec2_t = fixed_tensor; +template using vec3_t = fixed_tensor; +template using vec4_t = fixed_tensor; +using ivec2 = vec2_t; +using ivec3 = vec3_t; +using ivec4 = vec4_t; +using uvec2 = vec2_t; +using uvec3 = vec3_t; +using uvec4 = vec4_t; +using vec2 = vec2_t; +using vec3 = vec3_t; +using vec4 = vec4_t; +using dvec2 = vec2_t; +using dvec3 = vec3_t; +using dvec4 = vec4_t; +using bvec2 = vec2_t; +using bvec3 = vec3_t; +using bvec4 = vec4_t; // Square matrix types -template using mat2 = fixed_tensor; -template using mat3 = fixed_tensor; -template using mat4 = fixed_tensor; +template using mat2_t = fixed_tensor; +template using mat3_t = fixed_tensor; +template using mat4_t = fixed_tensor; +using imat2 = mat2_t; +using imat3 = mat3_t; +using imat4 = mat4_t; +using umat2 = mat2_t; +using umat3 = mat3_t; +using umat4 = mat4_t; +using mat2 = mat2_t; +using mat3 = mat3_t; +using mat4 = mat4_t; +using dmat2 = mat2_t; +using dmat3 = mat3_t; +using dmat4 = mat4_t; +using bmat2 = mat2_t; +using bmat3 = mat3_t; +using bmat4 = mat4_t; // Non-square matrix types -template using mat2x3 = fixed_tensor; -template using mat2x4 = fixed_tensor; -template using mat3x2 = fixed_tensor; -template using mat3x4 = fixed_tensor; -template using mat4x2 = fixed_tensor; -template using mat4x3 = fixed_tensor; +template using mat2x3_t = fixed_tensor; +template using mat2x4_t = fixed_tensor; +template using mat3x2_t = fixed_tensor; +template using mat3x4_t = fixed_tensor; +template using mat4x2_t = fixed_tensor; +template using mat4x3_t = fixed_tensor; +using imat2x3 = mat2x3_t; +using imat2x4 = mat2x4_t; +using imat3x2 = mat3x2_t; +using imat3x4 = mat3x4_t; +using imat4x2 = mat4x2_t; +using imat4x3 = mat4x3_t; +using umat2x3 = mat2x3_t; +using umat2x4 = mat2x4_t; +using umat3x2 = mat3x2_t; +using umat3x4 = mat3x4_t; +using umat4x2 = mat4x2_t; +using umat4x3 = mat4x3_t; +using mat2x3 = mat2x3_t; +using mat2x4 = mat2x4_t; +using mat3x2 = mat3x2_t; +using mat3x4 = mat3x4_t; +using mat4x2 = mat4x2_t; +using mat4x3 = mat4x3_t; +using dmat2x3 = mat2x3_t; +using dmat2x4 = mat2x4_t; +using dmat3x2 = mat3x2_t; +using dmat3x4 = mat3x4_t; +using dmat4x2 = mat4x2_t; +using dmat4x3 = mat4x3_t; +using bmat2x3 = mat2x3_t; +using bmat2x4 = mat2x4_t; +using bmat3x2 = mat3x2_t; +using bmat3x4 = mat3x4_t; +using bmat4x2 = mat4x2_t; +using bmat4x3 = mat4x3_t; } // namespace squint