diff --git a/heat/core/factories.py b/heat/core/factories.py index 802535830..dc2dbb4e0 100644 --- a/heat/core/factories.py +++ b/heat/core/factories.py @@ -138,10 +138,12 @@ def arange( # compose the local tensor start += offset * step stop = start + lshape[0] * step - data = torch.arange(start, stop, step, device=device.torch_device) - htype = types.canonical_heat_type(dtype) - data = data.type(htype.torch_type()) + if types.issubdtype(htype, types.floating): + data = torch.arange(start, stop, step, dtype=htype.torch_type(), device=device.torch_device) + else: + data = torch.arange(start, stop, step, device=device.torch_device) + data = data.type(htype.torch_type()) return DNDarray(data, gshape, htype, split, device, comm, balanced) @@ -288,8 +290,11 @@ def array( [torch.LongStorage of size 6] """ # sanitize the data type - if dtype is not None: + if dtype is None: + torch_dtype = None + else: dtype = types.canonical_heat_type(dtype) + torch_dtype = dtype.torch_type() # sanitize device if device is not None: @@ -337,6 +342,7 @@ def array( try: obj = torch.tensor( obj, + dtype=torch_dtype, device=( device.torch_device if device is not None @@ -360,15 +366,13 @@ def array( "argument `copy` is set to False, but copy of input object is necessary. \n Set copy=None to reuse the memory buffer whenever possible and allow for copies otherwise." ) try: - if not isinstance(obj, torch.Tensor): - obj = torch.as_tensor( - obj, - device=( - device.torch_device - if device is not None - else devices.get_device().torch_device - ), - ) + obj = torch.as_tensor( + obj, + dtype=torch_dtype, + device=( + device.torch_device if device is not None else devices.get_device().torch_device + ), + ) except RuntimeError: raise TypeError(f"invalid data of type {type(obj)}") @@ -376,7 +380,6 @@ def array( if dtype is None: dtype = types.canonical_heat_type(obj.dtype) else: - torch_dtype = dtype.torch_type() if obj.dtype != torch_dtype: obj = obj.type(torch_dtype) @@ -1172,9 +1175,18 @@ def linspace( # compose the local tensor start += offset * step stop = start + lshape[0] * step - step - data = torch.linspace(start, stop, lshape[0], device=device.torch_device) - if dtype is not None: - data = data.type(types.canonical_heat_type(dtype).torch_type()) + if dtype is not None and types.issubdtype(dtype, types.floating): + data = torch.linspace( + start, + stop, + lshape[0], + dtype=types.canonical_heat_type(dtype).torch_type(), + device=device.torch_device, + ) + else: + data = torch.linspace(start, stop, lshape[0], device=device.torch_device) + if dtype is not None: + data = data.type(types.canonical_heat_type(dtype).torch_type()) # construct the resulting global tensor ht_tensor = DNDarray( diff --git a/heat/core/linalg/basics.py b/heat/core/linalg/basics.py index 2f57fe774..35c4f208d 100644 --- a/heat/core/linalg/basics.py +++ b/heat/core/linalg/basics.py @@ -1261,6 +1261,8 @@ def matrix_norm( row_axis, col_axis = axis + # dtype = types.promote_types(x.dtype, types.float32) + if ord == 1: if col_axis > row_axis and not keepdims: col_axis -= 1 diff --git a/heat/core/rounding.py b/heat/core/rounding.py index 79bc989fd..dcee642b4 100644 --- a/heat/core/rounding.py +++ b/heat/core/rounding.py @@ -52,7 +52,7 @@ def abs( if dtype is not None and not issubclass(dtype, dtype): raise TypeError("dtype must be a heat data type") - absolute_values = _operations.__local_op(torch.abs, x, out) + absolute_values = _operations.__local_op(torch.abs, x, out, no_cast=True) if dtype is not None: absolute_values.larray = absolute_values.larray.type(dtype.torch_type()) absolute_values._DNDarray__dtype = dtype @@ -181,7 +181,11 @@ def fabs(x: DNDarray, out: Optional[DNDarray] = None) -> DNDarray: If not provided or ``None``, a freshly-allocated array is returned. """ - return abs(x, out, dtype=None) + if isinstance(x, DNDarray): + dtype = types.promote_types(x.dtype, types.float32) + else: + dtype = types.float32 + return abs(x, out, dtype=dtype) DNDarray.fabs: Callable[[DNDarray, Optional[DNDarray]], DNDarray] = lambda self, out=None: fabs( diff --git a/heat/core/tests/test_factories.py b/heat/core/tests/test_factories.py index 6851704cf..25b3845f2 100644 --- a/heat/core/tests/test_factories.py +++ b/heat/core/tests/test_factories.py @@ -106,6 +106,9 @@ def test_arange(self): # make an in direct check for the sequence, compare against the gaussian sum self.assertEqual(three_arg_arange_dtype_float64.sum(axis=0, keepdims=True), 20.0) + check_precision = ht.arange(16777217.0, 16777218, 1, dtype=ht.float64) + self.assertEqual(check_precision.sum(), 16777217) + # exceptions with self.assertRaises(ValueError): ht.arange(-5, 3, split=1) @@ -142,6 +145,8 @@ def test_array(self): == torch.tensor(tuple_data, dtype=torch.int8, device=self.device.torch_device) ).all() ) + check_precision = ht.array(16777217.0, dtype=ht.float64) + self.assertEqual(check_precision.sum(), 16777217) # basic array function, unsplit data, no copy torch_tensor = torch.tensor([6, 5, 4, 3, 2, 1], device=self.device.torch_device) @@ -727,6 +732,8 @@ def test_linspace(self): zero_samples = ht.linspace(-3, 5, num=0) self.assertEqual(zero_samples.size, 0) + check_precision = ht.linspace(0.0, 16777217.0, num=2, dtype=torch.float64) + self.assertEqual(check_precision.sum(), 16777217) # simple inverse linear space descending = ht.linspace(-5, 3, num=100) diff --git a/heat/core/tests/test_rounding.py b/heat/core/tests/test_rounding.py index 98e3c459f..761742095 100644 --- a/heat/core/tests/test_rounding.py +++ b/heat/core/tests/test_rounding.py @@ -57,6 +57,10 @@ def test_abs(self): self.assertEqual(absolute_values.sum(axis=0), 100) self.assertEqual(absolute_values.dtype, ht.float32) self.assertEqual(absolute_values.larray.dtype, torch.float32) + check_precision = ht.asarray(9007199254740993, dtype=ht.int64) + precision_absolute_values = ht.abs(check_precision, dtype=ht.int64) + self.assertEqual(precision_absolute_values.sum(), check_precision.sum()) + self.assertEqual(precision_absolute_values.dtype, check_precision.dtype) # for fabs self.assertEqual(int8_absolute_values_fabs.dtype, ht.float32) self.assertEqual(int16_absolute_values_fabs.dtype, ht.float32)