Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
barne856 committed Jul 10, 2024
1 parent 0eda89e commit 79c98fa
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 13 deletions.
77 changes: 76 additions & 1 deletion include/squint/dynamic_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,20 @@ class dynamic_tensor : public iterable_tensor<dynamic_tensor<T, ErrorChecking>,
return view().subview(slices...);
}

auto subview(const std::vector<slice> &slices) {
if constexpr (ErrorChecking == error_checking::enabled) {
this->check_subview_bounds(slices);
}
return view().subview(slices);
}

auto subview(const std::vector<slice> &slices) const {
if constexpr (ErrorChecking == error_checking::enabled) {
this->check_subview_bounds(slices);
}
return view().subview(slices);
}

static dynamic_tensor zeros(const std::vector<std::size_t> &shape, layout l = layout::column_major) {
dynamic_tensor result(shape, l);
result.fill(T{});
Expand Down Expand Up @@ -307,6 +321,62 @@ class dynamic_tensor : public iterable_tensor<dynamic_tensor<T, ErrorChecking>,
return this->subviews(col_shape);
}

auto row(std::size_t index) {
if constexpr (ErrorChecking == error_checking::enabled) {
if (index >= shape_[0]) {
throw std::out_of_range("Row index out of range");
}
}
std::vector<slice> row_slices(shape_.size());
row_slices[0] = slice{index, 1};
for (std::size_t i = 1; i < shape_.size(); ++i) {
row_slices[i] = slice{0, shape_[i]};
}
return this->subview(row_slices);
}

auto row(std::size_t index) const {
if constexpr (ErrorChecking == error_checking::enabled) {
if (index >= shape_[0]) {
throw std::out_of_range("Row index out of range");
}
}
std::vector<slice> row_slices(shape_.size());
row_slices[0] = slice{index, 1};
for (std::size_t i = 1; i < shape_.size(); ++i) {
row_slices[i] = slice{0, shape_[i]};
}
return this->subview(row_slices);
}

auto col(std::size_t index) {
if constexpr (ErrorChecking == error_checking::enabled) {
if (index >= shape_[shape_.size() - 1]) {
throw std::out_of_range("Column index out of range");
}
}
std::vector<slice> col_slices(shape_.size());
col_slices[shape_.size() - 1] = slice{index, 1};
for (std::size_t i = 0; i < shape_.size() - 1; ++i) {
col_slices[i] = slice{0, shape_[i]};
}
return this->subview(col_slices);
}

auto col(std::size_t index) const {
if constexpr (ErrorChecking == error_checking::enabled) {
if (index >= shape_[shape_.size() - 1]) {
throw std::out_of_range("Column index out of range");
}
}
std::vector<slice> col_slices(shape_.size());
col_slices[shape_.size() - 1] = slice{index, 1};
for (std::size_t i = 0; i < shape_.size() - 1; ++i) {
col_slices[i] = slice{0, shape_[i]};
}
return this->subview(col_slices);
}

private:
size_t calculate_index(const std::vector<size_t> &indices) const {
if constexpr (ErrorChecking == error_checking::enabled) {
Expand Down Expand Up @@ -371,7 +441,12 @@ class dynamic_tensor : public iterable_tensor<dynamic_tensor<T, ErrorChecking>,
}
};

template <typename T> using dtens = dynamic_tensor<T, error_checking::disabled>;
template <typename T> using tens_t = dynamic_tensor<T, error_checking::disabled>;
using itens = tens_t<int>;
using utens = tens_t<unsigned char>;
using tens = tens_t<float>;
using dtens = tens_t<double>;
using btens = tens_t<bool>;

} // namespace squint

Expand Down
153 changes: 141 additions & 12 deletions include/squint/fixed_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "squint/tensor_base.hpp"
#include "squint/tensor_view.hpp"
#include <array>
#include <cstddef>
#include <random>

namespace squint {
Expand Down Expand Up @@ -254,6 +255,40 @@ class fixed_tensor : public iterable_tensor<fixed_tensor<T, L, ErrorChecking, Di
}
}

auto row(std::size_t index) {
if constexpr (ErrorChecking == error_checking::enabled) {
if (index >= std::array{Dims...}[0]) {
throw std::out_of_range("Row index out of range");
}
}
if constexpr (sizeof...(Dims) == 0) {
// For 0D tensors
return this->template subview<>();
} else {
return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
constexpr std::array<std::size_t, sizeof...(Dims)> dims = {Dims...};
return this->template subview<1, std::get<Is + 1>(dims)...>(slice{index, 1}, slice{0, Is}...);
}(std::make_index_sequence<sizeof...(Dims) - 1>{});
}
}

auto row(std::size_t index) const {
if constexpr (ErrorChecking == error_checking::enabled) {
if (index >= std::array{Dims...}[0]) {
throw std::out_of_range("Row index out of range");
}
}
if constexpr (sizeof...(Dims) == 0) {
// For 0D tensors
return this->template subview<>();
} else {
return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
constexpr std::array<std::size_t, sizeof...(Dims)> dims = {Dims...};
return this->template subview<1, std::get<Is + 1>(dims)...>(slice{index, 1}, slice{0, Is}...);
}(std::make_index_sequence<sizeof...(Dims) - 1>{});
}
}

auto cols() {
if constexpr (sizeof...(Dims) > 1) {
return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
Expand All @@ -278,6 +313,40 @@ class fixed_tensor : public iterable_tensor<fixed_tensor<T, L, ErrorChecking, Di
}
}

auto col(std::size_t index) {
if constexpr (ErrorChecking == error_checking::enabled) {
if (index >= std::array{Dims...}[sizeof...(Dims) - 1]) {
throw std::out_of_range("Row index out of range");
}
}
if constexpr (sizeof...(Dims) > 1) {
return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
constexpr std::array<std::size_t, sizeof...(Dims)> dims = {Dims...};
return this->template subview<std::get<Is>(dims)..., 1>(slice{0, Is}..., slice{index, 1});
}(std::make_index_sequence<sizeof...(Dims) - 1>{});
} else {
// For 1D tensors
return this->template subview<Dims...>(slice{0, size()});
}
}

auto col(std::size_t index) const {
if constexpr (ErrorChecking == error_checking::enabled) {
if (index >= std::array{Dims...}[sizeof...(Dims) - 1]) {
throw std::out_of_range("Row index out of range");
}
}
if constexpr (sizeof...(Dims) > 1) {
return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
constexpr std::array<std::size_t, sizeof...(Dims)> dims = {Dims...};
return this->template subview<std::get<Is>(dims)..., 1>(slice{0, Is}..., slice{index, 1});
}(std::make_index_sequence<sizeof...(Dims) - 1>{});
} else {
// For 1D tensors
return this->template subview<Dims...>(slice{0, size()});
}
}

private:
template <std::size_t... Is, typename... Indices>
static constexpr size_t calculate_index(std::index_sequence<Is...> /*unused*/, Indices... indices) {
Expand Down Expand Up @@ -353,22 +422,82 @@ class fixed_tensor : public iterable_tensor<fixed_tensor<T, L, ErrorChecking, Di
};

// Vector types
template <typename T> using vec2 = fixed_tensor<T, layout::column_major, error_checking::disabled, 2>;
template <typename T> using vec3 = fixed_tensor<T, layout::column_major, error_checking::disabled, 3>;
template <typename T> using vec4 = fixed_tensor<T, layout::column_major, error_checking::disabled, 4>;
template <typename T> using vec2_t = fixed_tensor<T, layout::column_major, error_checking::disabled, 2>;
template <typename T> using vec3_t = fixed_tensor<T, layout::column_major, error_checking::disabled, 3>;
template <typename T> using vec4_t = fixed_tensor<T, layout::column_major, error_checking::disabled, 4>;
using ivec2 = vec2_t<int>;
using ivec3 = vec3_t<int>;
using ivec4 = vec4_t<int>;
using uvec2 = vec2_t<unsigned char>;
using uvec3 = vec3_t<unsigned char>;
using uvec4 = vec4_t<unsigned char>;
using vec2 = vec2_t<float>;
using vec3 = vec3_t<float>;
using vec4 = vec4_t<float>;
using dvec2 = vec2_t<double>;
using dvec3 = vec3_t<double>;
using dvec4 = vec4_t<double>;
using bvec2 = vec2_t<bool>;
using bvec3 = vec3_t<bool>;
using bvec4 = vec4_t<bool>;

// Square matrix types
template <typename T> using mat2 = fixed_tensor<T, layout::column_major, error_checking::disabled, 2, 2>;
template <typename T> using mat3 = fixed_tensor<T, layout::column_major, error_checking::disabled, 3, 3>;
template <typename T> using mat4 = fixed_tensor<T, layout::column_major, error_checking::disabled, 4, 4>;
template <typename T> using mat2_t = fixed_tensor<T, layout::column_major, error_checking::disabled, 2, 2>;
template <typename T> using mat3_t = fixed_tensor<T, layout::column_major, error_checking::disabled, 3, 3>;
template <typename T> using mat4_t = fixed_tensor<T, layout::column_major, error_checking::disabled, 4, 4>;
using imat2 = mat2_t<int>;
using imat3 = mat3_t<int>;
using imat4 = mat4_t<int>;
using umat2 = mat2_t<unsigned char>;
using umat3 = mat3_t<unsigned char>;
using umat4 = mat4_t<unsigned char>;
using mat2 = mat2_t<float>;
using mat3 = mat3_t<float>;
using mat4 = mat4_t<float>;
using dmat2 = mat2_t<double>;
using dmat3 = mat3_t<double>;
using dmat4 = mat4_t<double>;
using bmat2 = mat2_t<bool>;
using bmat3 = mat3_t<bool>;
using bmat4 = mat4_t<bool>;

// Non-square matrix types
template <typename T> using mat2x3 = fixed_tensor<T, layout::column_major, error_checking::disabled, 2, 3>;
template <typename T> using mat2x4 = fixed_tensor<T, layout::column_major, error_checking::disabled, 2, 4>;
template <typename T> using mat3x2 = fixed_tensor<T, layout::column_major, error_checking::disabled, 3, 2>;
template <typename T> using mat3x4 = fixed_tensor<T, layout::column_major, error_checking::disabled, 3, 4>;
template <typename T> using mat4x2 = fixed_tensor<T, layout::column_major, error_checking::disabled, 4, 2>;
template <typename T> using mat4x3 = fixed_tensor<T, layout::column_major, error_checking::disabled, 4, 3>;
template <typename T> using mat2x3_t = fixed_tensor<T, layout::column_major, error_checking::disabled, 2, 3>;
template <typename T> using mat2x4_t = fixed_tensor<T, layout::column_major, error_checking::disabled, 2, 4>;
template <typename T> using mat3x2_t = fixed_tensor<T, layout::column_major, error_checking::disabled, 3, 2>;
template <typename T> using mat3x4_t = fixed_tensor<T, layout::column_major, error_checking::disabled, 3, 4>;
template <typename T> using mat4x2_t = fixed_tensor<T, layout::column_major, error_checking::disabled, 4, 2>;
template <typename T> using mat4x3_t = fixed_tensor<T, layout::column_major, error_checking::disabled, 4, 3>;
using imat2x3 = mat2x3_t<int>;
using imat2x4 = mat2x4_t<int>;
using imat3x2 = mat3x2_t<int>;
using imat3x4 = mat3x4_t<int>;
using imat4x2 = mat4x2_t<int>;
using imat4x3 = mat4x3_t<int>;
using umat2x3 = mat2x3_t<unsigned char>;
using umat2x4 = mat2x4_t<unsigned char>;
using umat3x2 = mat3x2_t<unsigned char>;
using umat3x4 = mat3x4_t<unsigned char>;
using umat4x2 = mat4x2_t<unsigned char>;
using umat4x3 = mat4x3_t<unsigned char>;
using mat2x3 = mat2x3_t<float>;
using mat2x4 = mat2x4_t<float>;
using mat3x2 = mat3x2_t<float>;
using mat3x4 = mat3x4_t<float>;
using mat4x2 = mat4x2_t<float>;
using mat4x3 = mat4x3_t<float>;
using dmat2x3 = mat2x3_t<double>;
using dmat2x4 = mat2x4_t<double>;
using dmat3x2 = mat3x2_t<double>;
using dmat3x4 = mat3x4_t<double>;
using dmat4x2 = mat4x2_t<double>;
using dmat4x3 = mat4x3_t<double>;
using bmat2x3 = mat2x3_t<bool>;
using bmat2x4 = mat2x4_t<bool>;
using bmat3x2 = mat3x2_t<bool>;
using bmat3x4 = mat3x4_t<bool>;
using bmat4x2 = mat4x2_t<bool>;
using bmat4x3 = mat4x3_t<bool>;

} // namespace squint

Expand Down

0 comments on commit 79c98fa

Please sign in to comment.