Skip to content

Commit

Permalink
Added constructors for fixed_tensor and dynamic_tensor to allow const…
Browse files Browse the repository at this point in the history
…ruction from initializer lists.
  • Loading branch information
barne856 committed Jul 18, 2024
1 parent 2d60eeb commit 3c5d881
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
13 changes: 12 additions & 1 deletion include/squint/dynamic_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ class dynamic_tensor : public iterable_tensor<dynamic_tensor<T, ErrorChecking>,
std::size_t total_size = std::accumulate(shape_.begin(), shape_.end(), 1ULL, std::multiplies<>());
data_.resize(total_size);
}
// Construct from initializer list
dynamic_tensor(std::vector<std::size_t> shape, std::initializer_list<T> init, layout layout = layout::column_major)
: shape_(std::move(shape)), layout_(layout) {
std::size_t total_size = std::accumulate(shape_.begin(), shape_.end(), 1ULL, std::multiplies<>());
if constexpr (ErrorChecking == error_checking::enabled) {
if (init.size() != total_size) {
throw std::invalid_argument("Initializer list size must match total size");
}
}
data_ = std::vector<T>(init);
}
// Construct from vector of elements
dynamic_tensor(std::vector<std::size_t> shape, const std::vector<T> &elements, layout layout = layout::column_major)
: shape_(std::move(shape)), layout_(layout) {
Expand Down Expand Up @@ -454,4 +465,4 @@ using btens = tens_t<bool>;

} // namespace squint

#endif // SQUINT_DYNAMIC_TENSOR_HPP
#endif // SQUINT_DYNAMIC_TENSOR_HPP
11 changes: 10 additions & 1 deletion include/squint/fixed_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ class fixed_tensor : public iterable_tensor<fixed_tensor<T, L, ErrorChecking, Di
constexpr fixed_tensor() = default;
// insert elements into the layout
constexpr fixed_tensor(const std::array<T, total_size> &elements) : data_(elements) {}
// Construct from initializer list
constexpr fixed_tensor(std::initializer_list<T> init) {
if constexpr (ErrorChecking == error_checking::enabled) {
if (init.size() != total_size) {
throw std::invalid_argument("Initializer list size must match total size");
}
}
std::copy(init.begin(), init.end(), data_.begin());
}
// Fill the tensor with a single value
explicit constexpr fixed_tensor(const T &value) { data_.fill(value); }
// Fill the tensor with a single block or view
Expand Down Expand Up @@ -585,4 +594,4 @@ template <std::size_t... Dims> using bndarr = ndarr_t<bool, Dims...>;

} // namespace squint

#endif // SQUINT_FIXED_TENSOR_HPP
#endif // SQUINT_FIXED_TENSOR_HPP

0 comments on commit 3c5d881

Please sign in to comment.