Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
barne856 committed Jul 10, 2024
1 parent 79c98fa commit 4a730f1
Show file tree
Hide file tree
Showing 6 changed files with 373 additions and 62 deletions.
3 changes: 3 additions & 0 deletions include/squint/dynamic_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define SQUINT_DYNAMIC_TENSOR_HPP

#include "squint/iterable_tensor.hpp"
#include "squint/quantity.hpp"
#include "squint/tensor_base.hpp"
#include "squint/tensor_view.hpp"
#include <numeric>
Expand All @@ -18,6 +19,7 @@ class dynamic_tensor : public iterable_tensor<dynamic_tensor<T, ErrorChecking>,
layout layout_;

public:
using value_type = T;
using iterable_tensor<dynamic_tensor<T, ErrorChecking>, T, ErrorChecking>::subviews;
constexpr dynamic_tensor() = default;
// virtual destructor
Expand Down Expand Up @@ -125,6 +127,7 @@ class dynamic_tensor : public iterable_tensor<dynamic_tensor<T, ErrorChecking>,
constexpr std::size_t size() const { return data_.size(); }
constexpr std::vector<std::size_t> shape() const { return shape_; }
constexpr layout get_layout() const { return layout_; }
constexpr error_checking get_error_checking() const { return ErrorChecking; }
std::vector<std::size_t> strides() const { return calculate_strides(); }

T &at_impl(const std::vector<size_t> &indices) { return data_[calculate_index(indices)]; }
Expand Down
14 changes: 13 additions & 1 deletion include/squint/fixed_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define SQUINT_FIXED_TENSOR_HPP

#include "squint/iterable_tensor.hpp"
#include "squint/linear_algebra.hpp"
#include "squint/quantity.hpp"
#include "squint/tensor_base.hpp"
#include "squint/tensor_view.hpp"
Expand All @@ -22,11 +23,13 @@ constexpr auto make_subviews_iterator(const BlockTensor & /*unused*/, Tensor &te

// Fixed tensor implementation
template <typename T, layout L, error_checking ErrorChecking, std::size_t... Dims>
class fixed_tensor : public iterable_tensor<fixed_tensor<T, L, ErrorChecking, Dims...>, T, ErrorChecking> {
class fixed_tensor : public iterable_tensor<fixed_tensor<T, L, ErrorChecking, Dims...>, T, ErrorChecking>,
public fixed_linear_algebra_mixin<fixed_tensor<T, L, ErrorChecking, Dims...>, ErrorChecking> {
static constexpr std::size_t total_size = (Dims * ...);
std::array<T, total_size> data_;

public:
using value_type = T;
using iterable_tensor<fixed_tensor<T, L, ErrorChecking, Dims...>, T, ErrorChecking>::subviews;
// virtual destructor
virtual ~fixed_tensor() = default;
Expand Down Expand Up @@ -59,6 +62,7 @@ class fixed_tensor : public iterable_tensor<fixed_tensor<T, L, ErrorChecking, Di
static constexpr std::size_t rank() { return sizeof...(Dims); }
static constexpr std::size_t size() { return total_size; }
static constexpr layout get_layout() { return L; }
static constexpr error_checking get_error_checking() { return ErrorChecking; }
static constexpr std::vector<std::size_t> strides() {
auto strides_array = calculate_strides();
return std::vector<std::size_t>(std::begin(strides_array), std::end(strides_array));
Expand Down Expand Up @@ -131,6 +135,14 @@ class fixed_tensor : public iterable_tensor<fixed_tensor<T, L, ErrorChecking, Di
return view().template subview<NewDims...>(slices...);
}

template <typename U> auto as() const {
fixed_tensor<U, L, ErrorChecking, Dims...> result;
for (std::size_t i = 0; i < total_size; ++i) {
result.data()[i] = static_cast<U>(data_[i]);
}
return result;
}

static constexpr fixed_tensor zeros() {
fixed_tensor result;
result.data_.fill(T{});
Expand Down
Loading

0 comments on commit 4a730f1

Please sign in to comment.