Skip to content

Commit

Permalink
value broadcasting abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudiaComito committed Dec 8, 2023
1 parent f528356 commit 01a1140
Showing 1 changed file with 30 additions and 17 deletions.
47 changes: 30 additions & 17 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2342,6 +2342,25 @@ def __setitem__(
[0., 1., 0., 0., 0.]])
"""

def __broadcast_value(
arr: DNDarray, key: Union[int, Tuple[int, ...], List[int, ...]], value: DNDarray
):
"""
Broadcasts the given DNDarray `value` to the shape of the indexed array `arr[key]`.
"""
# need information on indexed array, use proxy to avoid MPI communication and limit memory usage
indexed_proxy = arr.__torch_proxy__()[key]
value_shape = value.shape
while value.ndim < indexed_proxy.ndim: # broadcasting
value = value.expand_dims(0)
try:
value_shape = tuple(torch.broadcast_shapes(value.shape, indexed_proxy.shape))
except RuntimeError:
raise ValueError(
f"could not broadcast input array from shape {value_shape} into shape {tuple(indexed_proxy.shape)}"
)
return value

def __set(
arr: DNDarray,
key: Union[int, Tuple[int, ...], List[int, ...]],
Expand Down Expand Up @@ -2376,34 +2395,28 @@ def __set(
except TypeError:
raise TypeError(f"Cannot assign object of type {type(value)} to DNDarray.")

# use low-memory torch_proxy in sanitation
indexed_proxy = self.__torch_proxy__()[key]
# `value` might be broadcasted
value_shape = value.shape
while value.ndim < indexed_proxy.ndim: # broadcasting
value = value.expand_dims(0)
try:
value_shape = tuple(torch.broadcast_shapes(value.shape, indexed_proxy.shape))
except RuntimeError:
raise ValueError(
f"could not broadcast input array from shape {value_shape} into shape {tuple(indexed_proxy.shape)}"
)

if key is None or key == ... or key == slice(None):
# make sure `self` and `value` distribution are aligned
value = sanitation.sanitize_distribution(value, target=self)
return __set(self, key, value)
# workaround for Heat issue #1292. TODO: remove when issue is fixed
if not isinstance(key, DNDarray):
if key is None or key is ... or key is slice(None):
# match dimensions
value = __broadcast_value(self, key, value)
# make sure `self` and `value` distribution are aligned
value = sanitation.sanitize_distribution(value, target=self)
return __set(self, key, value)

# single-element key
scalar = np.isscalar(key) or getattr(key, "ndim", 1) == 0
if scalar:
key, root = self.__process_scalar_key(key, indexed_axis=0, return_local_indices=True)
# match dimensions
value = __broadcast_value(self, key, value)
# `root` will be None when the indexed axis is not the split axis, or when the
# indexed axis is the split axis but the indexed element is not local
if root is not None:
if self.comm.rank == root:
# verify that `self[key]` and `value` distribution are aligned
# do not index `self` with `key` directly here, as this would MPI-broadcast to all ranks
indexed_proxy = self.__torch_proxy__()[key]
if indexed_proxy.names.count("split") != 0:
# indexed_split = indexed_proxy.names.index("split")
# lshape_map of indexed subarray is the same as the lshape_map of the original array after losing the first dimension
Expand Down

0 comments on commit 01a1140

Please sign in to comment.