diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile
index 29171ee..1d71bb7 100644
--- a/.devcontainer/Dockerfile
+++ b/.devcontainer/Dockerfile
@@ -40,7 +40,6 @@ RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
     netpbm \
     imagemagick \
     ghostscript \
-    nvidia-cuda-toolkit \
     && apt-get clean && rm -rf /var/lib/apt/lists/* \ 
     && ln -s /usr/bin/clang-$LLVM_VERSION /usr/bin/clang \
     && ln -s /usr/bin/clang++-$LLVM_VERSION /usr/bin/clang++ \
@@ -70,6 +69,17 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cm
     && ln -s /opt/cmake-${CMAKE_VERSION}/bin/* /usr/local/bin \
     && rm /tmp/cmake-install.sh
 
+# Install CUDA
+RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb \
+    && dpkg -i cuda-keyring_1.1-1_all.deb
+
+RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
+    && apt-get -y install --no-install-recommends \
+    cuda-toolkit-12-6
+
+# add CUDA to the path
+ENV PATH=/usr/local/cuda-12.6/bin:/usr/local/cuda-12.6:$PATH
+
 # Clean up
 RUN apt-get -y remove wget gnupg software-properties-common
 
diff --git a/CMakeLists.txt b/CMakeLists.txt
index f17a35f..6d7ae68 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -212,3 +212,7 @@ if(SQUINT_BUILD_DOCUMENTATION)
     DESTINATION ${CMAKE_INSTALL_DOCDIR})
 
 endif()
+
+add_executable(main main.cpp)
+target_link_libraries(main PRIVATE SQUINT)
+include_directories(include)
\ No newline at end of file
diff --git a/include/squint/tensor/cuda/cuda_context.hpp b/include/squint/tensor/cuda/cuda_context.hpp
index 6745550..d24613e 100644
--- a/include/squint/tensor/cuda/cuda_context.hpp
+++ b/include/squint/tensor/cuda/cuda_context.hpp
@@ -5,32 +5,32 @@
 
 namespace squint::cuda {
 
-class CudaContext {
+class cuda_context {
   public:
-    static auto instance() -> CudaContext & {
-        static CudaContext instance;
+    static auto instance() -> cuda_context & {
+        static cuda_context instance;
         return instance;
     }
 
     [[nodiscard]] auto cublas_handle() const -> cublasHandle_t { return cublas_handle_; }
 
     // Delete copy constructor and assignment operator
-    CudaContext(const CudaContext &) = delete;
-    auto operator=(const CudaContext &) -> CudaContext & = delete;
+    cuda_context(const cuda_context &) = delete;
+    auto operator=(const cuda_context &) -> cuda_context & = delete;
 
     // Delete move constructor and assignment operator
-    CudaContext(CudaContext &&) = delete;
-    auto operator=(CudaContext &&) -> CudaContext & = delete;
+    cuda_context(cuda_context &&) = delete;
+    auto operator=(cuda_context &&) -> cuda_context & = delete;
 
   private:
-    CudaContext() {
+    cuda_context() {
         cublasStatus_t status = cublasCreate(&cublas_handle_);
         if (status != CUBLAS_STATUS_SUCCESS) {
             throw std::runtime_error("Failed to create cuBLAS handle");
         }
     }
 
-    ~CudaContext() { cublasDestroy(cublas_handle_); }
+    ~cuda_context() { cublasDestroy(cublas_handle_); }
 
     cublasHandle_t cublas_handle_{};
 };
diff --git a/include/squint/tensor/cuda/element_wise.hpp b/include/squint/tensor/cuda/element_wise.hpp
index f894a41..3930228 100644
--- a/include/squint/tensor/cuda/element_wise.hpp
+++ b/include/squint/tensor/cuda/element_wise.hpp
@@ -1,3 +1,4 @@
+#include <cstdint>
 #ifndef SQUINT_TENSOR_CUDA_ELEMENT_WISE_HPP
 
 template <typename T>
@@ -7,17 +8,21 @@ void element_wise_addition(T *output, const T *a, const T *b, const unsigned lon
 
 template <typename T>
 void element_wise_subtraction(T *output, const T *a, const T *b, const unsigned long *dims,
-                           const unsigned long *strides_out, const unsigned long *strides_a,
-                           const unsigned long *strides_b, unsigned long num_dims, unsigned long total_size);
+                              const unsigned long *strides_out, const unsigned long *strides_a,
+                              const unsigned long *strides_b, unsigned long num_dims, unsigned long total_size);
 
 template <typename T>
-void element_wise_equality(T *output, const T *a, const T *b, const unsigned long *dims,
+void element_wise_equality(uint8_t *output, const T *a, const T *b, const unsigned long *dims,
                            const unsigned long *strides_out, const unsigned long *strides_a,
                            const unsigned long *strides_b, unsigned long num_dims, unsigned long total_size);
 
 template <typename T>
-void element_wise_inequality(T *output, const T *a, const T *b, const unsigned long *dims,
-                           const unsigned long *strides_out, const unsigned long *strides_a,
-                           const unsigned long *strides_b, unsigned long num_dims, unsigned long total_size);
+void element_wise_inequality(uint8_t *output, const T *a, const T *b, const unsigned long *dims,
+                             const unsigned long *strides_out, const unsigned long *strides_a,
+                             const unsigned long *strides_b, unsigned long num_dims, unsigned long total_size);
+
+template <typename T>
+void element_wise_negation(T *output, const T *a, const unsigned long *dims, const unsigned long *strides_out,
+                           const unsigned long *strides_a, unsigned long num_dims, unsigned long total_size);
 
 #endif // SQUINT_TENSOR_CUDA_ELEMENT_WISE_HPP
diff --git a/include/squint/tensor/cuda/scalar.hpp b/include/squint/tensor/cuda/scalar.hpp
new file mode 100644
index 0000000..48b1ff2
--- /dev/null
+++ b/include/squint/tensor/cuda/scalar.hpp
@@ -0,0 +1,7 @@
+#ifndef SQUINT_TENSOR_CUDA_SCALAR_HPP
+
+template <typename T>
+void scalar_multiplication(T scalar, T *output, const T *a, const unsigned long *dims, const unsigned long *strides_out,
+                           const unsigned long *strides_a, unsigned long num_dims, unsigned long total_size);
+
+#endif // SQUINT_TENSOR_CUDA_SCALAR_HPP
diff --git a/include/squint/tensor/element_wise_ops.hpp b/include/squint/tensor/element_wise_ops.hpp
index 1663bb9..c3ed39e 100644
--- a/include/squint/tensor/element_wise_ops.hpp
+++ b/include/squint/tensor/element_wise_ops.hpp
@@ -46,13 +46,15 @@ auto tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace>::opera
         // NOLINTBEGIN
         using blas_type = std::common_type_t<blas_type_t<T>, blas_type_t<U>>;
         if constexpr (std::is_same_v<blas_type, float>) {
-            element_wise_addition<float>(reinterpret_cast<float *>(data()), reinterpret_cast<const float *>(other.data()),
-                                  reinterpret_cast<const float *>(data()), device_shape(), device_strides(),
-                                  other.device_strides(), device_strides(), shape().size(), size());
+            element_wise_addition<float>(reinterpret_cast<float *>(data()),
+                                         reinterpret_cast<const float *>(other.data()),
+                                         reinterpret_cast<const float *>(data()), device_shape(), device_strides(),
+                                         other.device_strides(), device_strides(), shape().size(), size());
         } else if constexpr (std::is_same_v<blas_type, double>) {
-            element_wise_addition<double>(reinterpret_cast<double *>(data()), reinterpret_cast<const double *>(other.data()),
-                                  reinterpret_cast<const double *>(data()), device_shape(), device_strides(),
-                                  other.device_strides(), device_strides(), shape().size(), size());
+            element_wise_addition<double>(reinterpret_cast<double *>(data()),
+                                          reinterpret_cast<const double *>(other.data()),
+                                          reinterpret_cast<const double *>(data()), device_shape(), device_strides(),
+                                          other.device_strides(), device_strides(), shape().size(), size());
         }
         // NOLINTEND
 #endif
@@ -80,13 +82,15 @@ auto tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace>::opera
         // NOLINTBEGIN
         using blas_type = std::common_type_t<blas_type_t<T>, blas_type_t<U>>;
         if constexpr (std::is_same_v<blas_type, float>) {
-            element_wise_subtraction<float>(reinterpret_cast<float *>(data()), reinterpret_cast<const float *>(other.data()),
-                                  reinterpret_cast<const float *>(data()), device_shape(), device_strides(),
-                                  other.device_strides(), device_strides(), shape().size(), size());
+            element_wise_subtraction<float>(reinterpret_cast<float *>(data()),
+                                            reinterpret_cast<const float *>(other.data()),
+                                            reinterpret_cast<const float *>(data()), device_shape(), device_strides(),
+                                            other.device_strides(), device_strides(), shape().size(), size());
         } else if constexpr (std::is_same_v<blas_type, double>) {
-            element_wise_subtraction<double>(reinterpret_cast<double *>(data()), reinterpret_cast<const double *>(other.data()),
-                                  reinterpret_cast<const double *>(data()), device_shape(), device_strides(),
-                                  other.device_strides(), device_strides(), shape().size(), size());
+            element_wise_subtraction<double>(reinterpret_cast<double *>(data()),
+                                             reinterpret_cast<const double *>(other.data()),
+                                             reinterpret_cast<const double *>(data()), device_shape(), device_strides(),
+                                             other.device_strides(), device_strides(), shape().size(), size());
         }
         // NOLINTEND
 #endif
@@ -106,10 +110,10 @@ template <typename U, typename OtherShape, typename OtherStrides, enum error_che
           enum ownership_type OtherOwnershipType>
 auto tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace>::operator==(
     const tensor<U, OtherShape, OtherStrides, OtherErrorChecking, OtherOwnershipType, MemorySpace> &other) const
-    -> tensor<bool, Shape, Strides, ErrorChecking, ownership_type::owner, MemorySpace> {
+    -> tensor<std::uint8_t, Shape, Strides, ErrorChecking, ownership_type::owner, MemorySpace> {
     element_wise_compatible(*this, other);
-    if constexpr (fixed_shape<Shape>){
-        tensor<bool, Shape, Strides, ErrorChecking, ownership_type::owner, MemorySpace> result;
+    if constexpr (fixed_shape<Shape>) {
+        tensor<std::uint8_t, Shape, Strides, ErrorChecking, ownership_type::owner, MemorySpace> result;
         if constexpr (MemorySpace == memory_space::host) {
             std::transform(begin(), end(), other.begin(), result.begin(), std::equal_to{});
         } else {
@@ -117,21 +121,22 @@ auto tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace>::opera
             // NOLINTBEGIN
             using blas_type = std::common_type_t<blas_type_t<T>, blas_type_t<U>>;
             if constexpr (std::is_same_v<blas_type, float>) {
-                element_wise_equality<float>(reinterpret_cast<float *>(result.data()), reinterpret_cast<const float *>(data()),
-                                      reinterpret_cast<const float *>(other.data()), device_shape(), device_strides(),
-                                      other.device_strides(), device_strides(), shape().size(), size());
+                element_wise_equality<float>(
+                    reinterpret_cast<float *>(result.data()), reinterpret_cast<const float *>(data()),
+                    reinterpret_cast<const float *>(other.data()), device_shape(), device_strides(),
+                    other.device_strides(), device_strides(), shape().size(), size());
             } else if constexpr (std::is_same_v<blas_type, double>) {
-                element_wise_equality<double>(reinterpret_cast<double *>(result.data()), reinterpret_cast<const double *>(data()),
-                                      reinterpret_cast<const double *>(other.data()), device_shape(), device_strides(),
-                                      other.device_strides(), device_strides(), shape().size(), size());
+                element_wise_equality<double>(
+                    reinterpret_cast<double *>(result.data()), reinterpret_cast<const double *>(data()),
+                    reinterpret_cast<const double *>(other.data()), device_shape(), device_strides(),
+                    other.device_strides(), device_strides(), shape().size(), size());
             }
             // NOLINTEND
 #endif
         }
         return result;
-    }
-    else {
-        tensor<bool, Shape, Strides, ErrorChecking, ownership_type::owner, MemorySpace> result(shape());
+    } else {
+        tensor<std::uint8_t, Shape, Strides, ErrorChecking, ownership_type::owner, MemorySpace> result(shape());
         if constexpr (MemorySpace == memory_space::host) {
             std::transform(begin(), end(), other.begin(), result.begin(), std::equal_to{});
         } else {
@@ -139,13 +144,15 @@ auto tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace>::opera
             // NOLINTBEGIN
             using blas_type = std::common_type_t<blas_type_t<T>, blas_type_t<U>>;
             if constexpr (std::is_same_v<blas_type, float>) {
-                element_wise_equality<float>(reinterpret_cast<float *>(result.data()), reinterpret_cast<const float *>(data()),
-                                      reinterpret_cast<const float *>(other.data()), device_shape(), device_strides(),
-                                      other.device_strides(), device_strides(), shape().size(), size());
+                element_wise_equality<float>(
+                    reinterpret_cast<float *>(result.data()), reinterpret_cast<const float *>(data()),
+                    reinterpret_cast<const float *>(other.data()), device_shape(), device_strides(),
+                    other.device_strides(), device_strides(), shape().size(), size());
             } else if constexpr (std::is_same_v<blas_type, double>) {
-                element_wise_equality<double>(reinterpret_cast<double *>(result.data()), reinterpret_cast<const double *>(data()),
-                                      reinterpret_cast<const double *>(other.data()), device_shape(), device_strides(),
-                                      other.device_strides(), device_strides(), shape().size(), size());
+                element_wise_equality<double>(
+                    reinterpret_cast<double *>(result.data()), reinterpret_cast<const double *>(data()),
+                    reinterpret_cast<const double *>(other.data()), device_shape(), device_strides(),
+                    other.device_strides(), device_strides(), shape().size(), size());
             }
             // NOLINTEND
 #endif
@@ -166,10 +173,10 @@ template <typename U, typename OtherShape, typename OtherStrides, enum error_che
           enum ownership_type OtherOwnershipType>
 auto tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace>::operator!=(
     const tensor<U, OtherShape, OtherStrides, OtherErrorChecking, OtherOwnershipType, MemorySpace> &other) const
-    -> tensor<bool, Shape, Strides, ErrorChecking, ownership_type::owner, MemorySpace>  {
+    -> tensor<std::uint8_t, Shape, Strides, ErrorChecking, ownership_type::owner, MemorySpace> {
     element_wise_compatible(*this, other);
-    if constexpr (fixed_shape<Shape>){
-        tensor<bool, Shape, Strides, ErrorChecking, ownership_type::owner, MemorySpace> result;
+    if constexpr (fixed_shape<Shape>) {
+        tensor<std::uint8_t, Shape, Strides, ErrorChecking, ownership_type::owner, MemorySpace> result;
         if constexpr (MemorySpace == memory_space::host) {
             std::transform(begin(), end(), other.begin(), result.begin(), std::not_equal_to{});
         } else {
@@ -177,21 +184,22 @@ auto tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace>::opera
             // NOLINTBEGIN
             using blas_type = std::common_type_t<blas_type_t<T>, blas_type_t<U>>;
             if constexpr (std::is_same_v<blas_type, float>) {
-                element_wise_inequality<float>(reinterpret_cast<float *>(result.data()), reinterpret_cast<const float *>(data()),
-                                      reinterpret_cast<const float *>(other.data()), device_shape(), device_strides(),
-                                      other.device_strides(), device_strides(), shape().size(), size());
+                element_wise_inequality<float>(
+                    reinterpret_cast<float *>(result.data()), reinterpret_cast<const float *>(data()),
+                    reinterpret_cast<const float *>(other.data()), device_shape(), device_strides(),
+                    other.device_strides(), device_strides(), shape().size(), size());
             } else if constexpr (std::is_same_v<blas_type, double>) {
-                element_wise_inequality<double>(reinterpret_cast<double *>(result.data()), reinterpret_cast<const double *>(data()),
-                                      reinterpret_cast<const double *>(other.data()), device_shape(), device_strides(),
-                                      other.device_strides(), device_strides(), shape().size(), size());
+                element_wise_inequality<double>(
+                    reinterpret_cast<double *>(result.data()), reinterpret_cast<const double *>(data()),
+                    reinterpret_cast<const double *>(other.data()), device_shape(), device_strides(),
+                    other.device_strides(), device_strides(), shape().size(), size());
             }
             // NOLINTEND
 #endif
         }
         return result;
-    }
-    else {
-        tensor<bool, Shape, Strides, ErrorChecking, ownership_type::owner, MemorySpace> result(shape());
+    } else {
+        tensor<std::uint8_t, Shape, Strides, ErrorChecking, ownership_type::owner, MemorySpace> result(shape());
         if constexpr (MemorySpace == memory_space::host) {
             std::transform(begin(), end(), other.begin(), result.begin(), std::not_equal_to{});
         } else {
@@ -199,13 +207,15 @@ auto tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace>::opera
             // NOLINTBEGIN
             using blas_type = std::common_type_t<blas_type_t<T>, blas_type_t<U>>;
             if constexpr (std::is_same_v<blas_type, float>) {
-                element_wise_inequality<float>(reinterpret_cast<float *>(result.data()), reinterpret_cast<const float *>(data()),
-                                      reinterpret_cast<const float *>(other.data()), device_shape(), device_strides(),
-                                      other.device_strides(), device_strides(), shape().size(), size());
+                element_wise_inequality<float>(
+                    reinterpret_cast<float *>(result.data()), reinterpret_cast<const float *>(data()),
+                    reinterpret_cast<const float *>(other.data()), device_shape(), device_strides(),
+                    other.device_strides(), device_strides(), shape().size(), size());
             } else if constexpr (std::is_same_v<blas_type, double>) {
-                element_wise_inequality<double>(reinterpret_cast<double *>(result.data()), reinterpret_cast<const double *>(data()),
-                                      reinterpret_cast<const double *>(other.data()), device_shape(), device_strides(),
-                                      other.device_strides(), device_strides(), shape().size(), size());
+                element_wise_inequality<double>(
+                    reinterpret_cast<double *>(result.data()), reinterpret_cast<const double *>(data()),
+                    reinterpret_cast<const double *>(other.data()), device_shape(), device_strides(),
+                    other.device_strides(), device_strides(), shape().size(), size());
             }
             // NOLINTEND
 #endif
@@ -222,8 +232,25 @@ auto tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace>::opera
 template <typename T, typename Shape, typename Strides, error_checking ErrorChecking, ownership_type OwnershipType,
           memory_space MemorySpace>
 auto tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace>::operator-() const -> tensor {
-    tensor result(*this);
-    std::transform(result.begin(), result.end(), result.begin(), std::negate{});
+    tensor result = this->copy();
+    if constexpr (MemorySpace == memory_space::host) {
+        std::transform(result.begin(), result.end(), result.begin(), std::negate{});
+    } else {
+#ifdef SQUINT_USE_CUDA
+        // NOLINTBEGIN
+        using blas_type = blas_type_t<T>;
+        if constexpr (std::is_same_v<blas_type, float>) {
+            element_wise_negation<float>(reinterpret_cast<float *>(result.data()),
+                                         reinterpret_cast<const float *>(data()), device_shape(), device_strides(),
+                                         device_strides(), shape().size(), size());
+        } else if constexpr (std::is_same_v<blas_type, double>) {
+            element_wise_negation<double>(reinterpret_cast<double *>(result.data()),
+                                          reinterpret_cast<const double *>(data()), device_shape(), device_strides(),
+                                          device_strides(), shape().size(), size());
+        }
+        // NOLINTEND
+#endif
+    }
     return result;
 }
 
@@ -238,19 +265,11 @@ template <typename T, typename Shape, typename Strides, error_checking ErrorChec
           memory_space MemorySpace, typename U, typename OtherShape, typename OtherStrides,
           enum error_checking OtherErrorChecking, enum ownership_type OtherOwnershipType>
 auto operator+(const tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace> &lhs,
-               const tensor<U, OtherShape, OtherStrides, OtherErrorChecking, OtherOwnershipType, MemorySpace> &rhs)
-    -> tensor<decltype(std::declval<T>() + std::declval<U>()),
-              std::conditional_t<fixed_shape<Shape> && fixed_shape<OtherShape>, Shape, std::vector<std::size_t>>,
-              std::conditional_t<fixed_shape<Shape> && fixed_shape<OtherShape>, Strides, std::vector<std::size_t>>,
-              resulting_error_checking<ErrorChecking, OtherErrorChecking>::value, ownership_type::owner, MemorySpace> {
+               const tensor<U, OtherShape, OtherStrides, OtherErrorChecking, OtherOwnershipType, MemorySpace> &rhs) {
     element_wise_compatible(lhs, rhs);
-    tensor<decltype(std::declval<T>() + std::declval<U>()),
-           std::conditional_t<fixed_shape<Shape> && fixed_shape<OtherShape>, Shape, std::vector<std::size_t>>,
-           std::conditional_t<fixed_shape<Shape> && fixed_shape<OtherShape>, Strides, std::vector<std::size_t>>,
-           resulting_error_checking<ErrorChecking, OtherErrorChecking>::value, ownership_type::owner, MemorySpace>
-        result(lhs);
-    std::transform(lhs.begin(), lhs.end(), rhs.begin(), result.begin(), std::plus{});
-    return result;
+    auto result = lhs.copy();
+    result += rhs;
+    return std::move(result);
 }
 
 // Element-wise subtraction
@@ -264,19 +283,11 @@ template <typename T, typename Shape, typename Strides, error_checking ErrorChec
           memory_space MemorySpace, typename U, typename OtherShape, typename OtherStrides,
           enum error_checking OtherErrorChecking, enum ownership_type OtherOwnershipType>
 auto operator-(const tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace> &lhs,
-               const tensor<U, OtherShape, OtherStrides, OtherErrorChecking, OtherOwnershipType, MemorySpace> &rhs)
-    -> tensor<decltype(std::declval<T>() - std::declval<U>()),
-              std::conditional_t<fixed_shape<Shape> && fixed_shape<OtherShape>, Shape, std::vector<std::size_t>>,
-              std::conditional_t<fixed_shape<Shape> && fixed_shape<OtherShape>, Strides, std::vector<std::size_t>>,
-              resulting_error_checking<ErrorChecking, OtherErrorChecking>::value, ownership_type::owner, MemorySpace> {
+               const tensor<U, OtherShape, OtherStrides, OtherErrorChecking, OtherOwnershipType, MemorySpace> &rhs) {
     element_wise_compatible(lhs, rhs);
-    tensor<decltype(std::declval<T>() - std::declval<U>()),
-           std::conditional_t<fixed_shape<Shape> && fixed_shape<OtherShape>, Shape, std::vector<std::size_t>>,
-           std::conditional_t<fixed_shape<Shape> && fixed_shape<OtherShape>, Strides, std::vector<std::size_t>>,
-           resulting_error_checking<ErrorChecking, OtherErrorChecking>::value, ownership_type::owner, MemorySpace>
-        result(lhs);
-    std::transform(lhs.begin(), lhs.end(), rhs.begin(), result.begin(), std::minus{});
-    return result;
+    auto result = lhs.copy();
+    result -= rhs;
+    return std::move(result);
 }
 
 } // namespace squint
diff --git a/include/squint/tensor/scalar_ops.hpp b/include/squint/tensor/scalar_ops.hpp
index 521017d..4d5f112 100644
--- a/include/squint/tensor/scalar_ops.hpp
+++ b/include/squint/tensor/scalar_ops.hpp
@@ -12,9 +12,22 @@
 #include "squint/core/error_checking.hpp"
 #include "squint/core/memory.hpp"
 #include "squint/tensor/tensor.hpp"
+#include "squint/tensor/tensor_op_compatibility.hpp"
+
+#ifdef SQUINT_USE_CUDA
+#include "squint/tensor/cuda/scalar.hpp"
+#endif
 
 namespace squint {
 
+template <scalar T> auto get_scalar_value(const T &s) {
+    if constexpr (quantitative<T>) {
+        return s.value();
+    } else {
+        return s;
+    }
+}
+
 // Scalar multiplication assignment
 /**
  * @brief Scalar multiplication assignment operator.
@@ -25,8 +38,26 @@ template <typename T, typename Shape, typename Strides, error_checking ErrorChec
           memory_space MemorySpace>
 template <dimensionless_scalar U>
 auto tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace>::operator*=(const U &s) -> tensor & {
-    for (auto &element : *this) {
-        element *= s;
+    if constexpr (MemorySpace == memory_space::host) {
+        for (auto &element : *this) {
+            element *= s;
+        }
+    } else {
+#ifdef SQUINT_USE_CUDA
+        // NOLINTBEGIN
+        using blas_type = blas_type_t<T>;
+        blas_type scalar = static_cast<blas_type>(get_scalar_value(s));
+        if constexpr (std::is_same_v<blas_type, float>) {
+            scalar_multiplication<float>(scalar, reinterpret_cast<float *>(data()),
+                                         reinterpret_cast<const float *>(data()), device_shape(), device_strides(),
+                                         device_strides(), shape().size(), size());
+        } else if constexpr (std::is_same_v<blas_type, double>) {
+            scalar_multiplication<double>(scalar, reinterpret_cast<double *>(data()),
+                                          reinterpret_cast<const double *>(data()), device_shape(), device_strides(),
+                                          device_strides(), shape().size(), size());
+        }
+        // NOLINTEND
+#endif
     }
     return *this;
 }
@@ -41,8 +72,26 @@ template <typename T, typename Shape, typename Strides, error_checking ErrorChec
           memory_space MemorySpace>
 template <dimensionless_scalar U>
 auto tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace>::operator/=(const U &s) -> tensor & {
-    for (auto &element : *this) {
-        element /= s;
+    if constexpr (MemorySpace == memory_space::host) {
+        for (auto &element : *this) {
+            element /= s;
+        }
+    } else {
+#ifdef SQUINT_USE_CUDA
+        // NOLINTBEGIN
+        using blas_type = blas_type_t<T>;
+        blas_type scalar = blas_type(1) / static_cast<blas_type>(get_scalar_value(s));
+        if constexpr (std::is_same_v<blas_type, float>) {
+            scalar_multiplication<float>(scalar, reinterpret_cast<float *>(data()),
+                                         reinterpret_cast<const float *>(data()), device_shape(), device_strides(),
+                                         device_strides(), shape().size(), size());
+        } else if constexpr (std::is_same_v<blas_type, double>) {
+            scalar_multiplication<double>(scalar, reinterpret_cast<double *>(data()),
+                                          reinterpret_cast<const double *>(data()), device_shape(), device_strides(),
+                                          device_strides(), shape().size(), size());
+        }
+        // NOLINTEND
+#endif
     }
     return *this;
 }
@@ -56,25 +105,69 @@ auto tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace>::opera
  */
 template <typename T, typename Shape, typename Strides, error_checking ErrorChecking, ownership_type OwnershipType,
           memory_space MemorySpace, scalar U>
-auto operator*(const tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace> &t,
-               const U &s) -> tensor<decltype(std::declval<T>() * std::declval<U>()), Shape, Strides, ErrorChecking,
-                                     ownership_type::owner, MemorySpace> {
-    using result_type = tensor<decltype(std::declval<T>() * std::declval<U>()), Shape, Strides, ErrorChecking,
-                               ownership_type::owner, MemorySpace>;
+auto operator*(const tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace> &t, const U &s) {
     if constexpr (fixed_shape<Shape>) {
-        result_type result;
-        auto result_it = result.begin();
-        for (auto it = t.begin(); it != t.end(); ++it, ++result_it) {
-            *result_it = *it * s;
+        if constexpr (MemorySpace == memory_space::host) {
+            using result_type = tensor<decltype(std::declval<T>() * std::declval<U>()), Shape, Strides, ErrorChecking,
+                                       ownership_type::owner, MemorySpace>;
+            result_type result{};
+            auto result_it = result.begin();
+            for (auto it = t.begin(); it != t.end(); ++it, ++result_it) {
+                *result_it = *it * s;
+            }
+            return std::move(result);
+        } else {
+#ifdef SQUINT_USE_CUDA
+            using result_type = tensor<decltype(std::declval<T>() * std::declval<U>()), Shape, Strides, ErrorChecking,
+                                       ownership_type::reference, MemorySpace>;
+            result_type result{};
+            // NOLINTBEGIN
+            using blas_type = std::common_type_t<blas_type_t<T>, blas_type_t<U>>;
+            blas_type scalar = static_cast<blas_type>(get_scalar_value(s));
+            if constexpr (std::is_same_v<blas_type, float>) {
+                scalar_multiplication<float>(scalar, reinterpret_cast<float *>(result.data()),
+                                             reinterpret_cast<const float *>(t.data()), t.device_shape(),
+                                             t.device_strides(), t.device_strides(), t.shape().size(), t.size());
+            } else if constexpr (std::is_same_v<blas_type, double>) {
+                scalar_multiplication<double>(scalar, reinterpret_cast<double *>(result.data()),
+                                              reinterpret_cast<const double *>(t.data()), t.device_shape(),
+                                              t.device_strides(), t.device_strides(), t.shape().size(), t.size());
+            }
+            // NOLINTEND
+#endif
+            return std::move(result);
         }
-        return result;
     } else {
-        result_type result(t.shape());
-        auto result_it = result.begin();
-        for (auto it = t.begin(); it != t.end(); ++it, ++result_it) {
-            *result_it = *it * s;
+        if constexpr (MemorySpace == memory_space::host) {
+            using result_type = tensor<decltype(std::declval<T>() * std::declval<U>()), Shape, Strides, ErrorChecking,
+                                       ownership_type::owner, MemorySpace>;
+            result_type result(t.shape());
+            auto result_it = result.begin();
+            for (auto it = t.begin(); it != t.end(); ++it, ++result_it) {
+                *result_it = *it * s;
+            }
+            return std::move(result);
+        } else {
+#ifdef SQUINT_USE_CUDA
+            using result_type = tensor<decltype(std::declval<T>() * std::declval<U>()), Shape, Strides, ErrorChecking,
+                                       ownership_type::reference, MemorySpace>;
+            result_type result(t.shape());
+            // NOLINTBEGIN
+            using blas_type = std::common_type_t<blas_type_t<T>, blas_type_t<U>>;
+            blas_type scalar = static_cast<blas_type>(get_scalar_value(s));
+            if constexpr (std::is_same_v<blas_type, float>) {
+                scalar_multiplication<float>(scalar, reinterpret_cast<float *>(result.data()),
+                                             reinterpret_cast<const float *>(t.data()), t.device_shape(),
+                                             t.device_strides(), t.device_strides(), t.shape().size(), t.size());
+            } else if constexpr (std::is_same_v<blas_type, double>) {
+                scalar_multiplication<double>(scalar, reinterpret_cast<double *>(result.data()),
+                                              reinterpret_cast<const double *>(t.data()), t.device_shape(),
+                                              t.device_strides(), t.device_strides(), t.shape().size(), t.size());
+            }
+            // NOLINTEND
+#endif
+            return std::move(result);
         }
-        return result;
     }
 }
 
@@ -87,10 +180,9 @@ auto operator*(const tensor<T, Shape, Strides, ErrorChecking, OwnershipType, Mem
  */
 template <typename T, typename Shape, typename Strides, error_checking ErrorChecking, ownership_type OwnershipType,
           memory_space MemorySpace, scalar U>
-auto operator*(const U &s, const tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace> &t)
-    -> tensor<decltype(std::declval<T>() * std::declval<U>()), Shape, Strides, ErrorChecking, ownership_type::owner,
-              MemorySpace> {
-    return t * s;
+auto operator*(const U &s, const tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace> &t) {
+    auto result = t * s;
+    return std::move(result);
 }
 
 // Tensor-scalar division
@@ -102,25 +194,70 @@ auto operator*(const U &s, const tensor<T, Shape, Strides, ErrorChecking, Owners
  */
 template <typename T, typename Shape, typename Strides, error_checking ErrorChecking, ownership_type OwnershipType,
           memory_space MemorySpace, scalar U>
-auto operator/(const tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace> &t,
-               const U &s) -> tensor<decltype(std::declval<T>() / std::declval<U>()), Shape, Strides, ErrorChecking,
-                                     ownership_type::owner, MemorySpace> {
-    using result_type = tensor<decltype(std::declval<T>() / std::declval<U>()), Shape, Strides, ErrorChecking,
-                               ownership_type::owner, MemorySpace>;
+auto operator/(const tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace> &t, const U &s) {
+
     if constexpr (fixed_shape<Shape>) {
-        result_type result;
-        auto result_it = result.begin();
-        for (auto it = t.begin(); it != t.end(); ++it, ++result_it) {
-            *result_it = *it / s;
+        if constexpr (MemorySpace == memory_space::host) {
+            using result_type = tensor<decltype(std::declval<T>() / std::declval<U>()), Shape, Strides, ErrorChecking,
+                                       ownership_type::owner, MemorySpace>;
+            result_type result{};
+            auto result_it = result.begin();
+            for (auto it = t.begin(); it != t.end(); ++it, ++result_it) {
+                *result_it = *it / s;
+            }
+            return std::move(result);
+        } else {
+#ifdef SQUINT_USE_CUDA
+            using result_type = tensor<decltype(std::declval<T>() / std::declval<U>()), Shape, Strides, ErrorChecking,
+                                       ownership_type::reference, MemorySpace>;
+            result_type result{};
+            // NOLINTBEGIN
+            using blas_type = std::common_type_t<blas_type_t<T>, blas_type_t<U>>;
+            blas_type scalar = blas_type(1) / static_cast<blas_type>(get_scalar_value(s));
+            if constexpr (std::is_same_v<blas_type, float>) {
+                scalar_multiplication<float>(scalar, reinterpret_cast<float *>(result.data()),
+                                             reinterpret_cast<const float *>(t.data()), t.device_shape(),
+                                             t.device_strides(), t.device_strides(), t.shape().size(), t.size());
+            } else if constexpr (std::is_same_v<blas_type, double>) {
+                scalar_multiplication<double>(scalar, reinterpret_cast<double *>(result.data()),
+                                              reinterpret_cast<const double *>(t.data()), t.device_shape(),
+                                              t.device_strides(), t.device_strides(), t.shape().size(), t.size());
+            }
+            // NOLINTEND
+#endif
+            return std::move(result);
         }
-        return result;
     } else {
-        result_type result(t.shape());
-        auto result_it = result.begin();
-        for (auto it = t.begin(); it != t.end(); ++it, ++result_it) {
-            *result_it = *it / s;
+        if constexpr (MemorySpace == memory_space::host) {
+            using result_type = tensor<decltype(std::declval<T>() / std::declval<U>()), Shape, Strides, ErrorChecking,
+                                       ownership_type::owner, MemorySpace>;
+            result_type result(t.shape());
+            auto result_it = result.begin();
+            for (auto it = t.begin(); it != t.end(); ++it, ++result_it) {
+                *result_it = *it / s;
+            }
+            return std::move(result);
+        } else {
+#ifdef SQUINT_USE_CUDA
+            using result_type = tensor<decltype(std::declval<T>() / std::declval<U>()), Shape, Strides, ErrorChecking,
+                                       ownership_type::reference, MemorySpace>;
+            result_type result(t.shape());
+            // NOLINTBEGIN
+            using blas_type = std::common_type_t<blas_type_t<T>, blas_type_t<U>>;
+            blas_type scalar = blas_type(1) / static_cast<blas_type>(get_scalar_value(s));
+            if constexpr (std::is_same_v<blas_type, float>) {
+                scalar_multiplication<float>(scalar, reinterpret_cast<float *>(result.data()),
+                                             reinterpret_cast<const float *>(t.data()), t.device_shape(),
+                                             t.device_strides(), t.device_strides(), t.shape().size(), t.size());
+            } else if constexpr (std::is_same_v<blas_type, double>) {
+                scalar_multiplication<double>(scalar, reinterpret_cast<double *>(result.data()),
+                                              reinterpret_cast<const double *>(t.data()), t.device_shape(),
+                                              t.device_strides(), t.device_strides(), t.shape().size(), t.size());
+            }
+            // NOLINTEND
+#endif
+            return std::move(result);
         }
-        return result;
     }
 }
 
diff --git a/include/squint/tensor/tensor.hpp b/include/squint/tensor/tensor.hpp
index a6fa91f..4c9c331 100644
--- a/include/squint/tensor/tensor.hpp
+++ b/include/squint/tensor/tensor.hpp
@@ -112,10 +112,31 @@ class tensor {
     tensor()
         requires(OwnershipType == ownership_type::owner && MemorySpace == memory_space::host)
     = default;
-    tensor(const tensor &other) = default;
-    tensor(tensor &&other) noexcept = default;
+    tensor(const tensor &other)
+        requires(OwnershipType == ownership_type::owner && MemorySpace == memory_space::host)
+    = default;
+    tensor(tensor &&other) noexcept
+        requires(OwnershipType == ownership_type::owner && MemorySpace == memory_space::host)
+    = default;
+    tensor(tensor &&other) noexcept
+        requires(OwnershipType == ownership_type::reference && MemorySpace == memory_space::device)
+    {
+        // move constructor for device tensors
+        data_ = other.data_;
+        if constexpr (dynamic_shape<Shape>) {
+            shape_ = std::move(other.shape_);
+            strides_ = std::move(other.strides_);
+        }
+        device_shape_ = other.device_shape_;
+        device_strides_ = other.device_strides_;
+        other.data_ = nullptr;
+        other.device_shape_ = nullptr;
+        other.device_strides_ = nullptr;
+    }
     // Device constructors for uninitialized data
-    tensor() requires(fixed_shape<Shape> && MemorySpace == memory_space::device && OwnershipType == ownership_type::reference)
+    tensor()
+        requires(fixed_shape<Shape> && MemorySpace == memory_space::device &&
+                 OwnershipType == ownership_type::reference)
     {
 #ifdef SQUINT_USE_CUDA
         cudaError_t malloc_status = cudaMalloc(&data_, _size() * sizeof(T));
@@ -132,38 +153,44 @@ class tensor {
         }
         auto host_shape = this->shape();
         auto host_strides = make_array(strides::column_major<Shape>{});
-        cudaError_t memcpy_status = cudaMemcpy(device_shape_, host_shape.data(), _rank() * sizeof(std::size_t), cudaMemcpyHostToDevice);
+        cudaError_t memcpy_status =
+            cudaMemcpy(device_shape_, host_shape.data(), _rank() * sizeof(std::size_t), cudaMemcpyHostToDevice);
         if (memcpy_status != cudaSuccess) {
             throw std::runtime_error("Failed to copy tensor shape to device");
         }
-        memcpy_status = cudaMemcpy(device_strides_, host_strides.data(), _rank() * sizeof(std::size_t), cudaMemcpyHostToDevice);
+        memcpy_status =
+            cudaMemcpy(device_strides_, host_strides.data(), _rank() * sizeof(std::size_t), cudaMemcpyHostToDevice);
         if (memcpy_status != cudaSuccess) {
             throw std::runtime_error("Failed to copy tensor strides to device");
         }
 #endif
     }
-    tensor(std::vector<size_t> shape) requires(dynamic_shape<Shape> && MemorySpace == memory_space::device && OwnershipType == ownership_type::reference)
+    tensor(std::vector<size_t> shape)
+        requires(dynamic_shape<Shape> && MemorySpace == memory_space::device &&
+                 OwnershipType == ownership_type::reference)
     {
 #ifdef SQUINT_USE_CUDA
         shape_ = shape;
-        strides_ = compute_strides(shape, layout::column_major);
-        cudaError_t malloc_status = cudaMalloc(&data_, _size() * sizeof(T));
+        strides_ = compute_strides(layout::column_major, shape);
+        cudaError_t malloc_status = cudaMalloc(&data_, size() * sizeof(T));
         if (malloc_status != cudaSuccess) {
             throw std::runtime_error("Failed to allocate device memory for tensor data");
         }
-        malloc_status = cudaMalloc(&device_shape_, _rank() * sizeof(std::size_t));
+        malloc_status = cudaMalloc(&device_shape_, rank() * sizeof(std::size_t));
         if (malloc_status != cudaSuccess) {
             throw std::runtime_error("Failed to allocate device memory for tensor shape");
         }
-        malloc_status = cudaMalloc(&device_strides_, _rank() * sizeof(std::size_t));
+        malloc_status = cudaMalloc(&device_strides_, rank() * sizeof(std::size_t));
         if (malloc_status != cudaSuccess) {
             throw std::runtime_error("Failed to allocate device memory for tensor strides");
         }
-        cudaError_t memcpy_status = cudaMemcpy(device_shape_, shape.data(), _rank() * sizeof(std::size_t), cudaMemcpyHostToDevice);
+        cudaError_t memcpy_status =
+            cudaMemcpy(device_shape_, shape.data(), rank() * sizeof(std::size_t), cudaMemcpyHostToDevice);
         if (memcpy_status != cudaSuccess) {
             throw std::runtime_error("Failed to copy tensor shape to device");
         }
-        memcpy_status = cudaMemcpy(device_strides_, strides_.data(), _rank() * sizeof(std::size_t), cudaMemcpyHostToDevice);
+        memcpy_status =
+            cudaMemcpy(device_strides_, strides_.data(), rank() * sizeof(std::size_t), cudaMemcpyHostToDevice);
         if (memcpy_status != cudaSuccess) {
             throw std::runtime_error("Failed to copy tensor strides to device");
         }
@@ -191,7 +218,7 @@ class tensor {
     // Conversion constructors
     template <typename U, typename OtherShape, typename OtherStrides>
     tensor(const tensor<U, OtherShape, OtherStrides, ErrorChecking, OwnershipType, MemorySpace> &other)
-        requires fixed_shape<Shape>;
+        requires(fixed_shape<Shape> && MemorySpace == memory_space::host);
     template <typename U, typename OtherShape, typename OtherStrides>
     tensor(const tensor<U, OtherShape, OtherStrides, ErrorChecking, ownership_type::reference, MemorySpace> &other)
         requires(OwnershipType == ownership_type::owner);
@@ -206,9 +233,15 @@ class tensor {
         requires(MemorySpace == memory_space::device && OwnershipType == ownership_type::reference)
     {
 #ifdef SQUINT_USE_CUDA
-        cudaFree(data_);
-        cudaFree(device_shape_);
-        cudaFree(device_strides_);
+        if (data_ != nullptr) {
+            cudaFree(data_);
+        }
+        if (device_shape_ != nullptr) {
+            cudaFree(device_shape_);
+        }
+        if (device_strides_ != nullptr) {
+            cudaFree(device_strides_);
+        }
 #endif
     }
 
@@ -290,35 +323,32 @@ class tensor {
     auto copy() const -> auto
         requires(OwnershipType == ownership_type::reference && MemorySpace == memory_space::device)
     {
-        if constexpr (fixed_shape<Shape>) {
-            using device_tensor_type = tensor<std::remove_const_t<T>, Shape, Strides, ErrorChecking,
-                                              ownership_type::reference, memory_space::device>;
-            size_t size = this->size() * sizeof(T);
-            // Create device pointer
-            void *device_ptr = nullptr;
-
-            // Allocate memory on the device
-            cudaError_t malloc_status = cudaMalloc(&device_ptr, size);
-            if (malloc_status != cudaSuccess) {
-                throw std::runtime_error("Failed to allocate device memory");
-            }
+        using device_tensor_type = tensor<std::remove_const_t<T>, Shape, Strides, ErrorChecking,
+                                          ownership_type::reference, memory_space::device>;
+        size_t size = this->size() * sizeof(T);
+        // Create device pointer
+        void *device_ptr = nullptr;
 
-            // Copy data from device to device
-            cudaError_t memcpy_status =
-                cudaMemcpy(device_ptr, static_cast<void *>(const_cast<std::remove_const_t<T> *>(this->data())), size,
-                           cudaMemcpyDeviceToDevice);
-            if (memcpy_status != cudaSuccess) {
-                cudaFree(device_ptr);
-                throw std::runtime_error("Failed to copy data to device");
-            }
+        // Allocate memory on the device
+        cudaError_t malloc_status = cudaMalloc(&device_ptr, size);
+        if (malloc_status != cudaSuccess) {
+            throw std::runtime_error("Failed to allocate device memory");
+        }
 
-            // Create and return the device tensor
-            if constexpr (dynamic_shape<Shape>) {
-                return device_tensor_type(static_cast<std::remove_const_t<T> *>(device_ptr), this->shape_,
-                                          this->strides_);
-            } else {
-                return device_tensor_type(static_cast<std::remove_const_t<T> *>(device_ptr));
-            }
+        // Copy data from device to device
+        cudaError_t memcpy_status =
+            cudaMemcpy(device_ptr, static_cast<void *>(const_cast<std::remove_const_t<T> *>(this->data())), size,
+                       cudaMemcpyDeviceToDevice);
+        if (memcpy_status != cudaSuccess) {
+            cudaFree(device_ptr);
+            throw std::runtime_error("Failed to copy data to device");
+        }
+
+        // Create and return the device tensor
+        if constexpr (dynamic_shape<Shape>) {
+            return device_tensor_type(static_cast<std::remove_const_t<T> *>(device_ptr), this->shape_, this->strides_);
+        } else {
+            return device_tensor_type(static_cast<std::remove_const_t<T> *>(device_ptr));
         }
     }
 
@@ -351,7 +381,7 @@ class tensor {
         // Create and return the device tensor
         if constexpr (dynamic_shape<Shape>) {
             if constexpr (ErrorChecking == error_checking::enabled) {
-                auto column_major_strides = this->compute_strides(this->shape(), layout::column_major);
+                auto column_major_strides = this->compute_strides(layout::column_major, this->shape());
                 auto strides = this->strides();
                 if (!std::equal(strides.begin(), strides.end(), column_major_strides.begin())) {
                     cudaFree(device_ptr);
@@ -375,17 +405,21 @@ class tensor {
         if constexpr (dynamic_shape<Shape>) {
             host_tensor_type host_tensor(this->shape());
             // Copy data from device to host
-            cudaError_t memcpy_status = cudaMemcpy(host_tensor.data(), this->data(), size, cudaMemcpyDeviceToHost);
+            cudaError_t memcpy_status =
+                cudaMemcpy(static_cast<void *>(host_tensor.data()), static_cast<const void *>(this->data()), size,
+                           cudaMemcpyDeviceToHost);
             if (memcpy_status != cudaSuccess) {
-                throw std::runtime_error("Failed to copy data to host");
+                throw std::runtime_error("Failed to copy data to host error code: " + std::to_string(memcpy_status));
             }
             return host_tensor;
         } else {
-            host_tensor_type host_tensor;
+            host_tensor_type host_tensor{};
             // Copy data from device to host
-            cudaError_t memcpy_status = cudaMemcpy(host_tensor.data(), this->data(), size, cudaMemcpyDeviceToHost);
+            cudaError_t memcpy_status =
+                cudaMemcpy(static_cast<void *>(host_tensor.data()), static_cast<const void *>(this->data()), size,
+                           cudaMemcpyDeviceToHost);
             if (memcpy_status != cudaSuccess) {
-                throw std::runtime_error("Failed to copy data to host");
+                throw std::runtime_error("Failed to copy data to host error code: " + std::to_string(memcpy_status));
             }
             return host_tensor;
         }
@@ -546,12 +580,14 @@ class tensor {
     // Comparison operators
     template <typename U, typename OtherShape, typename OtherStrides, enum error_checking OtherErrorChecking,
               enum ownership_type OtherOwnershipType>
-    auto operator==(const tensor<U, OtherShape, OtherStrides, OtherErrorChecking, OtherOwnershipType, MemorySpace>
-                        &other) const -> tensor<bool, Shape, Strides, ErrorChecking, ownership_type::owner, MemorySpace>;
+    auto
+    operator==(const tensor<U, OtherShape, OtherStrides, OtherErrorChecking, OtherOwnershipType, MemorySpace> &other)
+        const -> tensor<std::uint8_t, Shape, Strides, ErrorChecking, ownership_type::owner, MemorySpace>;
     template <typename U, typename OtherShape, typename OtherStrides, enum error_checking OtherErrorChecking,
               enum ownership_type OtherOwnershipType>
-    auto operator!=(const tensor<U, OtherShape, OtherStrides, OtherErrorChecking, OtherOwnershipType, MemorySpace>
-                        &other) const -> tensor<bool, Shape, Strides, ErrorChecking, ownership_type::owner, MemorySpace>;
+    auto
+    operator!=(const tensor<U, OtherShape, OtherStrides, OtherErrorChecking, OtherOwnershipType, MemorySpace> &other)
+        const -> tensor<std::uint8_t, Shape, Strides, ErrorChecking, ownership_type::owner, MemorySpace>;
     // Unary operators
     auto operator-() const -> tensor;
     // scalar operations
diff --git a/include/squint/tensor/tensor_constructors.hpp b/include/squint/tensor/tensor_constructors.hpp
index 060668d..1ed2217 100644
--- a/include/squint/tensor/tensor_constructors.hpp
+++ b/include/squint/tensor/tensor_constructors.hpp
@@ -182,7 +182,7 @@ template <typename T, typename Shape, typename Strides, error_checking ErrorChec
 template <typename U, typename OtherShape, typename OtherStrides>
 tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace>::tensor(
     const tensor<U, OtherShape, OtherStrides, ErrorChecking, OwnershipType, MemorySpace> &other)
-    requires fixed_shape<Shape>
+    requires(fixed_shape<Shape> && MemorySpace == memory_space::host)
 {
     if constexpr (OwnershipType == ownership_type::owner) {
         // for owner ownership, only shape must be convertible
@@ -250,7 +250,7 @@ template <typename T, typename Shape, typename Strides, error_checking ErrorChec
           memory_space MemorySpace>
 tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace>::tensor(T *data, Shape shape, Strides strides)
     requires(dynamic_shape<Shape> && OwnershipType == ownership_type::reference)
-    : data_(data), shape_(std::move(shape)), strides_(std::move(strides)) {
+    : data_(data), shape_(shape), strides_(strides) {
     if (ErrorChecking == error_checking::enabled) {
         if (!implicit_convertible_shapes_vector(shape, this->shape())) {
             throw std::runtime_error("Invalid shape conversion");
diff --git a/include/squint/tensor/tensor_math.hpp b/include/squint/tensor/tensor_math.hpp
index 815a7af..29918dc 100644
--- a/include/squint/tensor/tensor_math.hpp
+++ b/include/squint/tensor/tensor_math.hpp
@@ -40,7 +40,7 @@ namespace squint {
  * @return The pivot indices.
  * @throws std::runtime_error if the system is singular or an error occurs during the solution.
  */
-template <tensorial T1, tensorial T2> auto solve(T1 &A, T2 &B) {
+template <host_tensor T1, host_tensor T2> auto solve(T1 &A, T2 &B) {
     blas_compatible(A, B);
     solve_compatible(A, B);
     static_assert(dimensionless_scalar<typename T1::value_type>);
@@ -88,7 +88,7 @@ template <tensorial T1, tensorial T2> auto solve(T1 &A, T2 &B) {
  * @return The pivot indices.
  * @throws std::runtime_error if an error occurs during the solution.
  */
-template <tensorial T1, tensorial T2> auto solve_general(T1 &A, T2 &B) {
+template <host_tensor T1, host_tensor T2> auto solve_general(T1 &A, T2 &B) {
     blas_compatible(A, B);
     solve_general_compatible(A, B);
     static_assert(dimensionless_scalar<typename T1::value_type>);
@@ -138,7 +138,7 @@ template <tensorial T1, tensorial T2> auto solve_general(T1 &A, T2 &B) {
  * @return The inverted matrix.
  * @throws std::runtime_error if the matrix is singular or an error occurs during inversion.
  */
-template <tensorial T> auto inv(const T &A) {
+template <host_tensor T> auto inv(const T &A) {
     inversion_compatible(A);
     static_assert(dimensionless_scalar<typename T::value_type>);
     using blas_type = blas_type_t<typename T::value_type>;
@@ -215,6 +215,7 @@ template <tensorial T> auto inv(const T &A) {
  *
  */
 template <fixed_tensor T> auto pinv(const T &A) {
+    static_assert(host_tensor<T>, "Pseudoinverse is only supported for host tensors");
     constexpr int m = make_array(typename T::shape_type{})[0];
     constexpr int n = make_array(typename T::shape_type{}).size() > 1 ? make_array(typename T::shape_type{})[1] : 1;
 
@@ -230,6 +231,7 @@ template <fixed_tensor T> auto pinv(const T &A) {
 }
 
 template <dynamic_tensor T> auto pinv(const T &A) {
+    static_assert(host_tensor<T>, "Pseudoinverse is only supported for host tensors");
     int m = A.shape()[0];
     int n = A.rank() > 1 ? A.shape()[1] : 1;
 
@@ -250,7 +252,7 @@ template <dynamic_tensor T> auto pinv(const T &A) {
  * @return The cross product of a and b.
  * @throws std::invalid_argument if the vectors are not 3D.
  */
-template <tensorial T1, tensorial T2> auto cross(const T1 &a, const T2 &b) {
+template <host_tensor T1, host_tensor T2> auto cross(const T1 &a, const T2 &b) {
     if constexpr (fixed_tensor<T1> && fixed_tensor<T2>) {
         static_assert(T1::shape_type::size() == 1 && T2::shape_type::size() == 1 &&
                           std::get<0>(make_array(typename T1::shape_type{})) == 3 &&
@@ -280,7 +282,7 @@ template <tensorial T1, tensorial T2> auto cross(const T1 &a, const T2 &b) {
  * @return The dot product of a and b.
  * @throws std::invalid_argument if the vectors have different sizes.
  */
-template <tensorial T1, tensorial T2> auto dot(const T1 &a, const T2 &b) {
+template <host_tensor T1, host_tensor T2> auto dot(const T1 &a, const T2 &b) {
     if constexpr (fixed_tensor<T1> && fixed_tensor<T2>) {
         static_assert(T1::shape_type::size() == 1 && T2::shape_type::size() == 1 &&
                           std::get<0>(make_array(typename T1::shape_type{})) ==
@@ -309,7 +311,7 @@ template <tensorial T1, tensorial T2> auto dot(const T1 &a, const T2 &b) {
  * @return The trace of the matrix.
  * @throws std::invalid_argument if the matrix is not square.
  */
-template <tensorial T> auto trace(const T &a) {
+template <host_tensor T> auto trace(const T &a) {
     if constexpr (fixed_tensor<T>) {
         static_assert(T::shape_type::size() == 2 && std::get<0>(make_array(typename T::shape_type{})) ==
                                                         std::get<1>(make_array(typename T::shape_type{})),
@@ -334,7 +336,7 @@ template <tensorial T> auto trace(const T &a) {
  * @param a The input vector.
  * @return The Euclidean norm of the vector.
  */
-template <tensorial T> auto norm(const T &a) {
+template <host_tensor T> auto norm(const T &a) {
     using value_type = typename T::value_type;
     if constexpr (quantitative<value_type>) {
         return sqrt(squared_norm(a));
@@ -348,7 +350,7 @@ template <tensorial T> auto norm(const T &a) {
  * @param a The input vector.
  * @return The squared Euclidean norm of the vector.
  */
-template <tensorial T> auto squared_norm(const T &a) {
+template <host_tensor T> auto squared_norm(const T &a) {
     using value_type = typename T::value_type;
     using result_type =
         std::conditional_t<quantitative<value_type>, decltype(std::declval<value_type>() * std::declval<value_type>()),
@@ -371,14 +373,14 @@ template <tensorial T> auto squared_norm(const T &a) {
  * @param a The input vector.
  * @return The normalized vector.
  */
-template <tensorial T> auto normalize(const T &a) { return a / norm(a); }
+template <host_tensor T> auto normalize(const T &a) { return a / norm(a); }
 
 /**
  * @brief Computes the mean of all elements in the tensor.
  * @param a The input tensor.
  * @return The mean value of all elements.
  */
-template <tensorial T> auto mean(const T &a) {
+template <host_tensor T> auto mean(const T &a) {
     typename T::value_type sum = 0;
     size_t count = 0;
 
@@ -395,21 +397,21 @@ template <tensorial T> auto mean(const T &a) {
  * @param a The input tensor.
  * @return The sum of all elements.
  */
-template <tensorial T> auto sum(const T &a) { return std::accumulate(a.begin(), a.end(), typename T::value_type(0)); }
+template <host_tensor T> auto sum(const T &a) { return std::accumulate(a.begin(), a.end(), typename T::value_type(0)); }
 
 /**
  * @brief Finds the minimum element in the tensor.
  * @param a The input tensor.
  * @return The minimum element.
  */
-template <tensorial T> auto min(const T &a) { return *std::min_element(a.begin(), a.end()); }
+template <host_tensor T> auto min(const T &a) { return *std::min_element(a.begin(), a.end()); }
 
 /**
  * @brief Finds the maximum element in the tensor.
  * @param a The input tensor.
  * @return The maximum element.
  */
-template <tensorial T> auto max(const T &a) { return *std::max_element(a.begin(), a.end()); }
+template <host_tensor T> auto max(const T &a) { return *std::max_element(a.begin(), a.end()); }
 
 /**
  * @brief Checks if two tensors are approximately equal within a given tolerance.
@@ -418,7 +420,7 @@ template <tensorial T> auto max(const T &a) { return *std::max_element(a.begin()
  * @param tol The tolerance for comparison (default is machine epsilon).
  * @return True if the tensors are approximately equal, false otherwise.
  */
-template <tensorial T1, tensorial T2>
+template <host_tensor T1, host_tensor T2>
 auto approx_equal(
     const T1 &a, const T2 &b,
     typename std::common_type_t<typename T1::value_type, typename T2::value_type> tol =
@@ -453,6 +455,8 @@ auto approx_equal(
  */
 template <dynamic_tensor Tensor1, dynamic_tensor Tensor2>
 auto contract(const Tensor1 &A, const Tensor2 &B, const std::vector<std::pair<size_t, size_t>> &contraction_pairs) {
+    static_assert(host_tensor<Tensor1> && host_tensor<Tensor2>,
+                  "Tensor contraction is only supported for host tensors");
     auto A_shape = A.shape();
     auto B_shape = B.shape();
     size_t A_rank = A_shape.size();
@@ -608,6 +612,8 @@ template <typename Tensor1, typename Tensor2, typename Sequence1, typename Seque
  */
 template <fixed_tensor Tensor1, fixed_tensor Tensor2, typename Sequence1, typename Sequence2>
 auto contract(const Tensor1 &A, const Tensor2 &B, const Sequence1 /*unused*/, const Sequence2 /*unused*/) {
+    static_assert(host_tensor<Tensor1> && host_tensor<Tensor2>,
+                  "Tensor contraction is only supported for host tensors");
     using types = contraction_types<Tensor1, Tensor2, Sequence1, Sequence2>;
     using result_value_type = std::common_type_t<typename Tensor1::value_type, typename Tensor2::value_type>;
 
@@ -639,6 +645,8 @@ auto contract(const Tensor1 &A, const Tensor2 &B, const Sequence1 /*unused*/, co
  */
 template <dynamic_tensor Tensor1, dynamic_tensor Tensor2>
 auto einsum(const std::string &subscripts, const Tensor1 &A, const Tensor2 &B) {
+    static_assert(host_tensor<Tensor1> && host_tensor<Tensor2>,
+                  "Tensor contraction is only supported for host tensors");
     // Parse the subscripts
     auto pos = subscripts.find("->");
     if (pos == std::string::npos) {
@@ -709,6 +717,7 @@ auto einsum(const std::string &subscripts, const Tensor1 &A, const Tensor2 &B) {
  * specifies a matrix transpose operation.
  */
 template <dynamic_tensor Tensor> auto einsum(const std::string &subscripts, const Tensor &tensor) {
+    static_assert(host_tensor<Tensor>, "Tensor contraction is only supported for host tensors");
     // Parse the subscripts
     auto pos = subscripts.find("->");
     if (pos == std::string::npos) {
@@ -783,6 +792,8 @@ template <typename ASubscripts, typename BSubscripts> struct get_contraction_ind
 template <typename ASubscripts, typename BSubscripts, typename OutputSubscripts, fixed_tensor Tensor1,
           fixed_tensor Tensor2>
 auto einsum(const Tensor1 &A, const Tensor2 &B) {
+    static_assert(host_tensor<Tensor1> && host_tensor<Tensor2>,
+                  "Tensor contraction is only supported for host tensors");
     using a_contractions = typename get_contraction_indices<ASubscripts, BSubscripts>::type;
     using b_contractions = typename get_contraction_indices<BSubscripts, ASubscripts>::type;
 
@@ -809,6 +820,7 @@ auto einsum(const Tensor1 &A, const Tensor2 &B) {
  * @return The result of the einsum operation.
  */
 template <typename InputSubscripts, typename OutputSubscripts, typename Tensor> auto einsum(const Tensor &tensor) {
+    static_assert(host_tensor<Tensor>, "Tensor contraction is only supported for host tensors");
     if constexpr (std::is_same_v<InputSubscripts, OutputSubscripts>) {
         // No operation needed
         return tensor;
diff --git a/include/squint/tensor/tensor_ops.hpp b/include/squint/tensor/tensor_ops.hpp
index 9146468..7329fa1 100644
--- a/include/squint/tensor/tensor_ops.hpp
+++ b/include/squint/tensor/tensor_ops.hpp
@@ -22,6 +22,12 @@
 #include <utility>
 #include <vector>
 
+#ifdef SQUINT_USE_CUDA
+#include "squint/tensor/cuda/cuda_context.hpp"
+#include <cublas_v2.h>
+#include <cuda_runtime.h>
+#endif
+
 namespace squint {
 
 /**
@@ -30,10 +36,7 @@ namespace squint {
  * @param t2 The second tensor to multiply.
  * @return A new tensor containing the result of the multiplication.
  */
-template <tensorial Tensor1, tensorial Tensor2>
-auto operator*(const Tensor1 &t1, const Tensor2 &t2)
-    requires(host_tensor<Tensor1> && host_tensor<Tensor2>)
-{
+template <tensorial Tensor1, tensorial Tensor2> auto operator*(const Tensor1 &t1, const Tensor2 &t2) {
     matrix_multiply_compatible(t1, t2);
     blas_compatible(t1, t2);
     using blas_type =
@@ -62,64 +65,118 @@ auto operator*(const Tensor1 &t1, const Tensor2 &t2)
     blas_type beta = 0;
 
     if constexpr (fixed_tensor<Tensor1> && fixed_tensor<Tensor2>) {
-        using strides_type = strides::column_major<result_shape_type>;
-        using result_type = tensor<result_value_type, result_shape_type, strides_type, result_error_checking::value,
-                                   ownership_type::owner, memory_space::host>;
-        result_type result{};
-        if constexpr (std::is_same_v<blas_type, float>) {
-            // NOLINTBEGIN
-            cblas_sgemm(
-                CBLAS_ORDER::CblasColMajor, op_a, op_b, m, n, k, alpha,
-                reinterpret_cast<float *>(const_cast<std::remove_const_t<typename Tensor1::value_type> *>(t1.data())),
-                lda,
-                reinterpret_cast<float *>(const_cast<std::remove_const_t<typename Tensor2::value_type> *>(t2.data())),
-                ldb, beta, reinterpret_cast<float *>(result.data()), ldc);
-            // NOLINTEND
-        } else if constexpr (std::is_same_v<blas_type, double>) {
+        if constexpr (host_tensor<Tensor1> && host_tensor<Tensor2>) {
+            using strides_type = strides::column_major<result_shape_type>;
+            using result_type = tensor<result_value_type, result_shape_type, strides_type, result_error_checking::value,
+                                       ownership_type::owner, memory_space::host>;
+            result_type result{};
+            if constexpr (std::is_same_v<blas_type, float>) {
+                // NOLINTBEGIN
+                cblas_sgemm(CBLAS_ORDER::CblasColMajor, op_a, op_b, m, n, k, alpha,
+                            reinterpret_cast<float *>(
+                                const_cast<std::remove_const_t<typename Tensor1::value_type> *>(t1.data())),
+                            lda,
+                            reinterpret_cast<float *>(
+                                const_cast<std::remove_const_t<typename Tensor2::value_type> *>(t2.data())),
+                            ldb, beta, reinterpret_cast<float *>(result.data()), ldc);
+                // NOLINTEND
+            } else if constexpr (std::is_same_v<blas_type, double>) {
+                // NOLINTBEGIN
+                cblas_dgemm(CBLAS_ORDER::CblasColMajor, op_a, op_b, m, n, k, alpha,
+                            reinterpret_cast<double *>(
+                                const_cast<std::remove_const_t<typename Tensor1::value_type> *>(t1.data())),
+                            lda,
+                            reinterpret_cast<double *>(
+                                const_cast<std::remove_const_t<typename Tensor2::value_type> *>(t2.data())),
+                            ldb, beta,
+                            reinterpret_cast<double *>(
+                                const_cast<std::remove_const_t<typename result_type::value_type> *>(result.data())),
+                            ldc);
+                // NOLINTEND
+            }
+            return std::move(result);
+        } else {
+#ifdef SQUINT_USE_CUDA
             // NOLINTBEGIN
-            cblas_dgemm(
-                CBLAS_ORDER::CblasColMajor, op_a, op_b, m, n, k, alpha,
-                reinterpret_cast<double *>(const_cast<std::remove_const_t<typename Tensor1::value_type> *>(t1.data())),
-                lda,
-                reinterpret_cast<double *>(const_cast<std::remove_const_t<typename Tensor2::value_type> *>(t2.data())),
-                ldb, beta,
-                reinterpret_cast<double *>(
-                    const_cast<std::remove_const_t<typename result_type::value_type> *>(result.data())),
-                ldc);
+            using strides_type = strides::column_major<result_shape_type>;
+            using result_type = tensor<result_value_type, result_shape_type, strides_type, result_error_checking::value,
+                                       ownership_type::reference, memory_space::device>;
+            auto handle = cuda::cuda_context::instance().cublas_handle();
+            auto cublas_op_a = (op_a == CBLAS_TRANSPOSE::CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+            auto cublas_op_b = (op_b == CBLAS_TRANSPOSE::CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+            result_type result{};
+            if constexpr (std::is_same_v<blas_type, float>) {
+                cublasSgemm(handle, cublas_op_a, cublas_op_b, m, n, k, &alpha,
+                            reinterpret_cast<const float *>(t1.data()), lda, reinterpret_cast<const float *>(t2.data()),
+                            ldb, &beta, reinterpret_cast<float *>(result.data()), ldc);
+            } else if constexpr (std::is_same_v<blas_type, double>) {
+                cublasDgemm(handle, cublas_op_a, cublas_op_b, m, n, k, &alpha,
+                            reinterpret_cast<const double *>(t1.data()), lda,
+                            reinterpret_cast<const double *>(t2.data()), ldb, &beta,
+                            reinterpret_cast<double *>(result.data()), ldc);
+            }
+            return std::move(result);
             // NOLINTEND
+#endif
         }
-        return result;
     } else {
-        using strides_type = std::vector<std::size_t>;
-        using result_type = tensor<result_value_type, result_shape_type, strides_type, result_error_checking::value,
-                                   ownership_type::owner, memory_space::host>;
-        result_type result({static_cast<std::size_t>(m), static_cast<std::size_t>(n)}, layout::column_major);
-        if constexpr (std::is_same_v<blas_type, float>) {
-            // NOLINTBEGIN
-            cblas_sgemm(
-                CBLAS_ORDER::CblasColMajor, op_a, op_b, m, n, k, alpha,
-                reinterpret_cast<float *>(const_cast<std::remove_const_t<typename Tensor1::value_type> *>(t1.data())),
-                lda,
-                reinterpret_cast<float *>(const_cast<std::remove_const_t<typename Tensor2::value_type> *>(t2.data())),
-                ldb, beta,
-                reinterpret_cast<float *>(
-                    const_cast<std::remove_const_t<typename result_type::value_type> *>(result.data())),
-                ldc);
-            // NOLINTEND
-        } else if constexpr (std::is_same_v<blas_type, double>) {
+        if constexpr (host_tensor<Tensor1> && host_tensor<Tensor2>) {
+            using strides_type = std::vector<std::size_t>;
+            using result_type = tensor<result_value_type, result_shape_type, strides_type, result_error_checking::value,
+                                       ownership_type::owner, memory_space::host>;
+            result_type result({static_cast<std::size_t>(m), static_cast<std::size_t>(n)}, layout::column_major);
+            if constexpr (std::is_same_v<blas_type, float>) {
+                // NOLINTBEGIN
+                cblas_sgemm(CBLAS_ORDER::CblasColMajor, op_a, op_b, m, n, k, alpha,
+                            reinterpret_cast<float *>(
+                                const_cast<std::remove_const_t<typename Tensor1::value_type> *>(t1.data())),
+                            lda,
+                            reinterpret_cast<float *>(
+                                const_cast<std::remove_const_t<typename Tensor2::value_type> *>(t2.data())),
+                            ldb, beta,
+                            reinterpret_cast<float *>(
+                                const_cast<std::remove_const_t<typename result_type::value_type> *>(result.data())),
+                            ldc);
+                // NOLINTEND
+            } else if constexpr (std::is_same_v<blas_type, double>) {
+                // NOLINTBEGIN
+                cblas_dgemm(CBLAS_ORDER::CblasColMajor, op_a, op_b, m, n, k, alpha,
+                            reinterpret_cast<double *>(
+                                const_cast<std::remove_const_t<typename Tensor1::value_type> *>(t1.data())),
+                            lda,
+                            reinterpret_cast<double *>(
+                                const_cast<std::remove_const_t<typename Tensor2::value_type> *>(t2.data())),
+                            ldb, beta,
+                            reinterpret_cast<double *>(
+                                const_cast<std::remove_const_t<typename result_type::value_type> *>(result.data())),
+                            ldc);
+                // NOLINTEND
+            }
+            return std::move(result);
+        } else {
+#ifdef SQUINT_USE_CUDA
             // NOLINTBEGIN
-            cblas_dgemm(
-                CBLAS_ORDER::CblasColMajor, op_a, op_b, m, n, k, alpha,
-                reinterpret_cast<double *>(const_cast<std::remove_const_t<typename Tensor1::value_type> *>(t1.data())),
-                lda,
-                reinterpret_cast<double *>(const_cast<std::remove_const_t<typename Tensor2::value_type> *>(t2.data())),
-                ldb, beta,
-                reinterpret_cast<double *>(
-                    const_cast<std::remove_const_t<typename result_type::value_type> *>(result.data())),
-                ldc);
-            // NOLINTEND
+            using strides_type = std::vector<std::size_t>;
+            using result_type = tensor<result_value_type, result_shape_type, strides_type, result_error_checking::value,
+                                       ownership_type::reference, memory_space::device>;
+            result_type result({static_cast<std::size_t>(m), static_cast<std::size_t>(n)});
+            auto handle = cuda::cuda_context::instance().cublas_handle();
+            auto cublas_op_a = (op_a == CBLAS_TRANSPOSE::CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+            auto cublas_op_b = (op_b == CBLAS_TRANSPOSE::CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+            if constexpr (std::is_same_v<blas_type, float>) {
+                cublasSgemm(handle, cublas_op_a, cublas_op_b, m, n, k, &alpha,
+                            reinterpret_cast<const float *>(t1.data()), lda, reinterpret_cast<const float *>(t2.data()),
+                            ldb, &beta, reinterpret_cast<float *>(result.data()), ldc);
+            } else if constexpr (std::is_same_v<blas_type, double>) {
+                cublasDgemm(handle, cublas_op_a, cublas_op_b, m, n, k, &alpha,
+                            reinterpret_cast<const double *>(t1.data()), lda,
+                            reinterpret_cast<const double *>(t2.data()), ldb, &beta,
+                            reinterpret_cast<double *>(result.data()), ldc);
+            }
+            return std::move(result);
+// NOLINTEND
+#endif
         }
-        return result;
     }
 }
 
@@ -133,7 +190,7 @@ auto operator*(const Tensor1 &t1, const Tensor2 &t2)
  * This function computes the least squares solution when Ax =B is overdetermined and the least norm solution when Ax =
  * B is underdetermined.
  */
-template <tensorial T1, tensorial T2> auto operator/(const T1 &B, const T2 &A) {
+template <host_tensor T1, host_tensor T2> auto operator/(const T1 &B, const T2 &A) {
     using blas_type = std::common_type_t<blas_type_t<typename T1::value_type>, blas_type_t<typename T2::value_type>>;
     using result_value_type =
         decltype(std::declval<typename T1::value_type>() / std::declval<typename T2::value_type>());
diff --git a/include/squint/tensor/tensor_shape_manipulation.hpp b/include/squint/tensor/tensor_shape_manipulation.hpp
index f7bf5b9..c1ed032 100644
--- a/include/squint/tensor/tensor_shape_manipulation.hpp
+++ b/include/squint/tensor/tensor_shape_manipulation.hpp
@@ -298,6 +298,20 @@ auto tensor<T, Shape, Strides, ErrorChecking, OwnershipType, MemorySpace>::set_s
 
     this->shape_ = new_shape;
     this->strides_ = compute_strides(l, new_shape);
+    if constexpr (MemorySpace == memory_space::device && OwnershipType == ownership_type::reference) {
+        // If the tensor is a device reference, we need to update the device shape and strides as well
+        // cuda memcopy
+        cudaError_t memcpy_status = cudaMemcpy(this->device_shape_.data(), new_shape.data(),
+                                               new_shape.size() * sizeof(size_t), cudaMemcpyHostToDevice);
+        if (memcpy_status != cudaSuccess) {
+            throw std::runtime_error("Failed to copy data to device");
+        }
+        memcpy_status = cudaMemcpy(this->device_strides_.data(), this->strides_.data(),
+                                   this->strides_.size() * sizeof(size_t), cudaMemcpyHostToDevice);
+        if (memcpy_status != cudaSuccess) {
+            throw std::runtime_error("Failed to copy data to device");
+        }
+    }
 }
 
 /**
diff --git a/include/squint/tensor/tensor_types.hpp b/include/squint/tensor/tensor_types.hpp
index 8a942f1..c0e5aec 100644
--- a/include/squint/tensor/tensor_types.hpp
+++ b/include/squint/tensor/tensor_types.hpp
@@ -31,9 +31,9 @@ using vec4 = vec4_t<float>;
 using dvec2 = vec2_t<double>;
 using dvec3 = vec3_t<double>;
 using dvec4 = vec4_t<double>;
-using bvec2 = vec2_t<bool>;
-using bvec3 = vec3_t<bool>;
-using bvec4 = vec4_t<bool>;
+using bvec2 = vec2_t<uint8_t>;
+using bvec3 = vec3_t<uint8_t>;
+using bvec4 = vec4_t<uint8_t>;
 
 // Square matrix types
 template <typename T> using mat2_t = tensor<T, shape<2, 2>>;
@@ -51,9 +51,9 @@ using mat4 = mat4_t<float>;
 using dmat2 = mat2_t<double>;
 using dmat3 = mat3_t<double>;
 using dmat4 = mat4_t<double>;
-using bmat2 = mat2_t<bool>;
-using bmat3 = mat3_t<bool>;
-using bmat4 = mat4_t<bool>;
+using bmat2 = mat2_t<uint8_t>;
+using bmat3 = mat3_t<uint8_t>;
+using bmat4 = mat4_t<uint8_t>;
 
 // Non-square matrix types
 template <typename T> using mat2x3_t = tensor<T, shape<2, 3>>;
@@ -86,12 +86,12 @@ using dmat3x2 = mat3x2_t<double>;
 using dmat3x4 = mat3x4_t<double>;
 using dmat4x2 = mat4x2_t<double>;
 using dmat4x3 = mat4x3_t<double>;
-using bmat2x3 = mat2x3_t<bool>;
-using bmat2x4 = mat2x4_t<bool>;
-using bmat3x2 = mat3x2_t<bool>;
-using bmat3x4 = mat3x4_t<bool>;
-using bmat4x2 = mat4x2_t<bool>;
-using bmat4x3 = mat4x3_t<bool>;
+using bmat2x3 = mat2x3_t<uint8_t>;
+using bmat2x4 = mat2x4_t<uint8_t>;
+using bmat3x2 = mat3x2_t<uint8_t>;
+using bmat3x4 = mat3x4_t<uint8_t>;
+using bmat4x2 = mat4x2_t<uint8_t>;
+using bmat4x3 = mat4x3_t<uint8_t>;
 
 // General tensor shapes
 template <typename T, std::size_t... Dims> using ndarr_t = tensor<T, shape<Dims...>>;
@@ -99,14 +99,14 @@ template <std::size_t... Dims> using indarr = ndarr_t<int, Dims...>;
 template <std::size_t... Dims> using undarr = ndarr_t<unsigned char, Dims...>;
 template <std::size_t... Dims> using ndarr = ndarr_t<float, Dims...>;
 template <std::size_t... Dims> using dndarr = ndarr_t<double, Dims...>;
-template <std::size_t... Dims> using bndarr = ndarr_t<bool, Dims...>;
+template <std::size_t... Dims> using bndarr = ndarr_t<uint8_t, Dims...>;
 
 template <typename T> using tens_t = tensor<T, dynamic, dynamic>;
 using itens = tens_t<int>;
 using utens = tens_t<unsigned char>;
 using tens = tens_t<float>;
 using dtens = tens_t<double>;
-using btens = tens_t<bool>;
+using btens = tens_t<uint8_t>;
 
 } // namespace squint
 
diff --git a/main.cpp b/main.cpp
new file mode 100644
index 0000000..183cb3b
--- /dev/null
+++ b/main.cpp
@@ -0,0 +1,15 @@
+#include <squint/squint.hpp>
+
+using namespace squint;
+
+auto main() -> int {
+    tensor<float, dynamic, dynamic> a({2, 3}, std::vector<float>{1, 4, 2, 5, 3, 6});
+    auto a_device = a.to_device();
+    auto b_device = a_device * 2.0f;
+    auto b_host = b_device.to_host();
+    std::cout << "b_host: " << b_host << std::endl;
+    auto permute_device = a_device.permute({1, 0});
+    auto permute_host = permute_device.to_host();
+    std::cout << "permute_host: " << permute_host << std::endl;
+    return 0;
+}
\ No newline at end of file
diff --git a/src/tensor/cuda/element_wise.cu b/src/tensor/cuda/element_wise.cu
index 8ecde63..c373838 100644
--- a/src/tensor/cuda/element_wise.cu
+++ b/src/tensor/cuda/element_wise.cu
@@ -91,7 +91,7 @@ template void element_wise_subtraction<double>(double *output, const double *a,
 
 // CUDA kernel for element-wise equality with different strides
 template <typename T>
-__global__ void element_wise_equality_kernel(bool *output, const T *a, const T *b, const unsigned long *dims,
+__global__ void element_wise_equality_kernel(uint8_t *output, const T *a, const T *b, const unsigned long *dims,
                                              const unsigned long *strides_out, const unsigned long *strides_a,
                                              const unsigned long *strides_b, unsigned long num_dims,
                                              unsigned long total_size) {
@@ -113,7 +113,7 @@ __global__ void element_wise_equality_kernel(bool *output, const T *a, const T *
 }
 
 template <typename T>
-void element_wise_equality(bool *output, const T *a, const T *b, const unsigned long *dims,
+void element_wise_equality(uint8_t *output, const T *a, const T *b, const unsigned long *dims,
                            const unsigned long *strides_out, const unsigned long *strides_a,
                            const unsigned long *strides_b, unsigned long num_dims, unsigned long total_size) {
     int block_size = 256;
@@ -122,19 +122,19 @@ void element_wise_equality(bool *output, const T *a, const T *b, const unsigned
                                                              num_dims, total_size);
 }
 
-template void element_wise_equality<float>(bool *output, const float *a, const float *b, const unsigned long *dims,
+template void element_wise_equality<float>(uint8_t *output, const float *a, const float *b, const unsigned long *dims,
                                            const unsigned long *strides_out, const unsigned long *strides_a,
                                            const unsigned long *strides_b, unsigned long num_dims,
                                            unsigned long total_size);
 
-template void element_wise_equality<double>(bool *output, const double *a, const double *b, const unsigned long *dims,
+template void element_wise_equality<double>(uint8_t *output, const double *a, const double *b, const unsigned long *dims,
                                             const unsigned long *strides_out, const unsigned long *strides_a,
                                             const unsigned long *strides_b, unsigned long num_dims,
                                             unsigned long total_size);
 
 // CUDA kernel for element-wise inequality with different strides
 template <typename T>
-__global__ void element_wise_inequality_kernel(bool *output, const T *a, const T *b, const unsigned long *dims,
+__global__ void element_wise_inequality_kernel(uint8_t *output, const T *a, const T *b, const unsigned long *dims,
                                                const unsigned long *strides_out, const unsigned long *strides_a,
                                                const unsigned long *strides_b, unsigned long num_dims,
                                                unsigned long total_size) {
@@ -156,7 +156,7 @@ __global__ void element_wise_inequality_kernel(bool *output, const T *a, const T
 }
 
 template <typename T>
-void element_wise_inequality(bool *output, const T *a, const T *b, const unsigned long *dims,
+void element_wise_inequality(uint8_t *output, const T *a, const T *b, const unsigned long *dims,
                              const unsigned long *strides_out, const unsigned long *strides_a,
                              const unsigned long *strides_b, unsigned long num_dims, unsigned long total_size) {
     int block_size = 256;
@@ -165,12 +165,49 @@ void element_wise_inequality(bool *output, const T *a, const T *b, const unsigne
                                                                num_dims, total_size);
 }
 
-template void element_wise_inequality<float>(bool *output, const float *a, const float *b, const unsigned long *dims,
+template void element_wise_inequality<float>(uint8_t *output, const float *a, const float *b, const unsigned long *dims,
                                              const unsigned long *strides_out, const unsigned long *strides_a,
                                              const unsigned long *strides_b, unsigned long num_dims,
                                              unsigned long total_size);
 
-template void element_wise_inequality<double>(bool *output, const double *a, const double *b, const unsigned long *dims,
+template void element_wise_inequality<double>(uint8_t *output, const double *a, const double *b, const unsigned long *dims,
                                               const unsigned long *strides_out, const unsigned long *strides_a,
                                               const unsigned long *strides_b, unsigned long num_dims,
                                               unsigned long total_size);
+
+// CUDA kernel for element-wise negation with different strides
+template <typename T>
+__global__ void element_wise_negation_kernel(T *output, const T *a, const unsigned long *dims,
+                                             const unsigned long *strides_out, const unsigned long *strides_a,
+                                             unsigned long num_dims, unsigned long total_size) {
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (idx < total_size) {
+        int flat_index_out = 0;
+        int flat_index_a = 0;
+        int temp = idx;
+        for (int i = num_dims - 1; i >= 0; --i) {
+            int dim_idx = temp % dims[i];
+            flat_index_out += dim_idx * strides_out[i];
+            flat_index_a += dim_idx * strides_a[i];
+            temp /= dims[i];
+        }
+        output[flat_index_out] = -a[flat_index_a];
+    }
+}
+
+template <typename T>
+void element_wise_negation(T *output, const T *a, const unsigned long *dims, const unsigned long *strides_out,
+                           const unsigned long *strides_a, unsigned long num_dims, unsigned long total_size) {
+    int block_size = 256;
+    int num_blocks = (total_size + block_size - 1) / block_size;
+    element_wise_negation_kernel<<<num_blocks, block_size>>>(output, a, dims, strides_out, strides_a, num_dims,
+                                                             total_size);
+}
+
+template void element_wise_negation<float>(float *output, const float *a, const unsigned long *dims,
+                                           const unsigned long *strides_out, const unsigned long *strides_a,
+                                           unsigned long num_dims, unsigned long total_size);
+
+template void element_wise_negation<double>(double *output, const double *a, const unsigned long *dims,
+                                            const unsigned long *strides_out, const unsigned long *strides_a,
+                                            unsigned long num_dims, unsigned long total_size);
diff --git a/src/tensor/cuda/scalar.cu b/src/tensor/cuda/scalar.cu
new file mode 100644
index 0000000..2938eae
--- /dev/null
+++ b/src/tensor/cuda/scalar.cu
@@ -0,0 +1,44 @@
+#include <cuda_runtime.h>
+#include <device_launch_parameters.h>
+
+#include "squint/tensor/cuda/scalar.hpp"
+
+// CUDA kernel for scalar multiplication with different strides
+template <typename T>
+__global__ void scalar_multiplication_kernel(T scalar, T *output, const T *a, const unsigned long *dims,
+                                             const unsigned long *strides_out, const unsigned long *strides_a,
+                                             unsigned long num_dims,
+                                             unsigned long total_size) {
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (idx < total_size) {
+        int flat_index_out = 0;
+        int flat_index_a = 0;
+        int temp = idx;
+        for (int i = num_dims - 1; i >= 0; --i) {
+            int dim_idx = temp % dims[i];
+            flat_index_out += dim_idx * strides_out[i];
+            flat_index_a += dim_idx * strides_a[i];
+            temp /= dims[i];
+        }
+        output[flat_index_out] = scalar * a[flat_index_a];
+    }
+}
+
+template <typename T>
+void scalar_multiplication(T scalar, T *output, const T *a, const unsigned long *dims,
+                           const unsigned long *strides_out, const unsigned long *strides_a,
+                           unsigned long num_dims, unsigned long total_size) {
+    int block_size = 256;
+    int num_blocks = (total_size + block_size - 1) / block_size;
+    scalar_multiplication_kernel<<<num_blocks, block_size>>>(scalar, output, a, dims, strides_out, strides_a,
+                                                             num_dims, total_size);
+}
+
+template void scalar_multiplication<float>(float scalar, float *output, const float *a, const unsigned long *dims,
+                                           const unsigned long *strides_out, const unsigned long *strides_a,
+                                           unsigned long num_dims, unsigned long total_size);
+
+template void scalar_multiplication<double>(double scalar, double *output, const double *a, const unsigned long *dims,
+                                            const unsigned long *strides_out, const unsigned long *strides_a,
+                                            unsigned long num_dims, unsigned long total_size);
+
diff --git a/tests/tensor_math_tests.cpp b/tests/tensor_math_tests.cpp
index f9b348a..523c8b7 100644
--- a/tests/tensor_math_tests.cpp
+++ b/tests/tensor_math_tests.cpp
@@ -587,9 +587,6 @@ TEST_CASE("tensor_einsum") {
         auto A = tens::arange(1, 1, {3});
         auto B = tens::arange(1, 1, {2});
         auto result = einsum("i,j->ij", A, B);
-        auto is_equal_mat = result == A * B;
-        bool is_equal = std::all_of(is_equal_mat.begin(), is_equal_mat.end(), [](bool b) { return b; });
-        CHECK(is_equal);
         CHECK(result.shape() == std::vector<size_t>{3, 2});
         CHECK(result(0, 0) == doctest::Approx(1));
         CHECK(result(0, 1) == doctest::Approx(2));
diff --git a/tests/tensor_ops_tests.cpp b/tests/tensor_ops_tests.cpp
index f081177..d68911d 100644
--- a/tests/tensor_ops_tests.cpp
+++ b/tests/tensor_ops_tests.cpp
@@ -23,6 +23,15 @@ TEST_CASE("Element-wise operations") {
                     CHECK(c(i, j) == doctest::Approx(a(i, j) + b(i, j)));
                 }
             }
+#ifdef SQUINT_USE_CUDA
+            auto c_device = a_device + b_device;
+            auto c_host = c_device.to_host();
+            for (int i = 0; i < 2; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(c_host(i, j) == doctest::Approx(a(i, j) + b(i, j)));
+                }
+            }
+#endif
         }
 
         SUBCASE("Subtraction") {
@@ -32,6 +41,15 @@ TEST_CASE("Element-wise operations") {
                     CHECK(c(i, j) == doctest::Approx(a(i, j) - b(i, j)));
                 }
             }
+#ifdef SQUINT_USE_CUDA
+            auto c_device = a_device - b_device;
+            auto c_host = c_device.to_host();
+            for (int i = 0; i < 2; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(c_host(i, j) == doctest::Approx(1));
+                }
+            }
+#endif
         }
 
         SUBCASE("Unary negation") {
@@ -41,6 +59,15 @@ TEST_CASE("Element-wise operations") {
                     CHECK(c(i, j) == doctest::Approx(-a(i, j)));
                 }
             }
+#ifdef SQUINT_USE_CUDA
+            auto c_device = -a_device;
+            auto c_host = c_device.to_host();
+            for (int i = 0; i < 2; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(c_host(i, j) == doctest::Approx(-a(i, j)));
+                }
+            }
+#endif
         }
 
         SUBCASE("In-place addition") {
@@ -75,7 +102,7 @@ TEST_CASE("Element-wise operations") {
             auto a_host = a_device.to_host();
             for (int i = 0; i < 2; ++i) {
                 for (int j = 0; j < 3; ++j) {
-                    CHECK(a_host(i, j) == doctest::Approx(-1));
+                    CHECK(a_host(i, j) == doctest::Approx(1));
                 }
             }
 #endif
@@ -85,6 +112,10 @@ TEST_CASE("Element-wise operations") {
     SUBCASE("Dynamic shape tensors") {
         tensor<float, dynamic, dynamic> a({2, 3}, std::vector<float>{1, 4, 2, 5, 3, 6});
         tensor<float, dynamic, dynamic> b({2, 3}, std::vector<float>{2, 5, 3, 6, 4, 7});
+#ifdef SQUINT_USE_CUDA
+        auto a_device = a.to_device();
+        auto b_device = b.to_device();
+#endif
 
         SUBCASE("Addition") {
             auto c = a + b;
@@ -93,6 +124,15 @@ TEST_CASE("Element-wise operations") {
                     CHECK(c(i, j) == doctest::Approx(a(i, j) + b(i, j)));
                 }
             }
+#ifdef SQUINT_USE_CUDA
+            auto c_device = a_device + b_device;
+            auto c_host = c_device.to_host();
+            for (int i = 0; i < 2; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(c_host(i, j) == doctest::Approx(a(i, j) + b(i, j)));
+                }
+            }
+#endif
         }
 
         SUBCASE("Subtraction") {
@@ -102,6 +142,15 @@ TEST_CASE("Element-wise operations") {
                     CHECK(c(i, j) == doctest::Approx(a(i, j) - b(i, j)));
                 }
             }
+#ifdef SQUINT_USE_CUDA
+            auto c_device = a_device - b_device;
+            auto c_host = c_device.to_host();
+            for (int i = 0; i < 2; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(c_host(i, j) == doctest::Approx(1));
+                }
+            }
+#endif
         }
 
         SUBCASE("Unary negation") {
@@ -111,6 +160,15 @@ TEST_CASE("Element-wise operations") {
                     CHECK(c(i, j) == doctest::Approx(-a(i, j)));
                 }
             }
+#ifdef SQUINT_USE_CUDA
+            auto c_device = -a_device;
+            auto c_host = c_device.to_host();
+            for (int i = 0; i < 2; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(c_host(i, j) == doctest::Approx(-a(i, j)));
+                }
+            }
+#endif
         }
 
         SUBCASE("In-place addition") {
@@ -121,6 +179,16 @@ TEST_CASE("Element-wise operations") {
             CHECK(a(1, 1) == doctest::Approx(11));
             CHECK(a(0, 2) == doctest::Approx(7));
             CHECK(a(1, 2) == doctest::Approx(13));
+#ifdef SQUINT_USE_CUDA
+            a_device += b_device;
+            auto a_host = a_device.to_host();
+            CHECK(a_host(0, 0) == doctest::Approx(3));
+            CHECK(a_host(1, 0) == doctest::Approx(9));
+            CHECK(a_host(0, 1) == doctest::Approx(5));
+            CHECK(a_host(1, 1) == doctest::Approx(11));
+            CHECK(a_host(0, 2) == doctest::Approx(7));
+            CHECK(a_host(1, 2) == doctest::Approx(13));
+#endif
         }
 
         SUBCASE("In-place subtraction") {
@@ -130,6 +198,15 @@ TEST_CASE("Element-wise operations") {
                     CHECK(a(i, j) == doctest::Approx(-1));
                 }
             }
+#ifdef SQUINT_USE_CUDA
+            a_device -= b_device;
+            auto a_host = a_device.to_host();
+            for (int i = 0; i < 2; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(a_host(i, j) == doctest::Approx(1));
+                }
+            }
+#endif
         }
     }
 }
@@ -137,6 +214,9 @@ TEST_CASE("Element-wise operations") {
 TEST_CASE("Scalar operations") {
     SUBCASE("Fixed shape tensors") {
         tensor<float, shape<2, 3>> a({1, 4, 2, 5, 3, 6});
+#ifdef SQUINT_USE_CUDA
+        auto a_device = a.to_device();
+#endif
 
         SUBCASE("Scalar multiplication") {
             auto b = a * 2.0f;
@@ -145,6 +225,15 @@ TEST_CASE("Scalar operations") {
                     CHECK(b(i, j) == doctest::Approx(2 * a(i, j)));
                 }
             }
+#ifdef SQUINT_USE_CUDA
+            auto b_device = a_device * 2.0f;
+            auto b_host = b_device.to_host();
+            for (int i = 0; i < 2; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(b_host(i, j) == doctest::Approx(2 * a(i, j)));
+                }
+            }
+#endif
 
             auto c = 2.0f * a;
             for (int i = 0; i < 2; ++i) {
@@ -152,6 +241,15 @@ TEST_CASE("Scalar operations") {
                     CHECK(c(i, j) == doctest::Approx(b(i, j)));
                 }
             }
+#ifdef SQUINT_USE_CUDA
+            auto c_device = 2.0f * a_device;
+            auto c_host = c_device.to_host();
+            for (int i = 0; i < 2; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(c_host(i, j) == doctest::Approx(b(i, j)));
+                }
+            }
+#endif
         }
 
         SUBCASE("Scalar division") {
@@ -161,6 +259,15 @@ TEST_CASE("Scalar operations") {
                     CHECK(b(i, j) == doctest::Approx(a(i, j) / 2.0f));
                 }
             }
+#ifdef SQUINT_USE_CUDA
+            auto b_device = a_device / 2.0f;
+            auto b_host = b_device.to_host();
+            for (int i = 0; i < 2; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(b_host(i, j) == doctest::Approx(a(i, j) / 2.0f));
+                }
+            }
+#endif
         }
 
         SUBCASE("In-place scalar multiplication") {
@@ -170,6 +277,15 @@ TEST_CASE("Scalar operations") {
                     CHECK(a(i, j) == doctest::Approx(2 * (1 + i * 3 + j)));
                 }
             }
+#ifdef SQUINT_USE_CUDA
+            a_device *= 2.0f;
+            auto a_host = a_device.to_host();
+            for (int i = 0; i < 2; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(a_host(i, j) == doctest::Approx(2 * (1 + i * 3 + j)));
+                }
+            }
+#endif
         }
 
         SUBCASE("In-place scalar division") {
@@ -179,11 +295,23 @@ TEST_CASE("Scalar operations") {
                     CHECK(a(i, j) == doctest::Approx((1 + i * 3 + j) / 2.0f));
                 }
             }
+#ifdef SQUINT_USE_CUDA
+            a_device /= 2.0f;
+            auto a_host = a_device.to_host();
+            for (int i = 0; i < 2; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(a_host(i, j) == doctest::Approx((1 + i * 3 + j) / 2.0f));
+                }
+            }
+#endif
         }
     }
 
     SUBCASE("Dynamic shape tensors") {
         tensor<float, dynamic, dynamic> a({2, 3}, std::vector<float>{1, 4, 2, 5, 3, 6});
+#ifdef SQUINT_USE_CUDA
+        auto a_device = a.to_device();
+#endif
 
         SUBCASE("Scalar multiplication") {
             auto b = a * 2.0f;
@@ -192,6 +320,15 @@ TEST_CASE("Scalar operations") {
                     CHECK(b(i, j) == doctest::Approx(2 * a(i, j)));
                 }
             }
+#ifdef SQUINT_USE_CUDA
+            auto b_device = a_device * 2.0f;
+            auto b_host = b_device.to_host();
+            for (int i = 0; i < 2; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(b_host(i, j) == doctest::Approx(2 * a(i, j)));
+                }
+            }
+#endif
 
             auto c = 2.0f * a;
             for (int i = 0; i < 2; ++i) {
@@ -199,6 +336,15 @@ TEST_CASE("Scalar operations") {
                     CHECK(c(i, j) == doctest::Approx(b(i, j)));
                 }
             }
+#ifdef SQUINT_USE_CUDA
+            auto c_device = 2.0f * a_device;
+            auto c_host = c_device.to_host();
+            for (int i = 0; i < 2; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(c_host(i, j) == doctest::Approx(b(i, j)));
+                }
+            }
+#endif
         }
 
         SUBCASE("Scalar division") {
@@ -208,6 +354,15 @@ TEST_CASE("Scalar operations") {
                     CHECK(b(i, j) == doctest::Approx(a(i, j) / 2.0f));
                 }
             }
+#ifdef SQUINT_USE_CUDA
+            auto b_device = a_device / 2.0f;
+            auto b_host = b_device.to_host();
+            for (int i = 0; i < 2; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(b_host(i, j) == doctest::Approx(a(i, j) / 2.0f));
+                }
+            }
+#endif
         }
 
         SUBCASE("In-place scalar multiplication") {
@@ -217,6 +372,15 @@ TEST_CASE("Scalar operations") {
                     CHECK(a(i, j) == doctest::Approx(2 * (1 + i * 3 + j)));
                 }
             }
+#ifdef SQUINT_USE_CUDA
+            a_device *= 2.0f;
+            auto a_host = a_device.to_host();
+            for (int i = 0; i < 2; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(a_host(i, j) == doctest::Approx(2 * (1 + i * 3 + j)));
+                }
+            }
+#endif
         }
 
         SUBCASE("In-place scalar division") {
@@ -226,6 +390,15 @@ TEST_CASE("Scalar operations") {
                     CHECK(a(i, j) == doctest::Approx((1 + i * 3 + j) / 2.0f));
                 }
             }
+#ifdef SQUINT_USE_CUDA
+            a_device /= 2.0f;
+            auto a_host = a_device.to_host();
+            for (int i = 0; i < 2; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(a_host(i, j) == doctest::Approx((1 + i * 3 + j) / 2.0f));
+                }
+            }
+#endif
         }
     }
 }
@@ -336,6 +509,142 @@ TEST_CASE("Matrix multiplication") {
     }
 }
 
+TEST_CASE("Matrix multiplication device") {
+    SUBCASE("Fixed shape tensors") {
+        SUBCASE("Inner product of vectors") {
+            tensor<float, shape<3>> a({1, 2, 3});
+            tensor<float, shape<3>> b({4, 5, 6});
+            auto a_device = a.to_device();
+            auto b_device = b.to_device();
+            auto c_device = a_device.transpose() * b_device;
+            auto c_host = c_device.to_host();
+            CHECK(c_host(0, 0) == doctest::Approx(32));
+        }
+
+        SUBCASE("Outer product of vectors") {
+            tensor<float, shape<3>> a({1, 2, 3});
+            tensor<float, shape<3>> b({4, 5, 6});
+            auto a_device = a.to_device();
+            auto b_device = b.to_device();
+            auto c_device = a_device * b_device.transpose();
+            auto c_host = c_device.to_host();
+            for (int i = 0; i < 3; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(c_host(i, j) == doctest::Approx(a(i) * b(j)));
+                }
+            }
+        }
+
+        SUBCASE("Vector times matrix") {
+            tensor<float, shape<2>> a({1, 2});
+            tensor<float, shape<2, 3>> b({1, 4, 2, 5, 3, 6});
+            auto a_device = a.to_device();
+            auto b_device = b.to_device();
+            auto c_device = a_device.transpose() * b_device;
+            auto c_host = c_device.to_host();
+            for (int j = 0; j < 3; ++j) {
+                CHECK(c_host(0, j) == doctest::Approx(a(0) * b(0, j) + a(1) * b(1, j)));
+            }
+        }
+
+        SUBCASE("Matrix times vector") {
+            tensor<float, shape<3, 3>> a({1, 2, 3, 4, 5, 6, 7, 8, 9});
+            tensor<float, shape<3>> b({1, 2, 3});
+            auto a_device = a.to_device();
+            auto b_device = b.to_device();
+            auto c_device = a_device * b_device;
+            auto c_host = c_device.to_host();
+            for (int i = 0; i < 3; ++i) {
+                float expected = 0;
+                for (int j = 0; j < 3; ++j) {
+                    expected += a(i, j) * b(j);
+                }
+                CHECK(c_host(i, 0) == doctest::Approx(expected));
+            }
+        }
+
+        SUBCASE("Matrix times matrix") {
+            tensor<float, shape<2, 3>> a({1, 4, 2, 5, 3, 6});
+            tensor<float, shape<3, 2>> b({1, 4, 2, 5, 3, 6});
+            auto a_device = a.to_device();
+            auto b_device = b.to_device();
+            auto c_device = a_device * b_device;
+            auto c_host = c_device.to_host();
+            CHECK(c_host(0, 0) == doctest::Approx(15));
+            CHECK(c_host(0, 1) == doctest::Approx(29));
+            CHECK(c_host(1, 0) == doctest::Approx(36));
+            CHECK(c_host(1, 1) == doctest::Approx(71));
+        }
+    }
+
+    SUBCASE("Dynamic shape tensors") {
+        SUBCASE("Inner product of vectors") {
+            tensor<float, dynamic, dynamic> a({3}, std::vector<float>{1, 2, 3});
+            tensor<float, dynamic, dynamic> b({3}, std::vector<float>{4, 5, 6});
+            auto a_device = a.to_device();
+            auto b_device = b.to_device();
+            auto c_device = a_device.transpose() * b_device;
+            auto c_host = c_device.to_host();
+            CHECK(c_host(0, 0) == doctest::Approx(32));
+        }
+
+        SUBCASE("Outer product of vectors") {
+            tensor<float, dynamic, dynamic> a({3}, std::vector<float>{1, 2, 3});
+            tensor<float, dynamic, dynamic> b({3}, std::vector<float>{4, 5, 6});
+            auto a_device = a.to_device();
+            auto b_device = b.to_device();
+            auto c_device = a_device * b_device.transpose();
+            auto c_host = c_device.to_host();
+            for (int i = 0; i < 3; ++i) {
+                for (int j = 0; j < 3; ++j) {
+                    CHECK(c_host(i, j) == doctest::Approx(a(i) * b(j)));
+                }
+            }
+        }
+
+        SUBCASE("Vector times matrix") {
+            tensor<float, dynamic, dynamic> a({2}, std::vector<float>{1, 2});
+            tensor<float, dynamic, dynamic> b({2, 3}, std::vector<float>{1, 4, 2, 5, 3, 6});
+            auto a_device = a.to_device();
+            auto b_device = b.to_device();
+            auto c_device = a_device.transpose() * b_device;
+            auto c_host = c_device.to_host();
+            for (int j = 0; j < 3; ++j) {
+                CHECK(c_host(0, j) == doctest::Approx(a(0) * b(0, j) + a(1) * b(1, j)));
+            }
+        }
+
+        SUBCASE("Matrix times vector") {
+            tensor<float, dynamic, dynamic> a({3, 3}, std::vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9});
+            tensor<float, dynamic, dynamic> b({3}, std::vector<float>{1, 2, 3});
+            auto a_device = a.to_device();
+            auto b_device = b.to_device();
+            auto c_device = a_device * b_device;
+            auto c_host = c_device.to_host();
+            for (int i = 0; i < 3; ++i) {
+                float expected = 0;
+                for (int j = 0; j < 3; ++j) {
+                    expected += a(i, j) * b(j);
+                }
+                CHECK(c_host(i, 0) == doctest::Approx(expected));
+            }
+        }
+
+        SUBCASE("Matrix times matrix") {
+            tensor<float, dynamic, dynamic> a({2, 3}, std::vector<float>{1, 4, 2, 5, 3, 6});
+            tensor<float, dynamic, dynamic> b({3, 2}, std::vector<float>{1, 4, 2, 5, 3, 6});
+            auto a_device = a.to_device();
+            auto b_device = b.to_device();
+            auto c_device = a_device * b_device;
+            auto c_host = c_device.to_host();
+            CHECK(c_host(0, 0) == doctest::Approx(15));
+            CHECK(c_host(0, 1) == doctest::Approx(29));
+            CHECK(c_host(1, 0) == doctest::Approx(36));
+            CHECK(c_host(1, 1) == doctest::Approx(71));
+        }
+    }
+}
+
 TEST_CASE("General matrix division") {
     SUBCASE("Fixed shape tensors square system") {
         tensor<float, shape<2, 2>> a({1, 3, 2, 4});