diff --git a/heat/core/logical.py b/heat/core/logical.py index 2cf4768be1..59051006ed 100644 --- a/heat/core/logical.py +++ b/heat/core/logical.py @@ -9,6 +9,7 @@ from . import factories from . import manipulations +from . import sanitation from . import _operations from . import stride_tricks @@ -516,9 +517,16 @@ def sanitize_input_type( x = sanitize_input_type(x, y) y = sanitize_input_type(y, x) - # if one of the tensors is distributed, unsplit/gather it - if x.split is not None and y.split is None: - t1 = manipulations.resplit(x, axis=None) + # if one of the DNDarrays is distributed and the other is not + if x.is_distributed() and not y.is_distributed() and y.ndim > 0: + t2 = factories.array(y.larray, device=x.device, split=x.split) + x, t2 = sanitation.sanitize_distribution(x, t2, target=x) + return x, t2 + + # if y is distributed, x is not distributed, and x is not a scalar + elif y.is_distributed() and not x.is_distributed() and x.ndim > 0: + t1 = factories.array(x.larray, device=y.device, split=y.split) + t1, y = sanitation.sanitize_distribution(t1, y, target=y) return t1, y elif x.split != y.split: