From 1547a7e04a15231f8e3b935004e57e8788566129 Mon Sep 17 00:00:00 2001 From: Alexander Harvey Nitz Date: Mon, 2 Dec 2024 17:33:40 -0500 Subject: [PATCH] pytorch array backend --- pycbc/scheme.py | 169 ++++++----- pycbc/types/array.py | 20 +- pycbc/types/array_cuda.py | 3 - pycbc/types/array_torch.py | 561 +++++++++++++++++++++++++++++++++++++ 4 files changed, 681 insertions(+), 72 deletions(-) create mode 100644 pycbc/types/array_torch.py diff --git a/pycbc/scheme.py b/pycbc/scheme.py index 0a9e6740e1e..a6edd43956c 100644 --- a/pycbc/scheme.py +++ b/pycbc/scheme.py @@ -1,35 +1,12 @@ -# Copyright (C) 2014 Alex Nitz, Andrew Miller -# -# This program is free software; you can redistribute it and/or modify it -# under the terms of the GNU General Public License as published by the -# Free Software Foundation; either version 3 of the License, or (at your -# option) any later version. -# -# This program is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General -# Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. - - -# -# ============================================================================= -# -# Preamble -# -# ============================================================================= -# """ -This modules provides python contexts that set the default behavior for PyCBC +This module provides python contexts that set the default behavior for PyCBC objects. """ import os import pycbc from functools import wraps import logging +import ctypes from .libutils import get_ctypes_library logger = logging.getLogger('pycbc.scheme') @@ -42,16 +19,16 @@ def __init__(self): if _SchemeManager._single is not None: raise RuntimeError("SchemeManager is a private class") - _SchemeManager._single= self + _SchemeManager._single = self - self.state= None - self._lock= False + self.state = None + self._lock = False def lock(self): - self._lock= True + self._lock = True def unlock(self): - self._lock= False + self._lock = False def shift_to(self, state): if self._lock is False: @@ -59,6 +36,7 @@ def shift_to(self, state): else: raise RuntimeError("The state is locked, cannot shift schemes") + # Create the global processing scheme manager mgr = _SchemeManager() DefaultScheme = None @@ -68,32 +46,39 @@ def shift_to(self, state): class Scheme(object): """Context that sets PyCBC objects to use CPU processing. """ _single = None + def __init__(self): if DefaultScheme is type(self): return if Scheme._single is not None: raise RuntimeError("Only one processing scheme can be used") Scheme._single = True + def __enter__(self): mgr.shift_to(self) mgr.lock() + def __exit__(self, type, value, traceback): mgr.unlock() mgr.shift_to(default_context) + def __del__(self): if Scheme is not None: Scheme._single = None -_cuda_cleanup_list=[] + +_cuda_cleanup_list = [] + def register_clean_cuda(function): _cuda_cleanup_list.append(function) + def clean_cuda(context): - #Before cuda context is destroyed, all item destructions dependent on cuda + # Before CUDA context is destroyed, all item destructions dependent on CUDA # must take place. This calls all functions that have been registered - # with _register_clean_cuda() in reverse order - #So the last one registered, is the first one cleaned + # with register_clean_cuda() in reverse order + # So the last one registered, is the first one cleaned _cuda_cleanup_list.reverse() for func in _cuda_cleanup_list: func() @@ -102,6 +87,7 @@ def clean_cuda(context): from pycuda.tools import clear_context_caches clear_context_caches() + class CUDAScheme(Scheme): """Context that sets PyCBC objects to use a CUDA processing scheme. """ def __init__(self, device_num=0): @@ -111,9 +97,37 @@ def __init__(self, device_num=0): import pycuda.driver pycuda.driver.init() self.device = pycuda.driver.Device(device_num) - self.context = self.device.make_context(flags=pycuda.driver.ctx_flags.SCHED_BLOCKING_SYNC) + self.context = self.device.make_context( + flags=pycuda.driver.ctx_flags.SCHED_BLOCKING_SYNC) import atexit - atexit.register(clean_cuda,self.context) + atexit.register(clean_cuda, self.context) + + +class TorchScheme(Scheme): + """Context that sets PyCBC objects to use a PyTorch processing scheme.""" + def __init__(self, device='cpu'): + Scheme.__init__(self) + self.device = device + # Check if CUDA is available for PyTorch if device is not CPU + if self.device != 'cpu': + import torch + if not torch.cuda.is_available(): + raise RuntimeError("CUDA device not available for PyTorch") + logger.info(f"PyTorch using device: {self.device}") + + def __enter__(self): + Scheme.__enter__(self) + # Set the default device for PyTorch tensors + import torch + torch_device = torch.device(self.device) + # No need to set default tensor type; tensors can specify device directly + self.torch_device = torch_device + logger.info(f"Entered TorchScheme with device: {self.device}") + + def __exit__(self, type, value, traceback): + Scheme.__exit__(self, type, value, traceback) + logger.info("Exited TorchScheme") + class CUPYScheme(Scheme): @@ -142,7 +156,7 @@ def __exit__(self, *args): class CPUScheme(Scheme): def __init__(self, num_threads=1): if isinstance(num_threads, int): - self.num_threads=num_threads + self.num_threads = num_threads elif num_threads == 'env' and "PYCBC_NUM_THREADS" in os.environ: self.num_threads = int(os.environ["PYCBC_NUM_THREADS"]) else: @@ -163,7 +177,7 @@ def __enter__(self): os.environ["OMP_NUM_THREADS"] = str(self.num_threads) if self._libgomp is not None: - self._libgomp.omp_set_num_threads( int(self.num_threads) ) + self._libgomp.omp_set_num_threads(int(self.num_threads)) def __exit__(self, type, value, traceback): os.environ["OMP_NUM_THREADS"] = "1" @@ -171,18 +185,21 @@ def __exit__(self, type, value, traceback): self._libgomp.omp_set_num_threads(1) Scheme.__exit__(self, type, value, traceback) + class MKLScheme(CPUScheme): def __init__(self, num_threads=1): CPUScheme.__init__(self, num_threads) if not pycbc.HAVE_MKL: raise RuntimeError("Can't find MKL libraries") + class NumpyScheme(CPUScheme): pass scheme_prefix = { CUDAScheme: "cuda", + TorchScheme: "torch", # Changed to 'torch' as the scheme name CPUScheme: "cpu", CUPYScheme: "cupy", MKLScheme: "mkl", @@ -201,17 +218,23 @@ class NumpyScheme(CPUScheme): ), ) + class DefaultScheme(_default_scheme_class): pass + default_context = DefaultScheme() mgr.state = default_context scheme_prefix[DefaultScheme] = _default_scheme_prefix + def current_prefix(): return scheme_prefix[type(mgr.state)] + _import_cache = {} + + def schemed(prefix): def scheming_function(func): @@ -237,25 +260,27 @@ def _scheming_function(*args, **kwds): return schemed_fn(*args, **kwds) - err = (f"Failed to find implementation of {func.__name__} " - f"for {current_prefix()} scheme. ") + err = """Failed to find implementation of (%s) + for %s scheme.""" % (str(func), current_prefix()) for emsg in exc_errors: - err += str(emsg) + " " + err += str(emsg) + "\n" raise RuntimeError(err) return _scheming_function return scheming_function + def cpuonly(func): @wraps(func) def _cpuonly(*args, **kwds): if not issubclass(type(mgr.state), CPUScheme): - raise TypeError(fn.__name__ + + raise TypeError(func.__name__ + " can only be called from a CPU processing scheme.") else: return func(*args, **kwds) return _cpuonly + def insert_processing_option_group(parser): """ Adds the options used to choose a processing scheme. This should be used @@ -267,26 +292,30 @@ def insert_processing_option_group(parser): OptionParser instance """ processing_group = parser.add_argument_group("Options for selecting the" - " processing scheme in this program.") + " processing scheme in this program.") + scheme_choices = list(set(scheme_prefix.values())) processing_group.add_argument("--processing-scheme", - help="The choice of processing scheme. " - "Choices are " + str(list(set(scheme_prefix.values()))) + - ". (optional for CPU scheme) The number of " - "execution threads " - "can be indicated by cpu:NUM_THREADS, " - "where NUM_THREADS " - "is an integer. The default is a single thread. " - "If the scheme is provided as cpu:env, the number " - "of threads can be provided by the PYCBC_NUM_THREADS " - "environment variable. If the environment variable " - "is not set, the number of threads matches the number " - "of logical cores. ", - default="cpu") + help="The choice of processing scheme. " + "Choices are " + str(scheme_choices) + + ". (optional for CPU scheme) The number of " + "execution threads " + "can be indicated by cpu:NUM_THREADS, " + "where NUM_THREADS " + "is an integer. The default is a single thread. " + "If the scheme is provided as cpu:env, the number " + "of threads can be provided by the PYCBC_NUM_THREADS " + "environment variable. If the environment variable " + "is not set, the number of threads matches the number " + "of logical cores. For Torch scheme, you can specify " + "the device as torch:DEVICE, where DEVICE is 'cpu' or " + "'cuda:0', 'cuda:1', etc.", + default="cpu") processing_group.add_argument("--processing-device-id", - help="(optional) ID of GPU to use for accelerated " - "processing", - default=0, type=int) + help="(optional) ID of GPU to use for accelerated " + "processing", + default=0, type=int) + def from_cli(opt): """Parses the command line options and returns a processing scheme. @@ -308,6 +337,15 @@ def from_cli(opt): if name == "cuda": logger.info("Running with CUDA support") ctx = CUDAScheme(opt.processing_device_id) + elif name == "torch": + logger.info("Running with Torch (PyTorch) support") + # Get device if specified + if len(scheme_str) > 1: + device = scheme_str[1] + else: + device = 'cpu' # Default to CPU + ctx = TorchScheme(device) + logger.info(f"Torch device set to: {device}") elif name == "mkl": if len(scheme_str) > 1: numt = scheme_str[1] @@ -331,11 +369,11 @@ def from_cli(opt): logger.info("Running with CPU support: %s threads" % ctx.num_threads) return ctx + def verify_processing_options(opt, parser): - """Parses the processing scheme options and verifies that they are + """Parses the processing scheme options and verifies that they are reasonable. - Parameters ---------- opt : object @@ -346,10 +384,11 @@ def verify_processing_options(opt, parser): """ scheme_types = scheme_prefix.values() if opt.processing_scheme.split(':')[0] not in scheme_types: - parser.error("(%s) is not a valid scheme type.") + parser.error("(%s) is not a valid scheme type." % opt.processing_scheme) + class ChooseBySchemeDict(dict): - """ This class represents a dictionary whose purpose is to chose objects + """ This class represents a dictionary whose purpose is to choose objects based on their processing scheme. The keys are intended to be processing schemes. """ @@ -357,7 +396,7 @@ def __getitem__(self, scheme): for base in scheme.__mro__[0:-1]: try: return dict.__getitem__(self, base) - break - except: + except KeyError: pass + raise KeyError("Scheme not found in ChooseBySchemeDict") diff --git a/pycbc/types/array.py b/pycbc/types/array.py index 9959a3700d4..f2ceab74faf 100644 --- a/pycbc/types/array.py +++ b/pycbc/types/array.py @@ -83,6 +83,13 @@ def noreal(self, *args, **kwargs): raise TypeError( func.__name__ + " does not support real types") return noreal +@schemed(BACKEND_PREFIX) +def _scheme_get_numpy_dtype(dtype): + """Get the NumPy dtype corresponding to the scheme's data type.""" + # This function is intended to be overridden by schemes. + # If not overridden, it will raise NotImplementedError. + raise NotImplementedError("_scheme_get_numpy_dtype not implemented for this scheme.") + def force_precision_to_match(scalar, precision): if _numpy.iscomplexobj(scalar): if precision == 'single': @@ -169,13 +176,13 @@ def __init__(self, initial_array, dtype=None, copy=True): self._data = initial_array # Check that the dtype is supported. - if self._data.dtype not in _ALLOWED_DTYPES: + data_dtype = _scheme_get_numpy_dtype(self._data.dtype) + if data_dtype not in _ALLOWED_DTYPES: raise TypeError(str(self._data.dtype) + ' is not supported') - if dtype and dtype != self._data.dtype: + if dtype and dtype != _scheme_get_numpy_dtype(self._data.dtype): raise TypeError("Can only set dtype when allowed to copy data") - if copy: # First we will check the dtype that we are given if not hasattr(initial_array, 'dtype'): @@ -981,7 +988,12 @@ def lal(self): @property def dtype(self): - return self._data.dtype + """Return the NumPy dtype of the array.""" + try: + return _scheme_get_numpy_dtype(self._data.dtype) + except (NotImplementedError, AttributeError): + # Fallback: assume self._data.dtype is a NumPy dtype + return self._data.dtype def save(self, path, group=None): """ diff --git a/pycbc/types/array_cuda.py b/pycbc/types/array_cuda.py index 45a89b73456..2e587c3aeb6 100644 --- a/pycbc/types/array_cuda.py +++ b/pycbc/types/array_cuda.py @@ -356,6 +356,3 @@ def _copy_base_array(array): def _to_device(array): return pycuda.gpuarray.to_gpu(array) - - - diff --git a/pycbc/types/array_torch.py b/pycbc/types/array_torch.py new file mode 100644 index 00000000000..5ddbf4feb66 --- /dev/null +++ b/pycbc/types/array_torch.py @@ -0,0 +1,561 @@ +# array_torch.py + +import torch +import numpy as np +import pycbc.scheme as _scheme +from pycbc.types.array import Array + +# Mapping from NumPy dtypes to Torch dtypes +NUMPY_TO_TORCH_DTYPE = { + np.dtype(np.float32): torch.float32, + np.dtype(np.float64): torch.float64, + np.dtype(np.complex64): torch.complex64, + np.dtype(np.complex128): torch.complex128, + np.dtype(np.int32): torch.int32, + np.dtype(np.uint32): torch.uint32, + np.dtype(int): torch.int64, +} + +# Reverse mapping from Torch dtypes to NumPy dtypes +TORCH_TO_NUMPY_DTYPE = {v: k for k, v in NUMPY_TO_TORCH_DTYPE.items()} + +def _scheme_get_numpy_dtype(dtype): + """ + Get the NumPy dtype corresponding to the given Torch dtype. + + Parameters + ---------- + dtype : torch.dtype + The Torch dtype to convert. + + Returns + ------- + numpy.dtype + The corresponding NumPy dtype. + """ + if isinstance(dtype, torch.dtype): + numpy_dtype = TORCH_TO_NUMPY_DTYPE.get(dtype) + if numpy_dtype is None: + raise TypeError(f"Torch data type {dtype} does not have a corresponding NumPy dtype") + return numpy_dtype + else: + # If it's already a NumPy dtype, return it + return np.dtype(dtype) + +def _scheme_array_from_initial(initial_array, dtype=None): + device = _scheme.mgr.state.device + + if isinstance(initial_array, torch.Tensor): + tensor = initial_array.to(device) + if dtype is not None: + torch_dtype = NUMPY_TO_TORCH_DTYPE.get(np.dtype(dtype)) + if tensor.dtype != torch_dtype: + tensor = tensor.to(dtype=torch_dtype) + else: + # Convert initial_array to torch.Tensor + if not isinstance(initial_array, np.ndarray): + initial_array = np.array(initial_array, dtype=dtype) + else: + if dtype is not None and initial_array.dtype != dtype: + initial_array = initial_array.astype(dtype) + torch_dtype = NUMPY_TO_TORCH_DTYPE.get(np.dtype(initial_array.dtype)) + tensor = torch.tensor(initial_array, dtype=torch_dtype, device=device) + + return tensor + +def _scheme_matches_base_array(array): + return isinstance(array, torch.Tensor) + +def _copy_base_array(array): + return array.clone() + +def _to_device(array): + """ + Move the array to the appropriate device and convert it to a Torch tensor. + + Parameters + ---------- + array : array-like + Input data, can be a NumPy array, list, or Torch tensor. + + Returns + ------- + torch.Tensor + Torch tensor on the specified device. + """ + device = _scheme.mgr.state.device + + if isinstance(array, torch.Tensor): + tensor = array.to(device) + else: + # Convert array to Torch tensor + tensor = torch.tensor(np.asanyarray(array)) + tensor = tensor.to(device) + return tensor + +def zeros(length, dtype=np.float64): + torch_dtype = NUMPY_TO_TORCH_DTYPE.get(np.dtype(dtype)) + device = _scheme.mgr.state.device + return Array(torch.zeros(length, dtype=torch_dtype, device=device), copy=False) + +def empty(length, dtype=np.float64): + torch_dtype = NUMPY_TO_TORCH_DTYPE.get(np.dtype(dtype)) + device = _scheme.mgr.state.device + return Array(torch.empty(length, dtype=torch_dtype, device=device), copy=False) + +def ptr(array): + return array._data.data_ptr() + +def numpy(array): + """ + Convert the Array to a NumPy array. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + + Returns + ------- + numpy.ndarray + NumPy array. + """ + return array._data.cpu().numpy() + +def inner(a, b): + """ + Compute the inner product of two arrays. + + Parameters + ---------- + a, b : Array + Input PyCBC Array instances. + + Returns + ------- + scalar + The inner product result. + """ + data_a = a._data + data_b = b + if data_a.is_complex(): + result = torch.vdot(data_a.view(-1), data_b.view(-1)) + else: + result = torch.dot(data_a.view(-1), data_b.view(-1)) + return result.item() + +vdot = inner # Alias for inner product + +def squared_norm(array): + """ + Compute the squared norm (magnitude squared) of the array. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + + Returns + ------- + Array + An Array instance containing the squared norm. + """ + data = array._data + result_data = torch.abs(data) ** 2 + return Array(result_data, copy=False) + +def sum(array): + """ + Compute the sum of all elements in the array. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + + Returns + ------- + scalar + Sum of elements. + """ + return torch.sum(array._data).item() + +def max(array): + """ + Find the maximum value in the array. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + + Returns + ------- + scalar + Maximum value. + """ + return torch.max(array._data).item() + +def max_loc(array): + """ + Find the maximum value and its index in the array. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + + Returns + ------- + tuple + Maximum value and its index. + """ + data = array._data + max_val, max_idx = torch.max(data, dim=0) + return max_val.item(), max_idx.item() + +def min(array): + """ + Find the minimum value in the array. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + + Returns + ------- + scalar + Minimum value. + """ + return torch.min(array._data).item() + +def cumsum(array): + """ + Compute the cumulative sum of the array along dimension 0. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + + Returns + ------- + Array + An Array instance containing the cumulative sum. + """ + data = array._data + result_data = torch.cumsum(data, dim=0) + return Array(result_data, copy=False) + +def take(array, indices): + """ + Extract elements from the array at the specified indices. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + indices : array-like + Indices to take. + + Returns + ------- + Array + An Array instance containing the extracted elements. + """ + data = array._data + indices_tensor = torch.tensor(indices, dtype=torch.long, device=data.device) + result_data = torch.take(data, indices_tensor) + return Array(result_data, copy=False) + +def dot(a, b): + """ + Compute the dot product of two arrays. + + Parameters + ---------- + a, b : Array + Input PyCBC Array instances. + + Returns + ------- + scalar + Dot product result. + """ + data_a = a._data + data_b = b + result = torch.dot(data_a.view(-1), data_b.view(-1)) + return result.item() + +def abs_max_loc(array): + """ + Find the maximum absolute value and its index in the array. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + + Returns + ------- + tuple + Maximum absolute value and its index. + """ + data = array._data + abs_data = torch.abs(data) + max_val, max_idx = torch.max(abs_data, dim=0) + return max_val.item(), max_idx.item() + +def clear(array): + """ + Zero out all elements of the array in-place. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + + Returns + ------- + None + """ + array._data.zero_() + +def _getvalue(array, index): + """ + Get the value at the specified index from the array. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + index : int + Index to retrieve. + + Returns + ------- + scalar + Value at the specified index. + """ + return array._data[index].item() + +def _copy(dest, src): + """ + Copy data from src array to dest array. + + Parameters + ---------- + dest, src : Array + Destination and source PyCBC Array instances. + + Returns + ------- + None + """ + dest._data.copy_(src._data) + +def multiply_and_add(a, b, mult_fac): + """ + Multiply b by mult_fac and add to a. + + Parameters + ---------- + a, b : Array + Input PyCBC Array instances. + mult_fac : scalar + Multiplication factor. + + Returns + ------- + None + """ + a._data.add_(b._data, alpha=mult_fac) + +def weighted_inner(a, b, weight): + """ + Compute the weighted inner product of two arrays. + + Parameters + ---------- + a, b : Array + Input PyCBC Array instances. + weight : Array or None + Weights array or None. + + Returns + ------- + scalar + Weighted inner product result. + """ + data_a = a._data + data_b = b + if weight is None: + return inner(a, b) + else: + data_w = weight._data + if data_a.is_complex(): + result = torch.dot((data_a.conj() * data_b / data_w).view(-1), torch.ones_like(data_a.view(-1))) + else: + result = torch.dot((data_a * data_b / data_w).view(-1), torch.ones_like(data_a.view(-1))) + return result.item() + +def real(array): + """ + Return the real part of the array. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + + Returns + ------- + Array + An Array instance containing the real part. + """ + data = torch.real(array._data) + return Array(data, copy=False) + +def imag(array): + """ + Return the imaginary part of the array. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + + Returns + ------- + Array + An Array instance containing the imaginary part. + """ + data = torch.imag(array._data) + return Array(data, copy=False) + +def conj(array): + """ + Return the complex conjugate of the array. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + + Returns + ------- + Array + An Array instance containing the complex conjugate. + """ + data = torch.conj(array._data) + return Array(data, copy=False) + +def fill(array, value): + """ + Fill the array with the specified value. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + value : scalar + Value to fill the array with. + + Returns + ------- + None + """ + array._data.fill_(value) + +def roll(array, shift): + """ + Roll the elements of the array by the specified shift. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + shift : int + Number of places by which elements are shifted. + + Returns + ------- + Array + An Array instance with rolled elements. + """ + data = torch.roll(array._data, shifts=shift, dims=0) + return Array(data, copy=False) + +def astype(array, dtype): + """ + Change the data type of the array. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + dtype : numpy.dtype + Desired NumPy data type. + + Returns + ------- + Array + An Array instance with the new data type. + """ + torch_dtype = NUMPY_TO_TORCH_DTYPE.get(np.dtype(dtype)) + data = array._data.to(dtype=torch_dtype) + return Array(data, copy=False) + +def resize(array, new_size): + """ + Resize the array to the specified new size. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + new_size : int + New size of the array. + + Returns + ------- + Array + An Array instance with the new size. + """ + data = array._data + current_size = data.size(0) + if new_size == current_size: + return array + elif new_size < current_size: + new_data = data[:new_size].clone() + else: + # Create a new tensor with the new size and copy existing data + torch_dtype = data.dtype + device = data.device + new_data = torch.empty(new_size, dtype=torch_dtype, device=device) + new_data[:current_size] = data + return Array(new_data, copy=False) + +def view(array, dtype): + """ + Return a view of the array with the specified data type. + + Parameters + ---------- + array : Array + Input PyCBC Array instance. + dtype : numpy.dtype + Desired NumPy data type. + + Returns + ------- + Array + An Array instance with the new data type. + """ + torch_dtype = NUMPY_TO_TORCH_DTYPE.get(np.dtype(dtype)) + data = array._data.to(dtype=torch_dtype) + return Array(data, copy=False) + +# Any other functions you need can be added following the same pattern. +